Skip to content

Commit 25568ff

Browse files
Add interval and period support to data pipeline
Enhanced data loading, ingestion, and backtesting to support flexible interval and period parameters for historical data. Updated fetchers for Yahoo, Alpaca, and Polygon to accept interval arguments. Improved lookback handling, caching, and error handling. Adjusted default strategy parameters and CLI arguments for greater flexibility.
1 parent d1e5619 commit 25568ff

File tree

12 files changed

+233
-61
lines changed

12 files changed

+233
-61
lines changed

contracts/portfolio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self,
1919
tickers: str | List[str],
2020
starting_cash: float,
2121
strategy: str,
22-
benchmark: str = "SPY",
22+
benchmark: Optional[str] = None,
2323
guardrail: Optional[str] = None,
2424
rebalance_freq: Optional[str] = None,
2525
recomposition_freq: Optional[str] = None,
@@ -53,6 +53,7 @@ def __init__(self,
5353
self.strategy = StrategyFactory.create_strategy(strategy)
5454

5555
# Initialise benchmark
56+
benchmark = benchmark if benchmark else tickers[0]
5657
benchmark = clean_ticker(benchmark)
5758
assert isinstance(benchmark, str), "Benchmark must be a string"
5859
self.benchmark = benchmark

core/backtester.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22
import pandas as pd
33
from tqdm import tqdm
44

@@ -38,13 +38,9 @@ def __init__(self, strategy: 'StrategyBase',
3838
self.executor = executor
3939

4040
self.tickers = portfolio.tickers
41-
self.start_date = None
42-
self.end_date = None
4341
self.signals = {}
4442

45-
def run(self, start_date: str, end_date: str):
46-
self.start_date = pd.to_datetime(start_date)
47-
self.end_date = pd.to_datetime(end_date)
43+
def run(self, end_date: str, start_date: Optional[str] = None, interval='1d', period='5y'):
4844

4945
# Fetch market data for all tickers
5046
# check for strategy lok-back period, if any, and adjust start_date accordingly
@@ -56,15 +52,21 @@ def run(self, start_date: str, end_date: str):
5652
else:
5753
raise ValueError("lookback_period must be an integer representing days")
5854
self.market_data.get_market_data(self.tickers + [self.portfolio.benchmark],
59-
start_date=start_date, end_date=end_date)
55+
start_date=start_date, end_date=end_date,
56+
interval=interval, period=period)
6057

6158
# Iterate through the common index dates
6259
for current_date in tqdm(self.market_data.dates):
6360
try:
6461
# Generate slice of
6562
# --- MARKET DATA: FETCH HISTORICAL DATA FOR ALL TICKERS ---
6663
# current_date = pd.to_datetime(current_date)
67-
historical_data = self.market_data.get_history(self.tickers, end_date=current_date, lookback=self.strategy.lookback_period)
64+
65+
historical_data = self.market_data.get_history(self.tickers, lookback=self.strategy.lookback_period, end_date=current_date)
66+
if not historical_data:
67+
continue
68+
69+
# current_date = historical_data.iloc[-1].name
6870

6971
# --- PURE STRATEGY: ONLY GENERATE SIGNALS ---
7072
signals = self.strategy.generate_signals(historical_data, current_date=current_date,

core/data_loader.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,71 @@
11
import os
22
import hashlib
3+
from datetime import datetime, timedelta
4+
35
import pandas as pd
4-
from typing import Dict, List
6+
from typing import Dict, List, Optional
7+
8+
from utils.config import DATA_CACHE
9+
from utils.utils import period_to_timedelta
510

611

7-
def _make_cache_key(ticker: str, start_date: str, end_date: str, source: str) -> str:
8-
key = f"{ticker}_{start_date}_{end_date}_{source}"
12+
def _make_cache_key(*args, **kwargs) -> str:
13+
key = '_'.join(args)
914
return hashlib.md5(key.encode()).hexdigest()
1015

1116

12-
def _fetch_data(ticker: str, start_date: str, end_date: str, source: str) -> pd.DataFrame:
17+
def _fetch_data(ticker: str, start_date: str, end_date: str, interval: str, source: str) -> pd.DataFrame:
1318
if source == "yahoo":
1419
from data_ingestion.yahoo_fetcher import fetch_yahoo_data
15-
return fetch_yahoo_data(ticker, start_date, end_date)
20+
return fetch_yahoo_data(ticker, start_date, end_date, interval)
1621
elif source == "alpaca":
1722
from data_ingestion.alpaca_fetcher import fetch_alpaca_data
18-
return fetch_alpaca_data(ticker, start_date, end_date)
23+
return fetch_alpaca_data(ticker, start_date, end_date, interval)
1924
elif source == "polygon":
2025
from data_ingestion.polygon_fetcher import fetch_polygon_data
21-
return fetch_polygon_data(ticker, start_date, end_date)
26+
return fetch_polygon_data(ticker, start_date, end_date, interval)
2227
else:
2328
raise ValueError(f"Unsupported data source: {source}. Supported sources are 'yahoo', 'alpaca', and 'polygon'.")
2429

2530

26-
def load_price_data(ticker: str, start_date: str, end_date: str,
31+
def load_price_data(ticker: str, end_date: str,
32+
start_date: Optional[str] = None,
33+
interval: str = '1d',
2734
use_cache: bool = True,
2835
force_refresh: bool = False,
2936
source: str = "yahoo") -> pd.DataFrame:
3037
"""
3138
Load historical OHLCV data for a single ticker from the specified data source.
3239
3340
Args:
34-
ticker (str): The ticker ticker of the security.
35-
start_date (str): The start date of the data range.
41+
ticker (str): The ticker of the security.
3642
end_date (str): The end date of the data range.
43+
start_date (str, optional): The start date of the data range.
44+
interval (str): The data interval.
45+
period (int): The data period.
3746
use_cache (bool, optional): Whether to use cached data. Defaults to True.
3847
force_refresh (bool, optional): Whether to force a refresh of the data. Defaults to False.
3948
source (str, optional): The data source to use. Defaults to "yahoo".
4049
4150
Returns:
4251
pd.DataFrame: A pandas DataFrame containing the historical OHLCV data.
4352
"""
44-
os.makedirs("./data_cache", exist_ok=True)
45-
cache_key = _make_cache_key(ticker, start_date, end_date, source)
46-
cache_path = os.path.join("./data_cache", f"{cache_key}.parquet")
53+
os.makedirs(DATA_CACHE, exist_ok=True)
54+
cache_key = _make_cache_key(ticker, start_date, end_date, interval, source)
55+
cache_path = os.path.join(DATA_CACHE, f"{cache_key}.parquet")
4756

4857
if use_cache and os.path.exists(cache_path) and not force_refresh:
4958
try:
5059
return pd.read_parquet(cache_path)
5160
except Exception:
5261
print(f"⚠️ Cache corrupted at {cache_path}, refetching...")
5362

54-
df = _fetch_data(ticker, start_date, end_date, source)
63+
df = _fetch_data(ticker, start_date, end_date, interval, source)
64+
65+
try:
66+
df.index = df.index.tz_localize("UTC")
67+
except:
68+
df.index = df.index.tz_convert("UTC")
5569

5670
if df.empty or "Close" not in df.columns:
5771
raise ValueError(f"No data returned for {ticker} from {start_date} to {end_date}")
@@ -69,18 +83,38 @@ def __init__(self, use_cache=True, force_refresh=False, source="yahoo"):
6983
self.force_refresh = force_refresh
7084
self.source = source
7185

72-
def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[str, pd.DataFrame] | pd.DataFrame:
86+
def get_data(
87+
self, tickers: List[str], end_date: str, start_date: Optional[str] = None,
88+
interval: Optional[str] = '1d', period: Optional[int] = '5y',
89+
) -> Dict[str, pd.DataFrame] | pd.DataFrame:
7390
"""
7491
Return a dictionary of {ticker: DataFrame} for all requested tickers.
7592
7693
Args:
7794
tickers (List[str]): A list of ticker symbols.
7895
start_date (str): The start date of the data range.
7996
end_date (str): The end date of the data range.
97+
interval (str): The data interval.
98+
period (int): The data period.
8099
81100
Returns:
82101
Dict[str, pd.DataFrame]: A dictionary containing the historical OHLCV data for each ticker.
83102
"""
103+
start = pd.to_datetime(start_date) if start_date else datetime.today().date()
104+
end = pd.to_datetime(end_date) if end_date else datetime.today().date()
105+
106+
# IF start after end
107+
# OR, if interval in minutes, but period > 60
108+
if start >= end:
109+
period_int = period_to_timedelta(period)
110+
start -= period_int
111+
start_date = start.strftime('%Y-%m-%d')
112+
113+
elif interval.endswith("m") and (end - start).days >= 60: # Yahoo-finance limitation
114+
start = end - timedelta(days=59)
115+
start_date = start.strftime('%Y-%m-%d')
116+
117+
data = {}
84118
# TODO: Convert Dict structure to a single Multi-indexed dataframe?
85119
# data = load_price_data(
86120
# tickers,
@@ -90,12 +124,12 @@ def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[s
90124
# force_refresh=self.force_refresh,
91125
# source=self.source
92126
# )
93-
data = {}
94127
for ticker in tickers:
95128
data[ticker] = load_price_data(
96129
ticker,
97-
start_date,
98-
end_date,
130+
start_date=start_date,
131+
end_date=end_date,
132+
interval=interval,
99133
use_cache=self.use_cache,
100134
force_refresh=self.force_refresh,
101135
source=self.source

core/market_data.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# core/market_data.py
2+
from datetime import timedelta
3+
24
import pandas as pd
3-
from typing import Dict, Any, List
5+
from typing import Dict, Any, List, Optional
46

57
from core.data_loader import DataIngestionManager
68

@@ -11,7 +13,7 @@ class MarketData:
1113
def __init__(self, ingestion_manager: DataIngestionManager, simulation_start_date: str = None):
1214
self._ingestion_manager = ingestion_manager
1315
self.data: Dict[str, pd.DataFrame] | None = None
14-
self._simulation_start_date = pd.to_datetime(simulation_start_date)
16+
self._simulation_start_date = pd.to_datetime(simulation_start_date).tz_localize("UTC")
1517

1618
def _validate_all_data(self):
1719
for ticker, df in self.data.items():
@@ -46,17 +48,22 @@ def _clean_and_align_data(self):
4648
# Align all DataFrames to the common index
4749
self._dates = common_index
4850
for ticker in self.data.keys():
49-
self.data[ticker] = self.data[ticker].reindex(common_index)
51+
df = self.data[ticker].reindex(common_index).copy()
52+
df['SEQ'] = range(len(df))
53+
self.data[ticker] = df
5054

5155
def get_market_data(self,
5256
tickers: List[str],
53-
start_date: str,
54-
end_date: str) -> None:
57+
end_date: str,
58+
start_date: Optional[str] = None,
59+
interval: str = '1d',
60+
period: str = '5y') -> None:
5561
"""
5662
Create MarketData by fetching from a DataIngestionManager.
5763
"""
5864
tickers = [t.strip().upper() for t in tickers]
59-
raw_data: Dict[str, pd.DataFrame] = self._ingestion_manager.get_data(tickers, start_date, end_date)
65+
66+
raw_data: Dict[str, pd.DataFrame] = self._ingestion_manager.get_data(tickers=tickers, start_date=start_date, end_date=end_date, interval=interval, period=period)
6067

6168
# Populate and validate the raw data
6269
self.data = raw_data
@@ -84,11 +91,12 @@ def get_series(self, ticker: str, price_type='Close') -> pd.Series:
8491
def get_available_symbols(self) -> list:
8592
return list(self.data.keys())
8693

87-
def get_history(self, ticker_list: List[str], end_date: pd.Timestamp, lookback: int) -> Dict[str, pd.DataFrame]:
94+
def get_history(self, ticker_list: List[str], end_date: str, lookback: int) -> Dict[str, pd.DataFrame]:
8895
"""
8996
Return historical price data for a ticker ending on `end_date` and going back `lookback` days.
9097
"""
9198
historical_data = {}
99+
92100
for ticker in ticker_list:
93101
if ticker not in self.data:
94102
raise ValueError(f"ticker {ticker} not found in market data.")
@@ -98,14 +106,19 @@ def get_history(self, ticker_list: List[str], end_date: pd.Timestamp, lookback:
98106
raise ValueError("lookback must be a positive integer.")
99107
if lookback > len(self.data[ticker]):
100108
raise ValueError(f"lookback {lookback} exceeds available data length for ticker {ticker}.")
101-
# Calculate start date based on lookback period
102-
if lookback == 0:
103-
start_date = end_date
104-
else:
105-
end_date = pd.to_datetime(end_date)
106-
start_date = end_date - pd.Timedelta(days=lookback)
107-
if ticker not in historical_data:
108-
historical_data[ticker] = self.data[ticker].loc[start_date:end_date].copy()
109+
# Calculate start date based on a lookback period
110+
idx = self.data[ticker].loc[end_date]['SEQ']
111+
historical_data[ticker] = self.data[ticker][
112+
(self.data[ticker]['SEQ'] > idx - lookback) &
113+
(self.data[ticker]['SEQ'] <= idx)
114+
]
115+
# if lookback == 0:
116+
# start_date = end_date
117+
# else:
118+
# end_date = pd.to_datetime(end_date)
119+
# start_date = end_date - pd.Timedelta(days=lookback)
120+
# if ticker not in historical_data:
121+
# historical_data[ticker] = self.data[ticker].loc[start_date:end_date].copy()
109122
return historical_data
110123

111124
def get_all_data(self) -> Dict[str, pd.DataFrame]:
@@ -124,4 +137,4 @@ def dates(self) -> pd.DatetimeIndex:
124137
return self._dates
125138

126139
# Filter dates to include only those on or after the simulation start date
127-
return self._dates[self._dates >= self._simulation_start_date]
140+
return self._dates[self._dates >= self._simulation_start_date]

data_ingestion/alpaca_fetcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dotenv import load_dotenv
55

66

7-
def fetch_alpaca_data(ticker: str, start_date: str, end_date: str, timeframe: str = "1Day") -> pd.DataFrame:
7+
def fetch_alpaca_data(ticker: str, start_date: str, end_date: str, interval: str = "1Day") -> pd.DataFrame:
88
"""
99
Fetch historical OHLCV data from Alpaca for a given ticker and date range.
1010
Requires ALPACA_API_KEY and ALPACA_API_SECRET in environment or .env.
@@ -26,7 +26,7 @@ def fetch_alpaca_data(ticker: str, start_date: str, end_date: str, timeframe: st
2626
params = {
2727
"start": start_date,
2828
"end": end_date,
29-
"timeframe": timeframe,
29+
"timeframe": interval,
3030
"adjustment": "all",
3131
"limit": 10000
3232
}

data_ingestion/polygon_fetcher.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,66 @@
1-
def fetch_polygon_data(ticker: str, start_date: str, end_date: str):
2-
raise NotImplementedError("Polygon data fetch not implemented. Set your API key and implement.")
1+
from polygon import RESTClient
2+
import pandas as pd
3+
4+
from utils.config import POLYGON_API_KEY
5+
from utils.utils import split_period
6+
7+
from tenacity import retry, stop_after_attempt, wait_exponential
8+
import time
9+
10+
11+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=5, max=300))
12+
def fetch_with_retry(client, ticker, multiplier, timespan, start_date, end_date):
13+
try:
14+
polygon_response = client.list_aggs(
15+
ticker=ticker, multiplier=multiplier, timespan=timespan,
16+
from_=start_date, to=end_date, adjusted=True, sort='asc',
17+
limit=500
18+
)
19+
return polygon_response
20+
except Exception as e:
21+
print(f"Error: {e}. Retrying...")
22+
raise
23+
24+
25+
def fetch_polygon_data_with_backoff(client, ticker, multiplier, timespan, start_date, end_date):
26+
while True:
27+
try:
28+
return fetch_with_retry(client, ticker, multiplier, timespan, start_date, end_date)
29+
except Exception as e:
30+
print(f"Rate limit hit. Waiting before retrying...")
31+
for remaining in range(300, 0, -1):
32+
print(f"Retrying in {remaining} seconds...", end="\r")
33+
time.sleep(1)
34+
35+
36+
def fetch_polygon_data(ticker: str, start_date: str, end_date: str, interval: str) -> pd.DataFrame:
37+
"""
38+
Fetch historical price data from Polygon for a given stock.
39+
40+
Args:
41+
ticker (str): Stock ticker symbol.
42+
start_date (str): Start date in 'YYYY-MM-DD' format.
43+
end_date (str): End date in 'YYYY-MM-DD' format.
44+
interval (str): Price interval (e.g., '1m', '5m', '1h').
45+
api_key (str): Polygon API key.
46+
47+
Returns:
48+
pd.DataFrame: DataFrame containing historical price data.
49+
"""
50+
client = RESTClient(POLYGON_API_KEY)
51+
52+
aggs = []
53+
multiplier, timespan = split_period(interval)
54+
55+
polygon_response = fetch_polygon_data_with_backoff(
56+
client, ticker, multiplier, timespan, start_date, end_date
57+
)
58+
59+
for a in polygon_response:
60+
aggs.append(a)
61+
62+
df = pd.DataFrame(aggs)
63+
df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
64+
df.set_index("timestamp", inplace=True)
65+
df.columns = [k.capitalize() for k in df.columns]
66+
return df

data_ingestion/yahoo_fetcher.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@
33

44

55
def fetch_yahoo_data(ticker: str, start_date: str, end_date: str, interval: str = "1d") -> pd.DataFrame:
6-
data = yf.download(ticker, start=start_date, end=end_date, interval=interval, multi_level_index=False, threads=False)
6+
if interval.endswith("m"):
7+
data = yf.download(ticker, end=end_date, interval=interval, period='1wk', multi_level_index=False, threads=False)
8+
else:
9+
data = yf.download(ticker, start=start_date, end=end_date, multi_level_index=False, threads=False)
10+
711
return data

0 commit comments

Comments
 (0)