88class MarketData :
99 REQUIRED_COLUMNS = ['Open' , 'High' , 'Low' , 'Close' , 'Volume' ]
1010
11- def __init__ (self , data : Dict [str , pd .DataFrame ]):
12- self .data = data
13- self ._validate_all_data ()
11+ def __init__ (self , ingestion_manager : DataIngestionManager , simulation_start_date : str = None ):
12+ self ._ingestion_manager = ingestion_manager
13+ self .data : Dict [str , pd .DataFrame ] | None = None
14+ self ._simulation_start_date = pd .to_datetime (simulation_start_date )
1415
1516 def _validate_all_data (self ):
1617 for ticker , df in self .data .items ():
@@ -22,18 +23,45 @@ def _validate_all_data(self):
2223 if missing_cols :
2324 raise ValueError (f"{ ticker } is missing required columns: { missing_cols } " )
2425
25- @classmethod
26- def from_ingestion (cls ,
27- tickers : List [str ],
28- start_date : str ,
29- end_date : str ,
30- ingestion_manager : DataIngestionManager ) -> 'MarketData' :
26+ def _clean_and_align_data (self ):
27+ """
28+ Ensure all DataFrames have the required columns and share a common index
29+ :return:
30+ """
31+ common_index = None
32+ for ticker , df in self .data .items ():
33+ # Filter DataFrame to include only REQUIRED_COLUMNS
34+ self .data [ticker ] = df [self .REQUIRED_COLUMNS ].copy ()
35+ # Update common_index to the intersection of all DataFrame indices
36+ common_index = df .index if common_index is None else common_index .intersection (df .index )
37+ common_index = common_index .sort_values ()
38+
39+ if common_index .empty :
40+ raise ValueError ("No common dates found across all tickers. Please check the date range and data availability." )
41+ if common_index is None :
42+ raise ValueError ("No data available for the specified tickers and date range." )
43+ if not common_index .is_monotonic_increasing :
44+ raise ValueError ("Common index dates are not in increasing order. Please check the data integrity." )
45+
46+ # Align all DataFrames to the common index
47+ self ._dates = common_index
48+ for ticker in self .data .keys ():
49+ self .data [ticker ] = self .data [ticker ].reindex (common_index )
50+
51+ def get_market_data (self ,
52+ tickers : List [str ],
53+ start_date : str ,
54+ end_date : str ) -> None :
3155 """
3256 Create MarketData by fetching from a DataIngestionManager.
3357 """
3458 tickers = [t .strip ().upper () for t in tickers ]
35- raw_data = ingestion_manager .get_data (tickers , start_date , end_date )
36- return cls (data = raw_data )
59+ raw_data : Dict [str , pd .DataFrame ] = self ._ingestion_manager .get_data (tickers , start_date , end_date )
60+
61+ # Populate and validate the raw data
62+ self .data = raw_data
63+ self ._validate_all_data ()
64+ self ._clean_and_align_data ()
3765
3866 def get_price (self , ticker : str , date : pd .Timestamp , price_type = 'Close' ) -> float | None :
3967 try :
@@ -49,16 +77,44 @@ def get_series(self, ticker: str, price_type='Close') -> pd.Series:
4977 def get_available_symbols (self ) -> list :
5078 return list (self .data .keys ())
5179
52- def get_history (self , ticker : str , end_date : pd .Timestamp , lookback : int ) -> pd .DataFrame :
80+ def get_history (self , ticker_list : List [ str ] , end_date : pd .Timestamp , lookback : int ) -> Dict [ str , pd .DataFrame ] :
5381 """
5482 Return historical price data for a ticker ending on `end_date` and going back `lookback` days.
5583 """
56- if ticker not in self .data :
57- raise ValueError (f"ticker { ticker } not found in market data." )
58-
59- df = self .data [ticker ]
60- start_date = end_date - pd .Timedelta (days = lookback )
61- return df .loc [start_date :end_date ].copy ()
84+ historical_data = {}
85+ for ticker in ticker_list :
86+ if ticker not in self .data :
87+ raise ValueError (f"ticker { ticker } not found in market data." )
88+ if end_date not in self .data [ticker ].index :
89+ raise ValueError (f"end_date { end_date } not found in market data for ticker { ticker } ." )
90+ if lookback <= 0 :
91+ raise ValueError ("lookback must be a positive integer." )
92+ if lookback > len (self .data [ticker ]):
93+ raise ValueError (f"lookback { lookback } exceeds available data length for ticker { ticker } ." )
94+ # Calculate start date based on lookback period
95+ if lookback == 0 :
96+ start_date = end_date
97+ else :
98+ end_date = pd .to_datetime (end_date )
99+ start_date = end_date - pd .Timedelta (days = lookback )
100+ if ticker not in historical_data :
101+ historical_data [ticker ] = self .data [ticker ].loc [start_date :end_date ].copy ()
102+ return historical_data
62103
63104 def get_all_data (self ) -> Dict [str , pd .DataFrame ]:
64105 return self .data
106+
107+ @property
108+ def dates (self ) -> pd .DatetimeIndex :
109+ """
110+ Returns the common index of all DataFrames from the simulation start date onwards.
111+ Note: This method returns dates from the 'simulation' start date and not the data's start date (which includes the strategy lookback).
112+ """
113+ if self .data is None or not self .data :
114+ raise ValueError ("Market data has not been loaded yet." )
115+
116+ if self ._simulation_start_date is None :
117+ return self ._dates
118+
119+ # Filter dates to include only those on or after the simulation start date
120+ return self ._dates [self ._dates >= self ._simulation_start_date ]
0 commit comments