Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

Expand Down Expand Up @@ -1197,6 +1197,171 @@ def schema(self) -> pa.Schema:
return pa.Table.from_pandas(df).schema


def _compute_non_entity_dates_ray(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have make a common utility function for this, so that it can be used in all stores without repeating the code.

wdyt ?

feature_views: List[FeatureView],
start_date_opt: Optional[datetime],
end_date_opt: Optional[datetime],
) -> Tuple[datetime, datetime]:
# Why: derive bounded time window when no entity_df is provided using explicit dates or max TTL fallback
end_date = (
make_tzaware(end_date_opt) if end_date_opt else make_tzaware(datetime.utcnow())
)
if start_date_opt is None:
max_ttl_seconds = 0
for fv in feature_views:
if getattr(fv, "ttl", None):
try:
ttl_val = fv.ttl
if isinstance(ttl_val, timedelta):
max_ttl_seconds = max(
max_ttl_seconds, int(ttl_val.total_seconds())
)
except Exception:
pass
start_date = (
end_date - timedelta(seconds=max_ttl_seconds)
if max_ttl_seconds > 0
else end_date - timedelta(days=30)
)
else:
start_date = make_tzaware(start_date_opt)
return start_date, end_date


def _make_filter_range(timestamp_field: str, start_date: datetime, end_date: datetime):
# Why: factory function for time-range filtering in Ray map_batches
def _filter_range(batch: pd.DataFrame) -> pd.Series:
ts = pd.to_datetime(batch[timestamp_field], utc=True)
return (ts >= start_date) & (ts <= end_date)

return _filter_range


def _make_select_distinct_entity_timestamps(join_keys: List[str], timestamp_field: str):
# Why: factory function for distinct (entity_keys, event_timestamp) projection in Ray map_batches
# This preserves multiple transactions per entity ID with different timestamps for proper PIT joins
def _select_distinct_entity_timestamps(batch: pd.DataFrame) -> pd.DataFrame:
cols = [c for c in join_keys if c in batch.columns]
if timestamp_field in batch.columns:
# Rename timestamp to standardized event_timestamp
batch = batch.copy()
if timestamp_field != "event_timestamp":
batch["event_timestamp"] = batch[timestamp_field]
cols = cols + ["event_timestamp"]
if not cols:
return pd.DataFrame(columns=join_keys + ["event_timestamp"])
return batch[cols].drop_duplicates().reset_index(drop=True)

return _select_distinct_entity_timestamps


def _distinct_entities_for_feature_view_ray(
store: "RayOfflineStore",
config: RepoConfig,
fv: FeatureView,
registry: BaseRegistry,
project: str,
start_date: datetime,
end_date: datetime,
) -> Tuple[Dataset, List[str]]:
# Why: read minimal columns, filter by time, and project distinct (join_keys, event_timestamp) per FeatureView
# This preserves multiple transactions per entity ID for proper point-in-time joins
ray_wrapper = get_ray_wrapper()
entities = fv.entities or []
entity_objs = [registry.get_entity(e, project) for e in entities]
original_join_keys, _rev_feats, timestamp_field, _created_col = _get_column_names(
fv, entity_objs
)

source_info = resolve_feature_view_source_with_fallback(
fv, config, is_materialization=False
)
source_path = store._get_source_path(source_info.data_source, config)
required_columns = list(set(original_join_keys + [timestamp_field]))
ds = ray_wrapper.read_parquet(source_path, columns=required_columns)

field_mapping = getattr(fv.batch_source, "field_mapping", None)
if field_mapping:
ds = apply_field_mapping(ds, field_mapping)
original_join_keys = [field_mapping.get(k, k) for k in original_join_keys]
timestamp_field = field_mapping.get(timestamp_field, timestamp_field)

if fv.projection.join_key_map:
join_keys = [
fv.projection.join_key_map.get(key, key) for key in original_join_keys
]
else:
join_keys = original_join_keys

ds = ensure_timestamp_compatibility(ds, [timestamp_field])
ds = ds.filter(_make_filter_range(timestamp_field, start_date, end_date))
# Extract distinct (entity_keys, event_timestamp) combinations - not just entity_keys
ds = ds.map_batches(
_make_select_distinct_entity_timestamps(join_keys, timestamp_field),
batch_format="pandas",
)
return ds, join_keys


def _make_align_columns(all_join_keys: List[str], include_timestamp: bool = False):
# Why: factory function for schema alignment in Ray map_batches
# When include_timestamp=True, also aligns event_timestamp column for proper PIT joins
def _align_columns(batch: pd.DataFrame) -> pd.DataFrame:
batch = batch.copy()
output_cols = list(all_join_keys)
if include_timestamp:
output_cols = output_cols + ["event_timestamp"]
for k in output_cols:
if k not in batch.columns:
batch[k] = pd.NA
return batch[output_cols]

return _align_columns


def _make_distinct_by_keys(keys: List[str], include_timestamp: bool = False):
# Why: factory function for deduplication in Ray map_batches
# When include_timestamp=True, deduplicates on (keys + event_timestamp) for proper PIT joins
def _distinct(batch: pd.DataFrame) -> pd.DataFrame:
subset = list(keys)
if include_timestamp and "event_timestamp" in batch.columns:
subset = subset + ["event_timestamp"]
return batch.drop_duplicates(subset=subset).reset_index(drop=True)

return _distinct


def _align_and_union_entities_ray(
datasets: List[Dataset],
all_join_keys: List[str],
include_timestamp: bool = False,
) -> Dataset:
# Why: align schemas across FeatureViews and union to a unified entity set
# When include_timestamp=True, preserves distinct (entity_keys, event_timestamp) combinations
# for proper point-in-time joins with multiple transactions per entity
ray_wrapper = get_ray_wrapper()
output_cols = list(all_join_keys)
if include_timestamp:
output_cols = output_cols + ["event_timestamp"]
if not datasets:
return ray_wrapper.from_pandas(pd.DataFrame(columns=output_cols))

aligned = [
ds.map_batches(
_make_align_columns(all_join_keys, include_timestamp=include_timestamp),
batch_format="pandas",
)
for ds in datasets
]
entity_ds = aligned[0]
for ds in aligned[1:]:
entity_ds = entity_ds.union(ds)
return entity_ds.map_batches(
_make_distinct_by_keys(all_join_keys, include_timestamp=include_timestamp),
batch_format="pandas",
)


class RayOfflineStore(OfflineStore):
def __init__(self) -> None:
self._staging_location: Optional[str] = None
Expand Down Expand Up @@ -1874,17 +2039,41 @@ def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pd.DataFrame, str],
entity_df: Optional[Union[pd.DataFrame, str]],
registry: BaseRegistry,
project: str,
full_feature_names: bool = False,
**kwargs: Any,
) -> RetrievalJob:
store = RayOfflineStore()
store._init_ray(config)

# Load entity_df as Ray dataset for distributed processing
# Load or derive entity dataset for distributed processing
ray_wrapper = get_ray_wrapper()
if isinstance(entity_df, str):
if entity_df is None:
# Non-entity mode: derive entity set from feature sources within a bounded time window
# Preserves distinct (entity_keys, event_timestamp) combinations for proper PIT joins
# This handles cases where multiple transactions per entity ID exist
start_date, end_date = _compute_non_entity_dates_ray(
feature_views, kwargs.get("start_date"), kwargs.get("end_date")
)
per_view_entity_ds: List[Dataset] = []
all_join_keys: List[str] = []
for fv in feature_views:
ds, join_keys = _distinct_entities_for_feature_view_ray(
store, config, fv, registry, project, start_date, end_date
)
per_view_entity_ds.append(ds)
for k in join_keys:
if k not in all_join_keys:
all_join_keys.append(k)
# Use include_timestamp=True to preserve actual event_timestamp from data
# instead of assigning a fixed end_date to all entities
entity_ds = _align_and_union_entities_ray(
per_view_entity_ds, all_join_keys, include_timestamp=True
)
entity_df_sample = entity_ds.limit(1000).to_pandas()
elif isinstance(entity_df, str):
entity_ds = ray_wrapper.read_csv(entity_df)
entity_df_sample = entity_ds.limit(1000).to_pandas()
else:
Expand Down
Loading