Skip to content

Commit 9b9a787

Browse files
refactor
1 parent 50d6606 commit 9b9a787

File tree

12 files changed

+267
-176
lines changed

12 files changed

+267
-176
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ flowchart TD
157157

158158
## 💡 Main Features
159159

160-
- 📈 **Backtesting Engine** — Realistic execution, guardrails, cash balance checks
160+
- 📈 **Backtesting Engine** — Realistic execution, guardrails, cash shares checks
161161
- 🧠 **Pluggable Strategy Interface** — Stateful/stateless signal generation
162162
- 💼 **Portfolio Tracking** — Accurate PnL with trade logs, equity curves
163163
- 🛡️ **Guardrail System** — Risk management hooks (stop-loss, asset unregister)

core/backtester.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List
22
import pandas as pd
3-
3+
from core.executors.backtest import BacktestExecutor
44
from core.executors.base import BaseExecutor
55
from core.market_data import MarketData
66
from contracts.portfolio import Portfolio
@@ -23,33 +23,49 @@ def __init__(self, strategy: 'StrategyBase',
2323
portfolio: Portfolio,
2424
executor: 'BaseExecutor',
2525
**executor_kwargs):
26+
"""
27+
:param strategy:
28+
:param market_data:
29+
:param portfolio:
30+
:param executor:
31+
:param executor_kwargs:
32+
"""
2633
self.strategy = strategy
2734
self.market_data = market_data
2835
self.portfolio = portfolio
36+
self.executor = executor
2937

30-
if executor is not None:
31-
self.executor = executor
32-
else:
33-
from core.executors.backtest import BacktestExecutor
34-
self.executor = BacktestExecutor(portfolio=self.portfolio, market_data=market_data, **executor_kwargs)
35-
self.tickers = []
38+
self.tickers = portfolio.tickers
3639
self.start_date = None
3740
self.end_date = None
3841
self.signals = {}
3942

40-
def run(self, tickers: List[str], start_date: str, end_date: str):
41-
self.tickers = tickers
42-
self.start_date = start_date
43-
self.end_date = end_date
44-
price_frames = [self.market_data.get_series(tic) for tic in tickers]
45-
common_index = price_frames[0].index
46-
for df in price_frames[1:]:
47-
common_index = common_index.intersection(df.index)
48-
common_index = common_index.sort_values()
49-
for current_date in common_index:
43+
def run(self, start_date: str, end_date: str):
44+
self.start_date = pd.to_datetime(start_date)
45+
self.end_date = pd.to_datetime(end_date)
46+
47+
# Fetch market data for all tickers
48+
# check for strategy lok-back period, if any, and adjust start_date accordingly
49+
if hasattr(self.strategy, 'lookback_period'):
50+
lookback_period = self.strategy.lookback_period
51+
if isinstance(lookback_period, int):
52+
# Convert lookback period to a date offset
53+
start_date = (pd.to_datetime(start_date) - pd.DateOffset(days=(self.strategy.lookback_period+1))).strftime('%Y-%m-%d')
54+
else:
55+
raise ValueError("lookback_period must be an integer representing days")
56+
self.market_data.get_market_data(self.tickers + [self.portfolio.benchmark],
57+
start_date=start_date, end_date=end_date)
58+
59+
# Iterate through the common index dates
60+
for current_date in self.market_data.dates:
61+
# Generate slice of
62+
# --- MARKET DATA: FETCH HISTORICAL DATA FOR ALL TICKERS ---
63+
# current_date = pd.to_datetime(current_date)
64+
historical_data = self.market_data.get_history(self.tickers, end_date=current_date, lookback=self.strategy.lookback_period)
65+
5066
# --- PURE STRATEGY: ONLY GENERATE SIGNALS ---
51-
signals = self.strategy.generate_signals(self.market_data, current_date=current_date,
52-
positions=self.portfolio.positions)
67+
signals = self.strategy.generate_signals(historical_data, current_date=current_date,
68+
positions=self.portfolio.positions, cash=self.portfolio.cash)
5369
# --- EXECUTOR: SUBMIT ORDERS BASED ON SIGNALS ---
5470
for symbol, order_size in signals.items():
5571
if order_size == 0:

core/data_loader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, use_cache=True, force_refresh=False, source="yahoo"):
6969
self.force_refresh = force_refresh
7070
self.source = source
7171

72-
def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[str, pd.DataFrame]:
72+
def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[str, pd.DataFrame] | pd.DataFrame:
7373
"""
7474
Return a dictionary of {ticker: DataFrame} for all requested tickers.
7575
@@ -81,6 +81,15 @@ def get_data(self, tickers: List[str], start_date: str, end_date: str) -> Dict[s
8181
Returns:
8282
Dict[str, pd.DataFrame]: A dictionary containing the historical OHLCV data for each ticker.
8383
"""
84+
# TODO: Convert Dict structure to a single Multi-indexed dataframe?
85+
# data = load_price_data(
86+
# tickers,
87+
# start_date,
88+
# end_date,
89+
# use_cache=self.use_cache,
90+
# force_refresh=self.force_refresh,
91+
# source=self.source
92+
# )
8493
data = {}
8594
for ticker in tickers:
8695
data[ticker] = load_price_data(

core/executors/backtest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def step(self, current_time):
8181
self.order_status[order_id] = OrderStatus.REJECTED
8282
self.fills[order_id] = OrderResult(order_id=order_id, status=OrderStatus.REJECTED, message=str(e))
8383
# Track equity
84-
net_worth = self.portfolio.positions['CASH'].balance
84+
net_worth = self.portfolio.positions['CASH'].shares
8585

8686
for ticker in self.portfolio.tickers:
8787
shares = self.portfolio.positions[ticker].shares

core/market_data.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
class MarketData:
99
REQUIRED_COLUMNS = ['Open', 'High', 'Low', 'Close', 'Volume']
1010

11-
def __init__(self, data: Dict[str, pd.DataFrame]):
12-
self.data = data
13-
self._validate_all_data()
11+
def __init__(self, ingestion_manager: DataIngestionManager, simulation_start_date: str = None):
12+
self._ingestion_manager = ingestion_manager
13+
self.data: Dict[str, pd.DataFrame] | None = None
14+
self._simulation_start_date = pd.to_datetime(simulation_start_date)
1415

1516
def _validate_all_data(self):
1617
for ticker, df in self.data.items():
@@ -22,18 +23,45 @@ def _validate_all_data(self):
2223
if missing_cols:
2324
raise ValueError(f"{ticker} is missing required columns: {missing_cols}")
2425

25-
@classmethod
26-
def from_ingestion(cls,
27-
tickers: List[str],
28-
start_date: str,
29-
end_date: str,
30-
ingestion_manager: DataIngestionManager) -> 'MarketData':
26+
def _clean_and_align_data(self):
27+
"""
28+
Ensure all DataFrames have the required columns and share a common index
29+
:return:
30+
"""
31+
common_index = None
32+
for ticker, df in self.data.items():
33+
# Filter DataFrame to include only REQUIRED_COLUMNS
34+
self.data[ticker] = df[self.REQUIRED_COLUMNS].copy()
35+
# Update common_index to the intersection of all DataFrame indices
36+
common_index = df.index if common_index is None else common_index.intersection(df.index)
37+
common_index = common_index.sort_values()
38+
39+
if common_index.empty:
40+
raise ValueError("No common dates found across all tickers. Please check the date range and data availability.")
41+
if common_index is None:
42+
raise ValueError("No data available for the specified tickers and date range.")
43+
if not common_index.is_monotonic_increasing:
44+
raise ValueError("Common index dates are not in increasing order. Please check the data integrity.")
45+
46+
# Align all DataFrames to the common index
47+
self._dates = common_index
48+
for ticker in self.data.keys():
49+
self.data[ticker] = self.data[ticker].reindex(common_index)
50+
51+
def get_market_data(self,
52+
tickers: List[str],
53+
start_date: str,
54+
end_date: str) -> None:
3155
"""
3256
Create MarketData by fetching from a DataIngestionManager.
3357
"""
3458
tickers = [t.strip().upper() for t in tickers]
35-
raw_data = ingestion_manager.get_data(tickers, start_date, end_date)
36-
return cls(data=raw_data)
59+
raw_data: Dict[str, pd.DataFrame] = self._ingestion_manager.get_data(tickers, start_date, end_date)
60+
61+
# Populate and validate the raw data
62+
self.data = raw_data
63+
self._validate_all_data()
64+
self._clean_and_align_data()
3765

3866
def get_price(self, ticker: str, date: pd.Timestamp, price_type='Close') -> float | None:
3967
try:
@@ -49,16 +77,44 @@ def get_series(self, ticker: str, price_type='Close') -> pd.Series:
4977
def get_available_symbols(self) -> list:
5078
return list(self.data.keys())
5179

52-
def get_history(self, ticker: str, end_date: pd.Timestamp, lookback: int) -> pd.DataFrame:
80+
def get_history(self, ticker_list: List[str], end_date: pd.Timestamp, lookback: int) -> Dict[str, pd.DataFrame]:
5381
"""
5482
Return historical price data for a ticker ending on `end_date` and going back `lookback` days.
5583
"""
56-
if ticker not in self.data:
57-
raise ValueError(f"ticker {ticker} not found in market data.")
58-
59-
df = self.data[ticker]
60-
start_date = end_date - pd.Timedelta(days=lookback)
61-
return df.loc[start_date:end_date].copy()
84+
historical_data = {}
85+
for ticker in ticker_list:
86+
if ticker not in self.data:
87+
raise ValueError(f"ticker {ticker} not found in market data.")
88+
if end_date not in self.data[ticker].index:
89+
raise ValueError(f"end_date {end_date} not found in market data for ticker {ticker}.")
90+
if lookback <= 0:
91+
raise ValueError("lookback must be a positive integer.")
92+
if lookback > len(self.data[ticker]):
93+
raise ValueError(f"lookback {lookback} exceeds available data length for ticker {ticker}.")
94+
# Calculate start date based on lookback period
95+
if lookback == 0:
96+
start_date = end_date
97+
else:
98+
end_date = pd.to_datetime(end_date)
99+
start_date = end_date - pd.Timedelta(days=lookback)
100+
if ticker not in historical_data:
101+
historical_data[ticker] = self.data[ticker].loc[start_date:end_date].copy()
102+
return historical_data
62103

63104
def get_all_data(self) -> Dict[str, pd.DataFrame]:
64105
return self.data
106+
107+
@property
108+
def dates(self) -> pd.DatetimeIndex:
109+
"""
110+
Returns the common index of all DataFrames from the simulation start date onwards.
111+
Note: This method returns dates from the 'simulation' start date and not the data's start date (which includes the strategy lookback).
112+
"""
113+
if self.data is None or not self.data:
114+
raise ValueError("Market data has not been loaded yet.")
115+
116+
if self._simulation_start_date is None:
117+
return self._dates
118+
119+
# Filter dates to include only those on or after the simulation start date
120+
return self._dates[self._dates >= self._simulation_start_date]

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ dotenv~=0.9.9
88
python-dotenv~=1.1.0
99
dataclasses~=0.6
1010
dataclasses-json~=0.6.3
11-
streamlit~=1.45.1
11+
streamlit~=1.45.1
12+
pyfolio~=0.9.2
13+
backtrader~=1.9.76.123
14+
prefect~=2.14.12

run_backtest.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from utils.metrics import summarize_metrics
1212
from contracts.portfolio import Portfolio
1313
from core.executors.backtest import BacktestExecutor
14-
from core.executors.paper import PaperExecutor
15-
from core.executors.live import LiveExecutor
1614

1715
from analytics.performance_evaluator import PerformanceEvaluator
1816
import yfinance as yf
@@ -40,46 +38,39 @@ def parse_args():
4038

4139
def main():
4240
args = parse_args()
43-
tickers = [s.strip().upper() for s in args.tickers.split(",")]
44-
if args.strategy not in StrategyFactory.get_supported_strategies():
45-
raise ValueError(
46-
f"Unsupported strategy: {args.strategy}. Supported strategies: {StrategyFactory.get_supported_strategies()}")
47-
strategy = StrategyFactory.create_strategy(args.strategy)
48-
ingestion = DataIngestionManager(use_cache=True, force_refresh=args.refresh, source=args.source)
4941

50-
benchmark = args.benchmark.upper()
51-
if benchmark.startswith("^"):
52-
benchmark = benchmark[1:] # Remove caret for yfinance compatibility
53-
54-
market_data = MarketData.from_ingestion(tickers+[benchmark], args.start, args.end, ingestion)
55-
guardrails = [TrailingStopLossGuardrail()]
42+
# --- Portfolio Setup ---
5643
portfolio = Portfolio(
57-
name=f"{args.strategy.capitalize()}Portfolio",
58-
tickers=tickers,
44+
name=f"{args.strategy.capitalize()}-Portfolio",
45+
tickers=args.tickers,
5946
benchmark=args.benchmark,
6047
starting_cash=args.cash,
61-
strategy=strategy,
48+
strategy=args.strategy,
6249
metadata={"source": f"{args.mode.capitalize()}Executor"}
6350
)
64-
if args.mode == "paper":
65-
executor = PaperExecutor(portfolio=portfolio, slippage=args.slippage, market_data=market_data)
66-
elif args.mode == "live":
67-
if not args.broker:
68-
raise ValueError("--broker must be specified for live trading mode (e.g. brokers.alpaca.AlpacaBrokerAPI)")
69-
# Dynamically import the broker API module
70-
import importlib
71-
broker_module, broker_class = args.broker.rsplit('.', 1)
72-
broker_api_cls = getattr(importlib.import_module(broker_module), broker_class)
73-
broker_api = broker_api_cls() # User must configure credentials in the broker API implementation
74-
executor = LiveExecutor(portfolio=portfolio, broker_api=broker_api, market_data=market_data)
75-
else:
76-
executor = BacktestExecutor(portfolio=portfolio, market_data=market_data, guardrails=guardrails)
51+
tickers = portfolio.tickers
52+
benchmark = portfolio.benchmark
53+
strategy = portfolio.strategy
54+
55+
# --- Market Data Setup ---
56+
ingestion = DataIngestionManager(source=args.source)
57+
market_data = MarketData(ingestion, simulation_start_date=args.start)
58+
59+
# --- Executor Setup ---
60+
guardrails = [TrailingStopLossGuardrail()]
61+
executor = BacktestExecutor(portfolio=portfolio, market_data=market_data, guardrails=guardrails)
62+
63+
# --- Backtester Setup ---
7764
bt = Backtester(strategy=strategy, market_data=market_data, portfolio=portfolio, executor=executor)
78-
bt.run(tickers=tickers, start_date=args.start, end_date=args.end)
79-
equity_curve = bt.get_equity_curve()
80-
trade_log = bt.get_trade_log()
65+
bt.run(start_date=args.start, end_date=args.end)
66+
67+
# --- Backtest Results ---
68+
print(f"\n🚀 Backtest completed for {portfolio.name} with {len(tickers)} tickers from {args.start} to {args.end}")
8169
print(f"\n💰 Starting Cash: ${args.cash:,.2f}")
8270
print(f"\n📈 Final Net Worth: ${bt.get_final_net_worth():,.2f}")
71+
72+
equity_curve = bt.get_equity_curve()
73+
trade_log = bt.get_trade_log()
8374
if args.export:
8475
os.makedirs('./logs', exist_ok=True)
8576
equity_curve.to_csv("./logs/equity_curve.csv")

strategies/stock/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,32 @@ class StrategyBase(ABC):
1111
Abstract base class for trading strategies.
1212
All strategies must implement get_name and generate_signals.
1313
"""
14+
15+
def __init__(self):
16+
self.lookback_period = None
17+
1418
@abstractmethod
1519
def get_name(self) -> str:
1620
"""
1721
Return unique strategy name for identification.
1822
"""
1923
pass
2024

25+
@property
26+
def lookback(self) -> int:
27+
"""
28+
Return the lookback period in days for the strategy.
29+
:return:
30+
"""
31+
return self.lookback_period
32+
2133
@abstractmethod
2234
def generate_signals(
2335
self,
24-
market_data: 'MarketData',
36+
market_data: pd.DataFrame | Dict[str, pd.DataFrame],
2537
current_date: pd.Timestamp,
26-
positions: Dict[str, Asset | CashAsset]
38+
positions: Dict[str, Asset],
39+
cash: CashAsset
2740
) -> Dict[str, int]:
2841
"""
2942
For each asset (ticker), return the Number of shares to buy or sell.
@@ -33,6 +46,7 @@ def generate_signals(
3346
- <0 : Short (number of shares to sell)
3447
- 0 : No action (hold)
3548
Strategy must not look ahead beyond `current_date`.
49+
:param cash:
3650
"""
3751
pass
3852

0 commit comments

Comments
 (0)