Skip to content

Commit 3e6fd07

Browse files
refactor price + polygon integraton (#5)
2 parents cddb598 + 22218ed commit 3e6fd07

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+399
-395
lines changed

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,8 @@ cython_debug/
175175
.idea/
176176
.qodo/
177177
/data_cache
178-
/figures
179-
/logs
178+
/output
180179
product_roadmap.md
181180
/run
182-
/logs
183181
/.run
184182
/.vscode

contracts/asset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC
22
from typing import Type, Optional
33

4-
from contracts.utils import clean_ticker
4+
from utils.utils import clean_ticker
55

66

77
class AssetBase(ABC):

contracts/portfolio.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from contracts.asset import Asset, CashAsset
66
from contracts.order import Order
7-
from contracts.utils import clean_ticker
8-
from core.guardrails.base import GuardrailFactory
9-
from strategies.stock.base import StrategyBase, StrategyFactory
7+
from utils.utils import clean_ticker
8+
from guardrails.base import GuardrailFactory
9+
from strategies.base import StrategyBase, StrategyFactory
1010

1111

1212
class Portfolio:
@@ -19,7 +19,7 @@ def __init__(self,
1919
tickers: str | List[str],
2020
starting_cash: float,
2121
strategy: str,
22-
benchmark: str = "SPY",
22+
benchmark: Optional[str] = None,
2323
guardrail: Optional[str] = None,
2424
rebalance_freq: Optional[str] = None,
2525
recomposition_freq: Optional[str] = None,
@@ -53,6 +53,7 @@ def __init__(self,
5353
self.strategy = StrategyFactory.create_strategy(strategy)
5454

5555
# Initialise benchmark
56+
benchmark = benchmark if benchmark else tickers[0]
5657
benchmark = clean_ticker(benchmark)
5758
assert isinstance(benchmark, str), "Benchmark must be a string"
5859
self.benchmark = benchmark

contracts/utils.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

core/backtester.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from typing import List
1+
from typing import List, Optional
22
import pandas as pd
33
from tqdm import tqdm
44

5-
from core.executors.backtest import BacktestExecutor
6-
from core.executors.base import BaseExecutor
5+
from executors.backtest import BacktestExecutor
6+
from executors.base import BaseExecutor
77
from core.market_data import MarketData
88
from contracts.portfolio import Portfolio
99
from contracts.order import Order, OrderSide, OrderType
10-
from strategies.stock import StrategyBase
10+
from strategies.base import StrategyBase
1111

1212

1313
class Backtester:
@@ -38,13 +38,9 @@ def __init__(self, strategy: 'StrategyBase',
3838
self.executor = executor
3939

4040
self.tickers = portfolio.tickers
41-
self.start_date = None
42-
self.end_date = None
4341
self.signals = {}
4442

45-
def run(self, start_date: str, end_date: str):
46-
self.start_date = pd.to_datetime(start_date)
47-
self.end_date = pd.to_datetime(end_date)
43+
def run(self, end_date: str, start_date: Optional[str] = None, interval='1d', period='5y'):
4844

4945
# Fetch market data for all tickers
5046
# check for strategy lok-back period, if any, and adjust start_date accordingly
@@ -56,15 +52,21 @@ def run(self, start_date: str, end_date: str):
5652
else:
5753
raise ValueError("lookback_period must be an integer representing days")
5854
self.market_data.get_market_data(self.tickers + [self.portfolio.benchmark],
59-
start_date=start_date, end_date=end_date)
55+
start_date=start_date, end_date=end_date,
56+
interval=interval, period=period)
6057

6158
# Iterate through the common index dates
6259
for current_date in tqdm(self.market_data.dates):
6360
try:
6461
# Generate slice of
6562
# --- MARKET DATA: FETCH HISTORICAL DATA FOR ALL TICKERS ---
6663
# current_date = pd.to_datetime(current_date)
67-
historical_data = self.market_data.get_history(self.tickers, end_date=current_date, lookback=self.strategy.lookback_period)
64+
65+
historical_data = self.market_data.get_history(self.tickers, lookback=self.strategy.lookback_period, end_date=current_date)
66+
if not historical_data:
67+
continue
68+
69+
# current_date = historical_data.iloc[-1].name
6870

6971
# --- PURE STRATEGY: ONLY GENERATE SIGNALS ---
7072
signals = self.strategy.generate_signals(historical_data, current_date=current_date,

core/data_loader.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,71 @@
11
import os
22
import hashlib
3+
from datetime import datetime, timedelta
4+
35
import pandas as pd
4-
from typing import Dict, List
6+
from typing import Dict, List, Optional
7+
8+
from utils.config import DATA_CACHE
9+
from utils.utils import period_to_timedelta
510

611

7-
def _make_cache_key(ticker: str, start_date: str, end_date: str, source: str) -> str:
8-
key = f"{ticker}_{start_date}_{end_date}_{source}"
12+
def _make_cache_key(*args, **kwargs) -> str:
13+
key = '_'.join(args)
914
return hashlib.md5(key.encode()).hexdigest()
1015

1116

12-
def _fetch_data(ticker: str, start_date: str, end_date: str, source: str) -> pd.DataFrame:
17+
def _fetch_data(ticker: str, start_date: str, end_date: str, interval: str, source: str) -> pd.DataFrame:
1318
if source == "yahoo":
1419
from data_ingestion.yahoo_fetcher import fetch_yahoo_data
15-
return fetch_yahoo_data(ticker, start_date, end_date)
20+
return fetch_yahoo_data(ticker, start_date, end_date, interval)
1621
elif source == "alpaca":
1722
from data_ingestion.alpaca_fetcher import fetch_alpaca_data
18-
return fetch_alpaca_data(ticker, start_date, end_date)
23+
return fetch_alpaca_data(ticker, start_date, end_date, interval)
1924
elif source == "polygon":
2025
from data_ingestion.polygon_fetcher import fetch_polygon_data
21-
return fetch_polygon_data(ticker, start_date, end_date)
26+
return fetch_polygon_data(ticker, start_date, end_date, interval)
2227
else:
2328
raise ValueError(f"Unsupported data source: {source}. Supported sources are 'yahoo', 'alpaca', and 'polygon'.")
2429

2530

26-
def load_price_data(ticker: str, start_date: str, end_date: str,
31+
def load_price_data(ticker: str, end_date: str,
32+
start_date: Optional[str] = None,
33+
interval: str = '1d',
2734
use_cache: bool = True,
2835
force_refresh: bool = False,
2936
source: str = "yahoo") -> pd.DataFrame:
3037
"""
3138
Load historical OHLCV data for a single ticker from the specified data source.
3239
3340
Args:
34-
ticker (str): The ticker ticker of the security.
35-
start_date (str): The start date of the data range.
41+
ticker (str): The ticker of the security.
3642
end_date (str): The end date of the data range.
43+
start_date (str, optional): The start date of the data range.
44+
interval (str): The data interval.
45+
period (int): The data period.
3746
use_cache (bool, optional): Whether to use cached data. Defaults to True.
3847
force_refresh (bool, optional): Whether to force a refresh of the data. Defaults to False.
3948
source (str, optional): The data source to use. Defaults to "yahoo".
4049
4150
Returns:
4251
pd.DataFrame: A pandas DataFrame containing the historical OHLCV data.
4352
"""
44-
os.makedirs("./data_cache", exist_ok=True)
45-
cache_key = _make_cache_key(ticker, start_date, end_date, source)
46-
cache_path = os.path.join("./data_cache", f"{cache_key}.parquet")
53+
os.makedirs(DATA_CACHE, exist_ok=True)
54+
cache_key = _make_cache_key(ticker, start_date, end_date, interval, source)
55+
cache_path = os.path.join(DATA_CACHE, f"{cache_key}.parquet")
4756

4857
if use_cache and os.path.exists(cache_path) and not force_refresh:
4958
try:
5059
return pd.read_parquet(cache_path)
5160
except Exception:
5261
print(f"⚠️ Cache corrupted at {cache_path}, refetching...")
5362

54-
df = _fetch_data(ticker, start_date, end_date, source)
63+
df = _fetch_data(ticker, start_date, end_date, interval, source)
64+
65+
try:
66+
df.index = df.index.tz_localize("UTC")
67+
except:
68+
df.index = df.index.tz_convert("UTC")
5569

5670
if df.empty or "Close" not in df.columns:
5771
raise ValueError(f"No data returned for {ticker} from {start_date} to {end_date}")
@@ -69,18 +83,38 @@ def __init__(self, use_cache=True, force_refresh=False, source="yahoo"):
6983
self.force_refresh = force_refresh
7084
self.source = source
7185

72-
def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[str, pd.DataFrame] | pd.DataFrame:
86+
def get_data(
87+
self, tickers: List[str], end_date: str, start_date: Optional[str] = None,
88+
interval: Optional[str] = '1d', period: Optional[int] = '5y',
89+
) -> Dict[str, pd.DataFrame] | pd.DataFrame:
7390
"""
7491
Return a dictionary of {ticker: DataFrame} for all requested tickers.
7592
7693
Args:
7794
tickers (List[str]): A list of ticker symbols.
7895
start_date (str): The start date of the data range.
7996
end_date (str): The end date of the data range.
97+
interval (str): The data interval.
98+
period (int): The data period.
8099
81100
Returns:
82101
Dict[str, pd.DataFrame]: A dictionary containing the historical OHLCV data for each ticker.
83102
"""
103+
start = pd.to_datetime(start_date) if start_date else datetime.today().date()
104+
end = pd.to_datetime(end_date) if end_date else datetime.today().date()
105+
106+
# IF start after end
107+
# OR, if interval in minutes, but period > 60
108+
if start >= end:
109+
period_int = period_to_timedelta(period)
110+
start -= period_int
111+
start_date = start.strftime('%Y-%m-%d')
112+
113+
elif interval.endswith("m") and (end - start).days >= 60: # Yahoo-finance limitation
114+
start = end - timedelta(days=59)
115+
start_date = start.strftime('%Y-%m-%d')
116+
117+
data = {}
84118
# TODO: Convert Dict structure to a single Multi-indexed dataframe?
85119
# data = load_price_data(
86120
# tickers,
@@ -90,12 +124,12 @@ def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[s
90124
# force_refresh=self.force_refresh,
91125
# source=self.source
92126
# )
93-
data = {}
94127
for ticker in tickers:
95128
data[ticker] = load_price_data(
96129
ticker,
97-
start_date,
98-
end_date,
130+
start_date=start_date,
131+
end_date=end_date,
132+
interval=interval,
99133
use_cache=self.use_cache,
100134
force_refresh=self.force_refresh,
101135
source=self.source

core/guardrails.py

Whitespace-only changes.

core/market_data.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# core/market_data.py
2+
from datetime import timedelta
3+
24
import pandas as pd
3-
from typing import Dict, Any, List
5+
from typing import Dict, Any, List, Optional
46

57
from core.data_loader import DataIngestionManager
68

@@ -11,7 +13,7 @@ class MarketData:
1113
def __init__(self, ingestion_manager: DataIngestionManager, simulation_start_date: str = None):
1214
self._ingestion_manager = ingestion_manager
1315
self.data: Dict[str, pd.DataFrame] | None = None
14-
self._simulation_start_date = pd.to_datetime(simulation_start_date)
16+
self._simulation_start_date = pd.to_datetime(simulation_start_date).tz_localize("UTC")
1517

1618
def _validate_all_data(self):
1719
for ticker, df in self.data.items():
@@ -46,17 +48,22 @@ def _clean_and_align_data(self):
4648
# Align all DataFrames to the common index
4749
self._dates = common_index
4850
for ticker in self.data.keys():
49-
self.data[ticker] = self.data[ticker].reindex(common_index)
51+
df = self.data[ticker].reindex(common_index).copy()
52+
df['SEQ'] = range(len(df))
53+
self.data[ticker] = df
5054

5155
def get_market_data(self,
5256
tickers: List[str],
53-
start_date: str,
54-
end_date: str) -> None:
57+
end_date: str,
58+
start_date: Optional[str] = None,
59+
interval: str = '1d',
60+
period: str = '5y') -> None:
5561
"""
5662
Create MarketData by fetching from a DataIngestionManager.
5763
"""
5864
tickers = [t.strip().upper() for t in tickers]
59-
raw_data: Dict[str, pd.DataFrame] = self._ingestion_manager.get_data(tickers, start_date, end_date)
65+
66+
raw_data: Dict[str, pd.DataFrame] = self._ingestion_manager.get_data(tickers=tickers, start_date=start_date, end_date=end_date, interval=interval, period=period)
6067

6168
# Populate and validate the raw data
6269
self.data = raw_data
@@ -84,11 +91,12 @@ def get_series(self, ticker: str, price_type='Close') -> pd.Series:
8491
def get_available_symbols(self) -> list:
8592
return list(self.data.keys())
8693

87-
def get_history(self, ticker_list: List[str], end_date: pd.Timestamp, lookback: int) -> Dict[str, pd.DataFrame]:
94+
def get_history(self, ticker_list: List[str], end_date: str, lookback: int) -> Dict[str, pd.DataFrame]:
8895
"""
8996
Return historical price data for a ticker ending on `end_date` and going back `lookback` days.
9097
"""
9198
historical_data = {}
99+
92100
for ticker in ticker_list:
93101
if ticker not in self.data:
94102
raise ValueError(f"ticker {ticker} not found in market data.")
@@ -98,14 +106,19 @@ def get_history(self, ticker_list: List[str], end_date: pd.Timestamp, lookback:
98106
raise ValueError("lookback must be a positive integer.")
99107
if lookback > len(self.data[ticker]):
100108
raise ValueError(f"lookback {lookback} exceeds available data length for ticker {ticker}.")
101-
# Calculate start date based on lookback period
102-
if lookback == 0:
103-
start_date = end_date
104-
else:
105-
end_date = pd.to_datetime(end_date)
106-
start_date = end_date - pd.Timedelta(days=lookback)
107-
if ticker not in historical_data:
108-
historical_data[ticker] = self.data[ticker].loc[start_date:end_date].copy()
109+
# Calculate start date based on a lookback period
110+
idx = self.data[ticker].loc[end_date]['SEQ']
111+
historical_data[ticker] = self.data[ticker][
112+
(self.data[ticker]['SEQ'] > idx - lookback) &
113+
(self.data[ticker]['SEQ'] <= idx)
114+
]
115+
# if lookback == 0:
116+
# start_date = end_date
117+
# else:
118+
# end_date = pd.to_datetime(end_date)
119+
# start_date = end_date - pd.Timedelta(days=lookback)
120+
# if ticker not in historical_data:
121+
# historical_data[ticker] = self.data[ticker].loc[start_date:end_date].copy()
109122
return historical_data
110123

111124
def get_all_data(self) -> Dict[str, pd.DataFrame]:
@@ -124,4 +137,4 @@ def dates(self) -> pd.DatetimeIndex:
124137
return self._dates
125138

126139
# Filter dates to include only those on or after the simulation start date
127-
return self._dates[self._dates >= self._simulation_start_date]
140+
return self._dates[self._dates >= self._simulation_start_date]

0 commit comments

Comments
 (0)