Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/common/retrieval_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Union
from typing import Optional, Union

import pandas as pd

Expand All @@ -15,5 +15,5 @@ class HistoricalRetrievalTask:
feature_view: Union[BatchFeatureView, StreamFeatureView]
full_feature_name: bool
registry: Registry
start_time: datetime
end_time: datetime
start_time: Optional[datetime] = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retrieval job doesn't have to provide start or end date

end_time: Optional[datetime] = None
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/compute_engines/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def _should_validate(self):
def build(self) -> ExecutionPlan:
last_node = self.build_source_node()

# PIT join entities to the feature data, and perform filtering
if isinstance(self.task, HistoricalRetrievalTask):
last_node = self.build_join_node(last_node)
# Join entity_df with source if needed
last_node = self.build_join_node(last_node)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without entity_df provided, the join node is just a pass through node


# PIT filter, TTL, and user-defined filter
last_node = self.build_filter_node(last_node)

if self._should_aggregate():
Expand Down
23 changes: 0 additions & 23 deletions sdk/python/feast/infra/compute_engines/local/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from feast.infra.common.materialization_job import MaterializationTask
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
from feast.infra.compute_engines.dag.plan import ExecutionPlan
from feast.infra.compute_engines.feature_builder import FeatureBuilder
from feast.infra.compute_engines.local.backends.base import DataFrameBackend
from feast.infra.compute_engines.local.nodes import (
Expand Down Expand Up @@ -95,25 +94,3 @@ def build_output_nodes(self, input_node):
node = LocalOutputNode("output")
node.add_input(input_node)
self.nodes.append(node)

def build(self) -> ExecutionPlan:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove it as it is the same with the Parent build

last_node = self.build_source_node()

if isinstance(self.task, HistoricalRetrievalTask):
last_node = self.build_join_node(last_node)

last_node = self.build_filter_node(last_node)

if self._should_aggregate():
last_node = self.build_aggregation_node(last_node)
elif isinstance(self.task, HistoricalRetrievalTask):
last_node = self.build_dedup_node(last_node)

if self._should_transform():
last_node = self.build_transformation_node(last_node)

if self._should_validate():
last_node = self.build_validation_node(last_node)

self.build_output_nodes(last_node)
return ExecutionPlan(self.nodes)
39 changes: 33 additions & 6 deletions sdk/python/feast/infra/compute_engines/local/nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import timedelta
from datetime import datetime, timedelta
from typing import Optional

import pyarrow as pa

from feast.data_source import DataSource
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue
from feast.infra.compute_engines.local.backends.base import DataFrameBackend
Expand All @@ -15,14 +16,40 @@


class LocalSourceReadNode(LocalNode):
def __init__(self, name: str, feature_view, task):
def __init__(
self,
name: str,
source: DataSource,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
):
super().__init__(name)
self.feature_view = feature_view
self.task = task
self.source = source
self.start_time = start_time
self.end_time = end_time

def execute(self, context: ExecutionContext) -> ArrowTableValue:
# TODO : Implement the logic to read from offline store
return ArrowTableValue(data=pa.Table.from_pandas(context.entity_df))
offline_store = context.offline_store
(
join_key_columns,
feature_name_columns,
timestamp_field,
created_timestamp_column,
) = context.column_info

# 📥 Reuse Feast's robust query resolver
retrieval_job = offline_store.pull_all_from_table_or_query(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse the pull_all_from_table_or_query to read data from all feast offline_store now

config=context.repo_config,
data_source=self.source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=self.start_time,
end_date=self.end_time,
)
arrow_table = retrieval_job.to_arrow()
return ArrowTableValue(data=arrow_table)


class LocalJoinNode(LocalNode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
SparkAggregationNode,
SparkDedupNode,
SparkFilterNode,
SparkHistoricalRetrievalReadNode,
SparkJoinNode,
SparkMaterializationReadNode,
SparkReadNode,
SparkTransformationNode,
SparkWriteNode,
)
Expand All @@ -27,12 +26,10 @@ def __init__(
self.spark_session = spark_session

def build_source_node(self):
if isinstance(self.task, MaterializationTask):
node = SparkMaterializationReadNode("source", self.task)
else:
node = SparkHistoricalRetrievalReadNode(
"source", self.task, self.spark_session
)
source = self.feature_view.batch_source
start_time = self.task.start_time
end_time = self.task.end_time
node = SparkReadNode("source", source, start_time, end_time)
self.nodes.append(node)
return node

Expand Down
120 changes: 30 additions & 90 deletions sdk/python/feast/infra/compute_engines/spark/node.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from datetime import timedelta
from datetime import datetime, timedelta
from typing import List, Optional, Union, cast

from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import functions as F

from feast import BatchFeatureView, StreamFeatureView
from feast.aggregation import Aggregation
from feast.infra.common.materialization_job import MaterializationTask
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
from feast.data_source import DataSource
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.dag.model import DAGFormat
from feast.infra.compute_engines.dag.node import DAGNode
Expand All @@ -23,7 +22,6 @@
from feast.infra.offline_stores.offline_utils import (
infer_event_timestamp_from_entity_df,
)
from feast.utils import _get_fields_with_aliases

ENTITY_TS_ALIAS = "__entity_event_timestamp"

Expand All @@ -49,18 +47,21 @@ def rename_entity_ts_column(
return entity_df


class SparkMaterializationReadNode(DAGNode):
class SparkReadNode(DAGNode):
def __init__(
self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask]
self,
name: str,
source: DataSource,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
):
super().__init__(name)
self.task = task
self.source = source
self.start_time = start_time
self.end_time = end_time

def execute(self, context: ExecutionContext) -> DAGValue:
offline_store = context.offline_store
start_time = self.task.start_time
end_time = self.task.end_time

(
join_key_columns,
feature_name_columns,
Expand All @@ -69,15 +70,15 @@ def execute(self, context: ExecutionContext) -> DAGValue:
) = context.column_info

# 📥 Reuse Feast's robust query resolver
retrieval_job = offline_store.pull_latest_from_table_or_query(
retrieval_job = offline_store.pull_all_from_table_or_query(
config=context.repo_config,
data_source=self.task.feature_view.batch_source,
data_source=self.source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_time,
end_date=end_time,
start_date=self.start_time,
end_date=self.end_time,
)
spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()

Expand All @@ -88,74 +89,8 @@ def execute(self, context: ExecutionContext) -> DAGValue:
"source": "feature_view_batch_source",
"timestamp_field": timestamp_field,
"created_timestamp_column": created_timestamp_column,
"start_date": start_time,
"end_date": end_time,
},
)


class SparkHistoricalRetrievalReadNode(DAGNode):
def __init__(
self, name: str, task: HistoricalRetrievalTask, spark_session: SparkSession
):
super().__init__(name)
self.task = task
self.spark_session = spark_session

def execute(self, context: ExecutionContext) -> DAGValue:
"""
Read data from the offline store on the Spark engine.
TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features.
Args:
context: SparkExecutionContext
Returns: DAGValue
"""
fv = self.task.feature_view
source = fv.batch_source

(
join_key_columns,
feature_name_columns,
timestamp_field,
created_timestamp_column,
) = context.column_info

# TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp
# retrieval_job = offline_store.pull_all_from_table_or_query(
# config=context.repo_config,
# data_source=source,
# join_key_columns=join_key_columns,
# feature_name_columns=feature_name_columns,
# timestamp_field=timestamp_field,
# start_date=min_ts,
# end_date=max_ts,
# )
# spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()

columns = join_key_columns + feature_name_columns + [timestamp_field]
if created_timestamp_column:
columns.append(created_timestamp_column)

(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=columns,
field_mappings=source.field_mapping,
)
fields_with_alias_string = ", ".join(fields_with_aliases)

from_expression = source.get_table_query_string()

query = f"""
SELECT {fields_with_alias_string}
FROM {from_expression}
"""
spark_df = self.spark_session.sql(query)

return DAGValue(
data=spark_df,
format=DAGFormat.SPARK,
metadata={
"source": "feature_view_batch_source",
"timestamp_field": timestamp_field,
"start_date": self.start_time,
"end_date": self.end_time,
},
)

Expand Down Expand Up @@ -227,7 +162,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
feature_df: DataFrame = feature_value.data

entity_df = context.entity_df
assert entity_df is not None, "entity_df must be set in ExecutionContext"
if entity_df is None:
return DAGValue(
data=feature_df,
format=DAGFormat.SPARK,
metadata={"joined_on": None},
)

# Get timestamp fields from feature view
join_keys, feature_cols, ts_col, created_ts_col = context.column_info
Expand Down Expand Up @@ -272,13 +212,13 @@ def execute(self, context: ExecutionContext) -> DAGValue:
if ENTITY_TS_ALIAS in input_df.columns:
filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS))

# Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
if self.ttl:
ttl_seconds = int(self.ttl.total_seconds())
lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr(
f"INTERVAL {ttl_seconds} seconds"
)
filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound)
# Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
if self.ttl:
ttl_seconds = int(self.ttl.total_seconds())
lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr(
f"INTERVAL {ttl_seconds} seconds"
)
filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound)

# Optional custom filter condition
if self.filter_condition:
Expand Down
21 changes: 17 additions & 4 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
BigQuerySource,
SavedDatasetBigQueryStorage,
)
from .offline_utils import get_timestamp_filter_sql

try:
from google.api_core import client_info as http_client_info
Expand Down Expand Up @@ -188,8 +189,9 @@ def pull_all_from_table_or_query(
join_key_columns: List[str],
feature_name_columns: List[str],
timestamp_field: str,
start_date: datetime,
end_date: datetime,
created_timestamp_column: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> RetrievalJob:
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
assert isinstance(data_source, BigQuerySource)
Expand All @@ -201,15 +203,26 @@ def pull_all_from_table_or_query(
project=project_id,
location=config.offline_store.location,
)

timestamp_fields = [timestamp_field]
if created_timestamp_column:
timestamp_fields.append(created_timestamp_column)
field_string = ", ".join(
BigQueryOfflineStore._escape_query_columns(join_key_columns)
+ BigQueryOfflineStore._escape_query_columns(feature_name_columns)
+ [timestamp_field]
+ timestamp_fields
)
timestamp_filter = get_timestamp_filter_sql(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

major change: get the timestamp filter from get_timestamp_filter_sql

start_date,
end_date,
timestamp_field,
quote_fields=False,
cast_style="timestamp_func",
)
query = f"""
SELECT {field_string}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
WHERE {timestamp_filter}
"""
return BigQueryRetrievalJob(
query=query,
Expand Down
Loading
Loading