Skip to content

Commit 50d6606

Browse files
updated data-structures
1 parent 1f10f48 commit 50d6606

File tree

3 files changed

+115
-38
lines changed

3 files changed

+115
-38
lines changed

contracts/asset.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABC
22
from typing import Type, Optional
33

4+
from contracts.utils import clean_ticker
5+
46

57
class AssetBase(ABC):
68
"""
@@ -12,49 +14,56 @@ class AssetBase(ABC):
1214

1315
def __init__(self, ticker: str, shares: int | float = 0):
1416
assert isinstance(ticker, str), "Ticker must be a string"
15-
self._ticker = ticker.strip().strip('^').upper()
17+
self._ticker = clean_ticker(ticker)
1618

1719
assert isinstance(shares, self.asset_type), f"Shares must be an {self.asset_type.__name__}"
18-
self._shares = shares
20+
self._shares: int | float = shares
1921
self._trade_history: Optional[list] = None # List to store trade history
2022

2123
def __repr__(self):
2224
return f"{self.__class__.__name__}(ticker={self._ticker}, shares={self.shares})"
2325

24-
def buy(self, quantity: int | float):
26+
def buy(self, shares: int | float):
2527
"""
2628
Buy a specified quantity of the asset.
27-
:param quantity: Number of shares to buy.
29+
:param shares: Number of shares to buy.
2830
"""
2931
# Type check based on the class's asset_type
30-
if not isinstance(quantity, self.asset_type):
31-
raise TypeError(f"Quantity must be of type {self.asset_type.__name__}")
32-
33-
if quantity <= 0:
34-
raise ValueError("Cannot buy a negative quantity")
35-
self._shares += quantity
32+
assert isinstance(shares, self.asset_type), TypeError(f"Quantity must be of type {self.asset_type.__name__}")
33+
assert shares <= 0, ValueError("Cannot buy a negative quantity")
34+
self._shares += shares
3635

37-
def sell(self, quantity: int | float):
36+
def sell(self, shares: int | float):
3837
"""
3938
Sell a specified quantity of the asset.
4039
(Short selling is not allowed in this base class)
41-
:param quantity: Number of shares to sell.
40+
:param shares: Number of shares to sell.
4241
"""
4342
# Type check based on the class's asset_type
44-
if not isinstance(quantity, self.asset_type):
45-
raise TypeError(f"Quantity must be of type {self.asset_type.__name__}")
43+
assert isinstance(shares, self.asset_type), TypeError(f"Quantity must be of type {self.asset_type.__name__}")
44+
assert shares > self._shares, ValueError("Cannot sell more than held quantity")
45+
self._shares -= shares
4646

47-
if quantity > self._shares:
48-
raise ValueError("Cannot sell more than held quantity")
49-
self._shares -= quantity
47+
# region Properties
48+
49+
@property
50+
def ticker(self) -> str:
51+
return self._ticker
5052

5153
@property
52-
def shares(self):
54+
def shares(self) -> int | float:
5355
return self._shares
5456

57+
@property
5558
def is_empty(self):
5659
return self.shares == 0
5760

61+
@property
62+
def trade_history(self):
63+
return self._trade_history
64+
65+
# endregion Properties
66+
5867

5968
class Asset(AssetBase):
6069
"""

contracts/portfolio.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Dict, List
2-
31
import pandas as pd
2+
import yfinance as yf
3+
from typing import Dict, List, Optional
44

55
from contracts.asset import Asset, CashAsset
6-
from strategies.stock.base import StrategyBase
6+
from contracts.utils import clean_ticker
7+
from strategies.stock.base import StrategyBase, StrategyFactory
78

89

910
class Portfolio:
@@ -13,40 +14,61 @@ class Portfolio:
1314
"""
1415
def __init__(self,
1516
name: str,
16-
tickers: List[str],
17+
tickers: str | List[str],
1718
starting_cash: float,
18-
strategy: StrategyBase,
19+
strategy: str,
1920
benchmark: str = "SPY",
20-
rebalance_freq: str = "monthly",
21+
rebalance_freq: Optional[str] = None,
22+
recomposition_freq: Optional[str] = None,
2123
metadata: Dict = None):
2224
"""
2325
Initialize a Portfolio object.
2426
2527
Args:
2628
name (str): Name of the portfolio.
2729
tickers (List[str]): List of asset tickers.
28-
starting_cash (float): Initial cash balance.
30+
starting_cash (float): Initial cash shares.
2931
strategy (StrategyBase): Strategy instance to generate signals.
3032
benchmark (str): Benchmark ticker for reference only.
31-
rebalance_freq (str): Rebalancing frequency (e.g., 'monthly').
33+
rebalance_freq (str, optional): Rebalancing frequency (e.g., 'monthly', 'quarterly', 'annually').
34+
recomposition_freq (str, optional): Frequency for recomposition (e.g., 'monthly', 'quarterly', 'annually').
3235
metadata (Dict, optional): Additional metadata.
3336
"""
37+
# Initialize Portfolio name
38+
assert (isinstance(name, str) and not name.strip()), "Portfolio name must be a non-empty string"
3439
self.name = name
40+
41+
# Initialize tickers
42+
if isinstance(tickers, str):
43+
tickers = [clean_ticker(ticker) for ticker in tickers.split(",")]
44+
assert isinstance(tickers, list), "tickers must be a list or comma-separated string"
3545
self.tickers = tickers
36-
self.strategy = strategy
46+
47+
# Initialize strategy
48+
assert strategy in StrategyFactory.get_supported_strategies(), f"Unsupported strategy: {strategy}. Supported strategies: {StrategyFactory.get_supported_strategies()}"
49+
self.strategy = StrategyFactory.create_strategy(strategy)
50+
51+
# Initialise benchmark
52+
benchmark = clean_ticker(benchmark)
53+
assert isinstance(benchmark, str), "Benchmark must be a string"
3754
self.benchmark = benchmark
38-
self.rebalance_freq = rebalance_freq
39-
self.metadata = metadata or {}
4055

41-
self.positions: Dict[str, Asset | CashAsset] = {
56+
self._positions: Dict[str, Asset | CashAsset] = {
4257
ticker: Asset(ticker)
4358
for ticker in tickers
4459
}
45-
self.positions['CASH'] = CashAsset(starting_cash)
60+
self._cash = CashAsset(starting_cash)
61+
4662
self.trade_log: List[dict] = []
4763
self.position_history: Dict[str, List[int]] = {}
4864

49-
def execute_trade(self, date: pd.Timestamp, ticker: str, action: str, shares: int, price: float, note: str = 'Strategy Signal'):
65+
self.rebalance_freq = rebalance_freq
66+
self.recomposition_freq = recomposition_freq
67+
self.metadata = metadata or {}
68+
69+
def execute_trade(self, date: pd.Timestamp, ticker: str,
70+
action: str, shares: int, price: float,
71+
note: str = 'Strategy Signal'):
5072
"""
5173
Execute a trade and update portfolio positions and cash.
5274
@@ -61,28 +83,27 @@ def execute_trade(self, date: pd.Timestamp, ticker: str, action: str, shares: in
6183
ValueError: If insufficient cash or shares.
6284
"""
6385
trade_value = shares * price
64-
cash_asset = self.positions['CASH']
6586

6687
if action == 'BUY':
67-
if cash_asset.balance < trade_value:
88+
if self.cash < trade_value:
6889
raise ValueError(f"Insufficient cash to buy {shares} shares of {ticker}")
69-
cash_asset.withdraw_cash(trade_value)
90+
self._cash.withdraw_cash(trade_value)
7091
self.update_position(ticker, shares)
7192

7293
elif action == 'SELL':
7394
held = self.get_position(ticker)
7495
if held < shares:
7596
raise ValueError(f"Trying to sell more shares than held for {ticker}")
7697
self.update_position(ticker, -shares)
77-
cash_asset.deposit_cash(trade_value)
98+
self._cash.deposit_cash(trade_value)
7899

79100
else:
80101
raise ValueError("Action must be either 'BUY' or 'SELL'")
81102

82-
self.add_trade(date, ticker, action, shares, price, cash_asset.balance, note)
103+
self.add_trade(date, ticker, action, shares, price, self._cash.shares, note)
83104

84105
def get_cash(self) -> float:
85-
return self.positions['CASH'].balance
106+
return self.positions['CASH'].shares
86107

87108
def add_trade(self, date, ticker, action, shares, price, cash_remaining, note=''):
88109
entry = {
@@ -135,3 +156,25 @@ def get_portfolio_value(self, prices: Dict[str, float]) -> float:
135156
for ticker in self.tickers:
136157
value += self.positions[ticker].shares * prices.get(ticker, 0.0)
137158
return value
159+
160+
# region Properties
161+
162+
@property
163+
def cash(self) -> float:
164+
"""
165+
Get the current cash shares in the portfolio.
166+
Returns:
167+
float: Current cash shares.
168+
"""
169+
return self.positions['CASH'].shares
170+
171+
@property
172+
def positions(self) -> dict:
173+
"""
174+
Get the current positions in the portfolio.
175+
Returns:
176+
dict: Dictionary of asset ticker to Asset object.
177+
"""
178+
return self.positions
179+
180+
# endregion Properties

contracts/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import yfinance as yf
2+
3+
# TODO: Add support for multiple data sources in the future. Perhaps move to DataIngestionManager?
4+
5+
6+
def clean_ticker(ticker):
7+
""" Validate and format a ticker string.
8+
Args:
9+
ticker (str): Ticker symbol to validate.
10+
Returns:
11+
str: Validated and formatted ticker symbol.
12+
Raises:
13+
ValueError: If the ticker is invalid.
14+
"""
15+
if not isinstance(ticker, str):
16+
raise ValueError("Ticker must be a string")
17+
ticker = ticker.strip().upper()
18+
if not ticker.isalpha() or len(ticker) < 1 or len(ticker) > 5:
19+
raise ValueError(f"Invalid ticker: {ticker}. Tickers must be 1-5 alphabetic characters.")
20+
# check if ticker is a valid stock symbol, on yahoo finance
21+
try:
22+
_ = yf.Ticker(ticker).info # This will raise an error if ticker is invalid
23+
except Exception as e:
24+
raise ValueError(f"Invalid ticker: {ticker}. Error: {str(e)}")
25+
return ticker

0 commit comments

Comments
 (0)