|
1 | 1 | import logging |
2 | 2 | import os |
3 | 3 | import uuid |
4 | | -from datetime import datetime |
| 4 | +from datetime import datetime, timedelta |
5 | 5 | from pathlib import Path |
6 | 6 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
7 | 7 |
|
@@ -1197,6 +1197,171 @@ def schema(self) -> pa.Schema: |
1197 | 1197 | return pa.Table.from_pandas(df).schema |
1198 | 1198 |
|
1199 | 1199 |
|
| 1200 | +def _compute_non_entity_dates_ray( |
| 1201 | + feature_views: List[FeatureView], |
| 1202 | + start_date_opt: Optional[datetime], |
| 1203 | + end_date_opt: Optional[datetime], |
| 1204 | +) -> Tuple[datetime, datetime]: |
| 1205 | + # Why: derive bounded time window when no entity_df is provided using explicit dates or max TTL fallback |
| 1206 | + end_date = ( |
| 1207 | + make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow()) |
| 1208 | + ) |
| 1209 | + if start_date_opt is None: |
| 1210 | + max_ttl_seconds = 0 |
| 1211 | + for fv in feature_views: |
| 1212 | + if getattr(fv, "ttl", None): |
| 1213 | + try: |
| 1214 | + ttl_val = fv.ttl |
| 1215 | + if isinstance(ttl_val, timedelta): |
| 1216 | + max_ttl_seconds = max( |
| 1217 | + max_ttl_seconds, int(ttl_val.total_seconds()) |
| 1218 | + ) |
| 1219 | + except Exception: |
| 1220 | + pass |
| 1221 | + start_date = ( |
| 1222 | + end_date - timedelta(seconds=max_ttl_seconds) |
| 1223 | + if max_ttl_seconds > 0 |
| 1224 | + else end_date - timedelta(days=30) |
| 1225 | + ) |
| 1226 | + else: |
| 1227 | + start_date = make_tzaware(start_date_opt) |
| 1228 | + return start_date, end_date |
| 1229 | + |
| 1230 | + |
| 1231 | +def _make_filter_range(timestamp_field: str, start_date: datetime, end_date: datetime): |
| 1232 | + # Why: factory function for time-range filtering in Ray map_batches |
| 1233 | + def _filter_range(batch: pd.DataFrame) -> pd.Series: |
| 1234 | + ts = pd.to_datetime(batch[timestamp_field], utc=True) |
| 1235 | + return (ts >= start_date) & (ts <= end_date) |
| 1236 | + |
| 1237 | + return _filter_range |
| 1238 | + |
| 1239 | + |
| 1240 | +def _make_select_distinct_entity_timestamps(join_keys: List[str], timestamp_field: str): |
| 1241 | + # Why: factory function for distinct (entity_keys, event_timestamp) projection in Ray map_batches |
| 1242 | + # This preserves multiple transactions per entity ID with different timestamps for proper PIT joins |
| 1243 | + def _select_distinct_entity_timestamps(batch: pd.DataFrame) -> pd.DataFrame: |
| 1244 | + cols = [c for c in join_keys if c in batch.columns] |
| 1245 | + if timestamp_field in batch.columns: |
| 1246 | + # Rename timestamp to standardized event_timestamp |
| 1247 | + batch = batch.copy() |
| 1248 | + if timestamp_field != "event_timestamp": |
| 1249 | + batch["event_timestamp"] = batch[timestamp_field] |
| 1250 | + cols = cols + ["event_timestamp"] |
| 1251 | + if not cols: |
| 1252 | + return pd.DataFrame(columns=join_keys + ["event_timestamp"]) |
| 1253 | + return batch[cols].drop_duplicates().reset_index(drop=True) |
| 1254 | + |
| 1255 | + return _select_distinct_entity_timestamps |
| 1256 | + |
| 1257 | + |
| 1258 | +def _distinct_entities_for_feature_view_ray( |
| 1259 | + store: "RayOfflineStore", |
| 1260 | + config: RepoConfig, |
| 1261 | + fv: FeatureView, |
| 1262 | + registry: BaseRegistry, |
| 1263 | + project: str, |
| 1264 | + start_date: datetime, |
| 1265 | + end_date: datetime, |
| 1266 | +) -> Tuple[Dataset, List[str]]: |
| 1267 | + # Why: read minimal columns, filter by time, and project distinct (join_keys, event_timestamp) per FeatureView |
| 1268 | + # This preserves multiple transactions per entity ID for proper point-in-time joins |
| 1269 | + ray_wrapper = get_ray_wrapper() |
| 1270 | + entities = fv.entities or [] |
| 1271 | + entity_objs = [registry.get_entity(e, project) for e in entities] |
| 1272 | + original_join_keys, _rev_feats, timestamp_field, _created_col = _get_column_names( |
| 1273 | + fv, entity_objs |
| 1274 | + ) |
| 1275 | + |
| 1276 | + source_info = resolve_feature_view_source_with_fallback( |
| 1277 | + fv, config, is_materialization=False |
| 1278 | + ) |
| 1279 | + source_path = store._get_source_path(source_info.data_source, config) |
| 1280 | + required_columns = list(set(original_join_keys + [timestamp_field])) |
| 1281 | + ds = ray_wrapper.read_parquet(source_path, columns=required_columns) |
| 1282 | + |
| 1283 | + field_mapping = getattr(fv.batch_source, "field_mapping", None) |
| 1284 | + if field_mapping: |
| 1285 | + ds = apply_field_mapping(ds, field_mapping) |
| 1286 | + original_join_keys = [field_mapping.get(k, k) for k in original_join_keys] |
| 1287 | + timestamp_field = field_mapping.get(timestamp_field, timestamp_field) |
| 1288 | + |
| 1289 | + if fv.projection.join_key_map: |
| 1290 | + join_keys = [ |
| 1291 | + fv.projection.join_key_map.get(key, key) for key in original_join_keys |
| 1292 | + ] |
| 1293 | + else: |
| 1294 | + join_keys = original_join_keys |
| 1295 | + |
| 1296 | + ds = ensure_timestamp_compatibility(ds, [timestamp_field]) |
| 1297 | + ds = ds.filter(_make_filter_range(timestamp_field, start_date, end_date)) |
| 1298 | + # Extract distinct (entity_keys, event_timestamp) combinations - not just entity_keys |
| 1299 | + ds = ds.map_batches( |
| 1300 | + _make_select_distinct_entity_timestamps(join_keys, timestamp_field), |
| 1301 | + batch_format="pandas", |
| 1302 | + ) |
| 1303 | + return ds, join_keys |
| 1304 | + |
| 1305 | + |
| 1306 | +def _make_align_columns(all_join_keys: List[str], include_timestamp: bool = False): |
| 1307 | + # Why: factory function for schema alignment in Ray map_batches |
| 1308 | + # When include_timestamp=True, also aligns event_timestamp column for proper PIT joins |
| 1309 | + def _align_columns(batch: pd.DataFrame) -> pd.DataFrame: |
| 1310 | + batch = batch.copy() |
| 1311 | + output_cols = list(all_join_keys) |
| 1312 | + if include_timestamp: |
| 1313 | + output_cols = output_cols + ["event_timestamp"] |
| 1314 | + for k in output_cols: |
| 1315 | + if k not in batch.columns: |
| 1316 | + batch[k] = pd.NA |
| 1317 | + return batch[output_cols] |
| 1318 | + |
| 1319 | + return _align_columns |
| 1320 | + |
| 1321 | + |
| 1322 | +def _make_distinct_by_keys(keys: List[str], include_timestamp: bool = False): |
| 1323 | + # Why: factory function for deduplication in Ray map_batches |
| 1324 | + # When include_timestamp=True, deduplicates on (keys + event_timestamp) for proper PIT joins |
| 1325 | + def _distinct(batch: pd.DataFrame) -> pd.DataFrame: |
| 1326 | + subset = list(keys) |
| 1327 | + if include_timestamp and "event_timestamp" in batch.columns: |
| 1328 | + subset = subset + ["event_timestamp"] |
| 1329 | + return batch.drop_duplicates(subset=subset).reset_index(drop=True) |
| 1330 | + |
| 1331 | + return _distinct |
| 1332 | + |
| 1333 | + |
| 1334 | +def _align_and_union_entities_ray( |
| 1335 | + datasets: List[Dataset], |
| 1336 | + all_join_keys: List[str], |
| 1337 | + include_timestamp: bool = False, |
| 1338 | +) -> Dataset: |
| 1339 | + # Why: align schemas across FeatureViews and union to a unified entity set |
| 1340 | + # When include_timestamp=True, preserves distinct (entity_keys, event_timestamp) combinations |
| 1341 | + # for proper point-in-time joins with multiple transactions per entity |
| 1342 | + ray_wrapper = get_ray_wrapper() |
| 1343 | + output_cols = list(all_join_keys) |
| 1344 | + if include_timestamp: |
| 1345 | + output_cols = output_cols + ["event_timestamp"] |
| 1346 | + if not datasets: |
| 1347 | + return ray_wrapper.from_pandas(pd.DataFrame(columns=output_cols)) |
| 1348 | + |
| 1349 | + aligned = [ |
| 1350 | + ds.map_batches( |
| 1351 | + _make_align_columns(all_join_keys, include_timestamp=include_timestamp), |
| 1352 | + batch_format="pandas", |
| 1353 | + ) |
| 1354 | + for ds in datasets |
| 1355 | + ] |
| 1356 | + entity_ds = aligned[0] |
| 1357 | + for ds in aligned[1:]: |
| 1358 | + entity_ds = entity_ds.union(ds) |
| 1359 | + return entity_ds.map_batches( |
| 1360 | + _make_distinct_by_keys(all_join_keys, include_timestamp=include_timestamp), |
| 1361 | + batch_format="pandas", |
| 1362 | + ) |
| 1363 | + |
| 1364 | + |
1200 | 1365 | class RayOfflineStore(OfflineStore): |
1201 | 1366 | def __init__(self) -> None: |
1202 | 1367 | self._staging_location: Optional[str] = None |
@@ -1874,17 +2039,41 @@ def get_historical_features( |
1874 | 2039 | config: RepoConfig, |
1875 | 2040 | feature_views: List[FeatureView], |
1876 | 2041 | feature_refs: List[str], |
1877 | | - entity_df: Union[pd.DataFrame, str], |
| 2042 | + entity_df: Optional[Union[pd.DataFrame, str]], |
1878 | 2043 | registry: BaseRegistry, |
1879 | 2044 | project: str, |
1880 | 2045 | full_feature_names: bool = False, |
| 2046 | + **kwargs: Any, |
1881 | 2047 | ) -> RetrievalJob: |
1882 | 2048 | store = RayOfflineStore() |
1883 | 2049 | store._init_ray(config) |
1884 | 2050 |
|
1885 | | - # Load entity_df as Ray dataset for distributed processing |
| 2051 | + # Load or derive entity dataset for distributed processing |
1886 | 2052 | ray_wrapper = get_ray_wrapper() |
1887 | | - if isinstance(entity_df, str): |
| 2053 | + if entity_df is None: |
| 2054 | + # Non-entity mode: derive entity set from feature sources within a bounded time window |
| 2055 | + # Preserves distinct (entity_keys, event_timestamp) combinations for proper PIT joins |
| 2056 | + # This handles cases where multiple transactions per entity ID exist |
| 2057 | + start_date, end_date = _compute_non_entity_dates_ray( |
| 2058 | + feature_views, kwargs.get("start_date"), kwargs.get("end_date") |
| 2059 | + ) |
| 2060 | + per_view_entity_ds: List[Dataset] = [] |
| 2061 | + all_join_keys: List[str] = [] |
| 2062 | + for fv in feature_views: |
| 2063 | + ds, join_keys = _distinct_entities_for_feature_view_ray( |
| 2064 | + store, config, fv, registry, project, start_date, end_date |
| 2065 | + ) |
| 2066 | + per_view_entity_ds.append(ds) |
| 2067 | + for k in join_keys: |
| 2068 | + if k not in all_join_keys: |
| 2069 | + all_join_keys.append(k) |
| 2070 | + # Use include_timestamp=True to preserve actual event_timestamp from data |
| 2071 | + # instead of assigning a fixed end_date to all entities |
| 2072 | + entity_ds = _align_and_union_entities_ray( |
| 2073 | + per_view_entity_ds, all_join_keys, include_timestamp=True |
| 2074 | + ) |
| 2075 | + entity_df_sample = entity_ds.limit(1000).to_pandas() |
| 2076 | + elif isinstance(entity_df, str): |
1888 | 2077 | entity_ds = ray_wrapper.read_csv(entity_df) |
1889 | 2078 | entity_df_sample = entity_ds.limit(1000).to_pandas() |
1890 | 2079 | else: |
|
0 commit comments