Skip to content
3 changes: 3 additions & 0 deletions sdk/python/feast/arrow_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def wrapper(*args, **kwargs):
except Exception as e:
if isinstance(e, FeastError):
raise fl.FlightError(e.to_error_detail())
# Re-raise non-Feast exceptions so Arrow Flight returns a proper error
# instead of allowing the server method to return None.
raise e

return wrapper

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import uuid
import warnings
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
KeysView,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -151,10 +152,11 @@ def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
entity_df: Optional[Union[pandas.DataFrame, str, pyspark.sql.DataFrame]],
registry: BaseRegistry,
project: str,
full_feature_names: bool = False,
**kwargs,
) -> RetrievalJob:
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
date_partition_column_formats = []
Expand All @@ -175,33 +177,75 @@ def get_historical_features(
)
tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name()

entity_schema = _get_entity_schema(
spark_session=spark_session,
entity_df=entity_df,
)
event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
event_timestamp_col,
spark_session,
)
_upload_entity_df(
spark_session=spark_session,
table_name=tmp_entity_df_table_name,
entity_df=entity_df,
event_timestamp_col=event_timestamp_col,
)
# Non-entity mode: synthesize a left table and timestamp range from start/end dates to avoid requiring entity_df.
# This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time.
non_entity_mode = entity_df is None
if non_entity_mode:
# Why: derive bounded time window without requiring entities; uses max TTL fallback to constrain scans.
start_date, end_date = _compute_non_entity_dates(feature_views, kwargs)
entity_df_event_timestamp_range = (start_date, end_date)

# Build query contexts so we can reuse entity names and per-view table info consistently.
fv_query_contexts = offline_utils.get_feature_view_query_context(
feature_refs,
feature_views,
registry,
project,
entity_df_event_timestamp_range,
)

expected_join_keys = offline_utils.get_expected_join_keys(
project=project, feature_views=feature_views, registry=registry
)
offline_utils.assert_expected_columns_in_entity_df(
entity_schema=entity_schema,
join_keys=expected_join_keys,
entity_df_event_timestamp_col=event_timestamp_col,
)
# Collect the union of entity columns required across all feature views.
all_entities = _gather_all_entities(fv_query_contexts)

# Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned.
_create_temp_entity_union_view(
spark_session=spark_session,
tmp_view_name=tmp_entity_df_table_name,
feature_views=feature_views,
fv_query_contexts=fv_query_contexts,
start_date=start_date,
end_date=end_date,
date_partition_column_formats=date_partition_column_formats,
)

# Add a stable as-of timestamp column for PIT joins.
left_table_query_string, event_timestamp_col = _make_left_table_query(
end_date=end_date, tmp_view_name=tmp_entity_df_table_name
)
entity_schema_keys = _entity_schema_keys_from(
all_entities=all_entities, event_timestamp_col=event_timestamp_col
)
else:
entity_schema = _get_entity_schema(
spark_session=spark_session,
entity_df=entity_df,
)
event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
event_timestamp_col,
spark_session,
)
_upload_entity_df(
spark_session=spark_session,
table_name=tmp_entity_df_table_name,
entity_df=entity_df,
event_timestamp_col=event_timestamp_col,
)
left_table_query_string = tmp_entity_df_table_name
entity_schema_keys = cast(KeysView[str], entity_schema.keys())

if not non_entity_mode:
expected_join_keys = offline_utils.get_expected_join_keys(
project=project, feature_views=feature_views, registry=registry
)
offline_utils.assert_expected_columns_in_entity_df(
entity_schema=entity_schema,
join_keys=expected_join_keys,
entity_df_event_timestamp_col=event_timestamp_col,
)

query_context = offline_utils.get_feature_view_query_context(
feature_refs,
Expand Down Expand Up @@ -232,9 +276,9 @@ def get_historical_features(
feature_view_query_contexts=cast(
List[offline_utils.FeatureViewQueryContext], spark_query_context
),
left_table_query_string=tmp_entity_df_table_name,
left_table_query_string=left_table_query_string,
entity_df_event_timestamp_col=event_timestamp_col,
entity_df_columns=entity_schema.keys(),
entity_df_columns=entity_schema_keys,
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
)
Expand All @@ -248,7 +292,7 @@ def get_historical_features(
),
metadata=RetrievalMetadata(
features=feature_refs,
keys=list(set(entity_schema.keys()) - {event_timestamp_col}),
keys=list(set(entity_schema_keys) - {event_timestamp_col}),
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
Expand Down Expand Up @@ -540,6 +584,114 @@ def get_spark_session_or_start_new_with_repoconfig(
return spark_session


def _compute_non_entity_dates(
feature_views: List[FeatureView], kwargs: Dict[str, Any]
) -> Tuple[datetime, datetime]:
# Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback.
start_date_opt = cast(Optional[datetime], kwargs.get("start_date"))
end_date_opt = cast(Optional[datetime], kwargs.get("end_date"))
end_date: datetime = end_date_opt or datetime.now(timezone.utc)

if start_date_opt is None:
max_ttl_seconds = 0
for fv in feature_views:
if fv.ttl and isinstance(fv.ttl, timedelta):
max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds()))
start_date: datetime = (
end_date - timedelta(seconds=max_ttl_seconds)
if max_ttl_seconds > 0
else end_date - timedelta(days=30)
)
else:
start_date = start_date_opt
return (start_date, end_date)


def _gather_all_entities(
fv_query_contexts: List[offline_utils.FeatureViewQueryContext],
) -> List[str]:
# Why: ensure a unified entity set across feature views to align UNION schemas.
all_entities: List[str] = []
for ctx in fv_query_contexts:
for e in ctx.entities:
if e not in all_entities:
all_entities.append(e)
return all_entities


def _create_temp_entity_union_view(
spark_session: SparkSession,
tmp_view_name: str,
feature_views: List[FeatureView],
fv_query_contexts: List[offline_utils.FeatureViewQueryContext],
start_date: datetime,
end_date: datetime,
date_partition_column_formats: List[Optional[str]],
) -> None:
# Why: derive distinct entity keys observed in the time window without requiring an entity_df upfront.
start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)

# Compute the unified entity set to align schemas in the UNION.
all_entities = _gather_all_entities(fv_query_contexts)

per_view_selects: List[str] = []
for fv, ctx, date_format in zip(
feature_views, fv_query_contexts, date_partition_column_formats
):
assert isinstance(fv.batch_source, SparkSource)
from_expression = fv.batch_source.get_table_query_string()
timestamp_field = fv.batch_source.timestamp_field or "event_timestamp"
date_partition_column = fv.batch_source.date_partition_column
partition_clause = ""
if date_partition_column and date_format:
partition_clause = (
f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'"
f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'"
)

# Fill missing entity columns with NULL and cast to STRING to keep UNION schemas aligned.
select_entities: List[str] = []
ctx_entities_set = set(ctx.entities)
for col in all_entities:
if col in ctx_entities_set:
select_entities.append(f"CAST({col} AS STRING) AS {col}")
else:
select_entities.append(f"CAST(NULL AS STRING) AS {col}")

per_view_selects.append(
f"""
SELECT DISTINCT {", ".join(select_entities)}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause}
"""
)

union_query = "\nUNION DISTINCT\n".join([s.strip() for s in per_view_selects])
spark_session.sql(
f"CREATE OR REPLACE TEMPORARY VIEW {tmp_view_name} AS {union_query}"
)


def _make_left_table_query(end_date: datetime, tmp_view_name: str) -> Tuple[str, str]:
# Why: use a stable as-of timestamp for PIT joins when no entity timestamps are provided.
event_timestamp_col = "entity_ts"
left_table_query_string = (
f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS {event_timestamp_col} "
f"FROM {tmp_view_name})"
)
return left_table_query_string, event_timestamp_col


def _entity_schema_keys_from(
all_entities: List[str], event_timestamp_col: str
) -> KeysView[str]:
# Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing.
return cast(
KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys()
)


def _get_entity_df_event_timestamp_range(
entity_df: Union[pd.DataFrame, str],
entity_df_event_timestamp_col: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def get_table_query_string(self) -> str:
# If both the table query string and the actual query are null, we can load from file.
spark_session = SparkSession.getActiveSession()
if spark_session is None:
raise AssertionError("Could not find an active spark session.")
# Remote mode may not have an active session bound to the thread; create one on demand.
spark_session = SparkSession.builder.getOrCreate()
try:
df = self._load_dataframe_from_path(spark_session)
except Exception:
Expand Down
Loading
Loading