Skip to content

Commit e484c12

Browse files
authored
feat: Offline Store historical features retrieval based on datetime range in Ray (#5738)
* feat: Offline Store historical features retrieval based on datetime range in Ray Signed-off-by: Aniket Paluskar <apaluska@redhat.com> * Reforamatted code to fix lint issues Signed-off-by: Aniket Paluskar <apaluska@redhat.com> * preserve event_timestamp in non-entity mode for correct point-in-time joins Signed-off-by: Aniket Paluskar <apaluska@redhat.com> * Minor lint changes Signed-off-by: Aniket Paluskar <apaluska@redhat.com> * Added test cases for datetime range based feature retrieval in Ray Signed-off-by: Aniket Paluskar <apaluska@redhat.com> --------- Signed-off-by: Aniket Paluskar <apaluska@redhat.com>
1 parent 75b5628 commit e484c12

File tree

2 files changed

+318
-4
lines changed

2 files changed

+318
-4
lines changed

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py

Lines changed: 193 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33
import uuid
4-
from datetime import datetime
4+
from datetime import datetime, timedelta
55
from pathlib import Path
66
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
77

@@ -1197,6 +1197,171 @@ def schema(self) -> pa.Schema:
11971197
return pa.Table.from_pandas(df).schema
11981198

11991199

1200+
def _compute_non_entity_dates_ray(
1201+
feature_views: List[FeatureView],
1202+
start_date_opt: Optional[datetime],
1203+
end_date_opt: Optional[datetime],
1204+
) -> Tuple[datetime, datetime]:
1205+
# Why: derive bounded time window when no entity_df is provided using explicit dates or max TTL fallback
1206+
end_date = (
1207+
make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow())
1208+
)
1209+
if start_date_opt is None:
1210+
max_ttl_seconds = 0
1211+
for fv in feature_views:
1212+
if getattr(fv, "ttl", None):
1213+
try:
1214+
ttl_val = fv.ttl
1215+
if isinstance(ttl_val, timedelta):
1216+
max_ttl_seconds = max(
1217+
max_ttl_seconds, int(ttl_val.total_seconds())
1218+
)
1219+
except Exception:
1220+
pass
1221+
start_date = (
1222+
end_date - timedelta(seconds=max_ttl_seconds)
1223+
if max_ttl_seconds > 0
1224+
else end_date - timedelta(days=30)
1225+
)
1226+
else:
1227+
start_date = make_tzaware(start_date_opt)
1228+
return start_date, end_date
1229+
1230+
1231+
def _make_filter_range(timestamp_field: str, start_date: datetime, end_date: datetime):
1232+
# Why: factory function for time-range filtering in Ray map_batches
1233+
def _filter_range(batch: pd.DataFrame) -> pd.Series:
1234+
ts = pd.to_datetime(batch[timestamp_field], utc=True)
1235+
return (ts >= start_date) & (ts <= end_date)
1236+
1237+
return _filter_range
1238+
1239+
1240+
def _make_select_distinct_entity_timestamps(join_keys: List[str], timestamp_field: str):
1241+
# Why: factory function for distinct (entity_keys, event_timestamp) projection in Ray map_batches
1242+
# This preserves multiple transactions per entity ID with different timestamps for proper PIT joins
1243+
def _select_distinct_entity_timestamps(batch: pd.DataFrame) -> pd.DataFrame:
1244+
cols = [c for c in join_keys if c in batch.columns]
1245+
if timestamp_field in batch.columns:
1246+
# Rename timestamp to standardized event_timestamp
1247+
batch = batch.copy()
1248+
if timestamp_field != "event_timestamp":
1249+
batch["event_timestamp"] = batch[timestamp_field]
1250+
cols = cols + ["event_timestamp"]
1251+
if not cols:
1252+
return pd.DataFrame(columns=join_keys + ["event_timestamp"])
1253+
return batch[cols].drop_duplicates().reset_index(drop=True)
1254+
1255+
return _select_distinct_entity_timestamps
1256+
1257+
1258+
def _distinct_entities_for_feature_view_ray(
1259+
store: "RayOfflineStore",
1260+
config: RepoConfig,
1261+
fv: FeatureView,
1262+
registry: BaseRegistry,
1263+
project: str,
1264+
start_date: datetime,
1265+
end_date: datetime,
1266+
) -> Tuple[Dataset, List[str]]:
1267+
# Why: read minimal columns, filter by time, and project distinct (join_keys, event_timestamp) per FeatureView
1268+
# This preserves multiple transactions per entity ID for proper point-in-time joins
1269+
ray_wrapper = get_ray_wrapper()
1270+
entities = fv.entities or []
1271+
entity_objs = [registry.get_entity(e, project) for e in entities]
1272+
original_join_keys, _rev_feats, timestamp_field, _created_col = _get_column_names(
1273+
fv, entity_objs
1274+
)
1275+
1276+
source_info = resolve_feature_view_source_with_fallback(
1277+
fv, config, is_materialization=False
1278+
)
1279+
source_path = store._get_source_path(source_info.data_source, config)
1280+
required_columns = list(set(original_join_keys + [timestamp_field]))
1281+
ds = ray_wrapper.read_parquet(source_path, columns=required_columns)
1282+
1283+
field_mapping = getattr(fv.batch_source, "field_mapping", None)
1284+
if field_mapping:
1285+
ds = apply_field_mapping(ds, field_mapping)
1286+
original_join_keys = [field_mapping.get(k, k) for k in original_join_keys]
1287+
timestamp_field = field_mapping.get(timestamp_field, timestamp_field)
1288+
1289+
if fv.projection.join_key_map:
1290+
join_keys = [
1291+
fv.projection.join_key_map.get(key, key) for key in original_join_keys
1292+
]
1293+
else:
1294+
join_keys = original_join_keys
1295+
1296+
ds = ensure_timestamp_compatibility(ds, [timestamp_field])
1297+
ds = ds.filter(_make_filter_range(timestamp_field, start_date, end_date))
1298+
# Extract distinct (entity_keys, event_timestamp) combinations - not just entity_keys
1299+
ds = ds.map_batches(
1300+
_make_select_distinct_entity_timestamps(join_keys, timestamp_field),
1301+
batch_format="pandas",
1302+
)
1303+
return ds, join_keys
1304+
1305+
1306+
def _make_align_columns(all_join_keys: List[str], include_timestamp: bool = False):
1307+
# Why: factory function for schema alignment in Ray map_batches
1308+
# When include_timestamp=True, also aligns event_timestamp column for proper PIT joins
1309+
def _align_columns(batch: pd.DataFrame) -> pd.DataFrame:
1310+
batch = batch.copy()
1311+
output_cols = list(all_join_keys)
1312+
if include_timestamp:
1313+
output_cols = output_cols + ["event_timestamp"]
1314+
for k in output_cols:
1315+
if k not in batch.columns:
1316+
batch[k] = pd.NA
1317+
return batch[output_cols]
1318+
1319+
return _align_columns
1320+
1321+
1322+
def _make_distinct_by_keys(keys: List[str], include_timestamp: bool = False):
1323+
# Why: factory function for deduplication in Ray map_batches
1324+
# When include_timestamp=True, deduplicates on (keys + event_timestamp) for proper PIT joins
1325+
def _distinct(batch: pd.DataFrame) -> pd.DataFrame:
1326+
subset = list(keys)
1327+
if include_timestamp and "event_timestamp" in batch.columns:
1328+
subset = subset + ["event_timestamp"]
1329+
return batch.drop_duplicates(subset=subset).reset_index(drop=True)
1330+
1331+
return _distinct
1332+
1333+
1334+
def _align_and_union_entities_ray(
1335+
datasets: List[Dataset],
1336+
all_join_keys: List[str],
1337+
include_timestamp: bool = False,
1338+
) -> Dataset:
1339+
# Why: align schemas across FeatureViews and union to a unified entity set
1340+
# When include_timestamp=True, preserves distinct (entity_keys, event_timestamp) combinations
1341+
# for proper point-in-time joins with multiple transactions per entity
1342+
ray_wrapper = get_ray_wrapper()
1343+
output_cols = list(all_join_keys)
1344+
if include_timestamp:
1345+
output_cols = output_cols + ["event_timestamp"]
1346+
if not datasets:
1347+
return ray_wrapper.from_pandas(pd.DataFrame(columns=output_cols))
1348+
1349+
aligned = [
1350+
ds.map_batches(
1351+
_make_align_columns(all_join_keys, include_timestamp=include_timestamp),
1352+
batch_format="pandas",
1353+
)
1354+
for ds in datasets
1355+
]
1356+
entity_ds = aligned[0]
1357+
for ds in aligned[1:]:
1358+
entity_ds = entity_ds.union(ds)
1359+
return entity_ds.map_batches(
1360+
_make_distinct_by_keys(all_join_keys, include_timestamp=include_timestamp),
1361+
batch_format="pandas",
1362+
)
1363+
1364+
12001365
class RayOfflineStore(OfflineStore):
12011366
def __init__(self) -> None:
12021367
self._staging_location: Optional[str] = None
@@ -1874,17 +2039,41 @@ def get_historical_features(
18742039
config: RepoConfig,
18752040
feature_views: List[FeatureView],
18762041
feature_refs: List[str],
1877-
entity_df: Union[pd.DataFrame, str],
2042+
entity_df: Optional[Union[pd.DataFrame, str]],
18782043
registry: BaseRegistry,
18792044
project: str,
18802045
full_feature_names: bool = False,
2046+
**kwargs: Any,
18812047
) -> RetrievalJob:
18822048
store = RayOfflineStore()
18832049
store._init_ray(config)
18842050

1885-
# Load entity_df as Ray dataset for distributed processing
2051+
# Load or derive entity dataset for distributed processing
18862052
ray_wrapper = get_ray_wrapper()
1887-
if isinstance(entity_df, str):
2053+
if entity_df is None:
2054+
# Non-entity mode: derive entity set from feature sources within a bounded time window
2055+
# Preserves distinct (entity_keys, event_timestamp) combinations for proper PIT joins
2056+
# This handles cases where multiple transactions per entity ID exist
2057+
start_date, end_date = _compute_non_entity_dates_ray(
2058+
feature_views, kwargs.get("start_date"), kwargs.get("end_date")
2059+
)
2060+
per_view_entity_ds: List[Dataset] = []
2061+
all_join_keys: List[str] = []
2062+
for fv in feature_views:
2063+
ds, join_keys = _distinct_entities_for_feature_view_ray(
2064+
store, config, fv, registry, project, start_date, end_date
2065+
)
2066+
per_view_entity_ds.append(ds)
2067+
for k in join_keys:
2068+
if k not in all_join_keys:
2069+
all_join_keys.append(k)
2070+
# Use include_timestamp=True to preserve actual event_timestamp from data
2071+
# instead of assigning a fixed end_date to all entities
2072+
entity_ds = _align_and_union_entities_ray(
2073+
per_view_entity_ds, all_join_keys, include_timestamp=True
2074+
)
2075+
entity_df_sample = entity_ds.limit(1000).to_pandas()
2076+
elif isinstance(entity_df, str):
18882077
entity_ds = ray_wrapper.read_csv(entity_df)
18892078
entity_df_sample = entity_ds.limit(1000).to_pandas()
18902079
else:

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from datetime import timedelta
2+
13
import pandas as pd
24
import pytest
35

@@ -144,3 +146,126 @@ def test_ray_offline_store_persist(environment, universal_data_sources):
144146
import os
145147

146148
assert os.path.exists(saved_path)
149+
150+
151+
@pytest.mark.integration
152+
@pytest.mark.universal_offline_stores
153+
def test_ray_offline_store_non_entity_mode_basic(environment, universal_data_sources):
154+
"""Test historical features retrieval without entity_df (non-entity mode).
155+
156+
This tests the basic functionality where entity_df=None and start_date/end_date
157+
are provided to retrieve all features within the time range.
158+
"""
159+
store = environment.feature_store
160+
161+
(entities, datasets, data_sources) = universal_data_sources
162+
feature_views = construct_universal_feature_views(data_sources)
163+
164+
store.apply(
165+
[
166+
driver(),
167+
feature_views.driver,
168+
]
169+
)
170+
171+
# Use the environment's start and end dates for the query
172+
start_date = environment.start_date
173+
end_date = environment.end_date
174+
175+
# Non-entity mode: entity_df=None with start_date and end_date
176+
result_df = store.get_historical_features(
177+
entity_df=None,
178+
features=[
179+
"driver_stats:conv_rate",
180+
"driver_stats:acc_rate",
181+
"driver_stats:avg_daily_trips",
182+
],
183+
full_feature_names=False,
184+
start_date=start_date,
185+
end_date=end_date,
186+
).to_df()
187+
188+
# Verify data was retrieved
189+
assert len(result_df) > 0, "Non-entity mode should return data"
190+
assert "conv_rate" in result_df.columns
191+
assert "acc_rate" in result_df.columns
192+
assert "avg_daily_trips" in result_df.columns
193+
assert "event_timestamp" in result_df.columns
194+
assert "driver_id" in result_df.columns
195+
196+
# Verify timestamps are within the requested range
197+
result_df["event_timestamp"] = pd.to_datetime(
198+
result_df["event_timestamp"], utc=True
199+
)
200+
assert (result_df["event_timestamp"] >= start_date).all()
201+
assert (result_df["event_timestamp"] <= end_date).all()
202+
203+
204+
@pytest.mark.integration
205+
@pytest.mark.universal_offline_stores
206+
def test_ray_offline_store_non_entity_mode_preserves_multiple_timestamps(
207+
environment, universal_data_sources
208+
):
209+
"""Test that non-entity mode preserves multiple transactions per entity ID.
210+
211+
This is a regression test for the fix that ensures distinct (entity_key, event_timestamp)
212+
combinations are preserved, not just distinct entity keys. This is critical for
213+
proper point-in-time joins when an entity has multiple transactions.
214+
"""
215+
store = environment.feature_store
216+
217+
(entities, datasets, data_sources) = universal_data_sources
218+
feature_views = construct_universal_feature_views(data_sources)
219+
220+
store.apply(
221+
[
222+
driver(),
223+
feature_views.driver,
224+
]
225+
)
226+
227+
now = _utc_now()
228+
ts1 = pd.Timestamp(now - timedelta(hours=2)).round("ms")
229+
ts2 = pd.Timestamp(now - timedelta(hours=1)).round("ms")
230+
ts3 = pd.Timestamp(now).round("ms")
231+
232+
# Write data with multiple timestamps for the same entity (driver_id=9001)
233+
df_to_write = pd.DataFrame.from_dict(
234+
{
235+
"event_timestamp": [ts1, ts2, ts3],
236+
"driver_id": [9001, 9001, 9001], # Same entity, different timestamps
237+
"conv_rate": [0.1, 0.2, 0.3],
238+
"acc_rate": [0.9, 0.8, 0.7],
239+
"avg_daily_trips": [10, 20, 30],
240+
"created": [ts1, ts2, ts3],
241+
},
242+
)
243+
244+
store.write_to_offline_store(
245+
feature_views.driver.name, df_to_write, allow_registry_cache=False
246+
)
247+
248+
# Query without entity_df - should get all 3 rows for driver_id=9001
249+
result_df = store.get_historical_features(
250+
entity_df=None,
251+
features=[
252+
"driver_stats:conv_rate",
253+
"driver_stats:acc_rate",
254+
],
255+
full_feature_names=False,
256+
start_date=ts1 - timedelta(minutes=1),
257+
end_date=ts3 + timedelta(minutes=1),
258+
).to_df()
259+
260+
# Filter to just our test entity
261+
result_df = result_df[result_df["driver_id"] == 9001]
262+
263+
# Verify we got all 3 rows with different timestamps (not just 1 row)
264+
assert len(result_df) == 3, (
265+
f"Expected 3 rows for driver_id=9001 (one per timestamp), got {len(result_df)}"
266+
)
267+
268+
# Verify the feature values are correct for each timestamp
269+
result_df = result_df.sort_values("event_timestamp").reset_index(drop=True)
270+
assert list(result_df["conv_rate"]) == [0.1, 0.2, 0.3]
271+
assert list(result_df["acc_rate"]) == [0.9, 0.8, 0.7]

0 commit comments

Comments
 (0)