Skip to content

Commit c99bc7f

Browse files
XuananLeLe Xuan An
authored andcommitted
refactor: Share compute engine timestamp helpers
Signed-off-by: Le Xuan An <anlx@viettel.com.vn>
1 parent 37accc6 commit c99bc7f

5 files changed

Lines changed: 44 additions & 30 deletions

File tree

sdk/python/feast/infra/compute_engines/flink/nodes.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@
2020
pandas_to_flink_table,
2121
register_flink_temporary_view,
2222
)
23-
from feast.infra.compute_engines.utils import create_offline_store_retrieval_job
24-
from feast.infra.offline_stores.offline_utils import (
25-
infer_event_timestamp_from_entity_df,
23+
from feast.infra.compute_engines.utils import (
24+
ENTITY_ROW_ID,
25+
ENTITY_TS_ALIAS,
26+
create_offline_store_retrieval_job,
27+
find_entity_timestamp_column,
28+
infer_entity_timestamp_column,
2629
)
2730
from feast.utils import _convert_arrow_to_proto
2831

2932
logger = logging.getLogger(__name__)
3033

31-
ENTITY_TS_ALIAS = "__entity_event_timestamp"
32-
ENTITY_ROW_ID = "__feast_entity_row_id"
3334
DEDUP_ROW_NUMBER = "__feast_row_number"
3435

3536

@@ -133,10 +134,9 @@ def _sql_value(
133134

134135

135136
def _entity_timestamp_column_from_columns(columns: List[str]) -> str:
136-
if ENTITY_TS_ALIAS in columns:
137-
return ENTITY_TS_ALIAS
138-
if "event_timestamp" in columns:
139-
return "event_timestamp"
137+
entity_ts_col = find_entity_timestamp_column(columns)
138+
if entity_ts_col:
139+
return entity_ts_col
140140
raise ValueError(
141141
"SQL-based entity_df for FlinkComputeEngine must select an "
142142
"`event_timestamp` column."
@@ -151,7 +151,7 @@ def _entity_value_from_dataframe(
151151
entity_df = entity_df.copy()
152152
entity_df[ENTITY_ROW_ID] = range(len(entity_df))
153153
entity_schema = dict(zip(entity_df.columns, entity_df.dtypes))
154-
entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema)
154+
entity_ts_col = infer_entity_timestamp_column(entity_schema)
155155
if entity_ts_col != ENTITY_TS_ALIAS:
156156
entity_df = entity_df.rename(columns={entity_ts_col: ENTITY_TS_ALIAS})
157157
return (

sdk/python/feast/infra/compute_engines/local/nodes.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,14 @@
1414
from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue
1515
from feast.infra.compute_engines.local.local_node import LocalNode
1616
from feast.infra.compute_engines.utils import (
17+
ENTITY_TS_ALIAS,
1718
create_offline_store_retrieval_job,
18-
)
19-
from feast.infra.offline_stores.offline_utils import (
20-
infer_event_timestamp_from_entity_df,
19+
infer_entity_timestamp_column,
2120
)
2221
from feast.utils import _convert_arrow_to_proto
2322

2423
logger = logging.getLogger(__name__)
2524

26-
ENTITY_TS_ALIAS = "__entity_event_timestamp"
27-
2825

2926
class LocalSourceReadNode(LocalNode):
3027
def __init__(
@@ -99,7 +96,7 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
9996
entity_df = self.backend.from_arrow(pa.Table.from_pandas(context.entity_df))
10097

10198
entity_schema = dict(zip(entity_df.columns, entity_df.dtypes))
102-
entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema)
99+
entity_ts_col = infer_entity_timestamp_column(entity_schema)
103100

104101
if entity_ts_col != ENTITY_TS_ALIAS:
105102
entity_df = self.backend.rename_columns(

sdk/python/feast/infra/compute_engines/ray/nodes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
safe_batch_processor,
2424
write_to_online_store,
2525
)
26-
from feast.infra.compute_engines.utils import create_offline_store_retrieval_job
26+
from feast.infra.compute_engines.utils import (
27+
ENTITY_TS_ALIAS,
28+
create_offline_store_retrieval_job,
29+
)
2730
from feast.infra.ray_initializer import get_ray_wrapper
2831
from feast.infra.ray_shared_utils import (
2932
apply_field_mapping,
@@ -33,9 +36,6 @@
3336

3437
logger = logging.getLogger(__name__)
3538

36-
# Entity timestamp alias for historical feature retrieval
37-
ENTITY_TS_ALIAS = "__entity_event_timestamp"
38-
3939

4040
class RayReadNode(DAGNode):
4141
"""

sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from feast.infra.compute_engines.dag.value import DAGValue
3535
from feast.infra.compute_engines.spark.utils import map_in_arrow
3636
from feast.infra.compute_engines.utils import (
37+
ENTITY_TS_ALIAS,
3738
create_offline_store_retrieval_job,
39+
infer_entity_timestamp_column,
3840
)
3941
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
4042
SparkRetrievalJob,
@@ -43,9 +45,6 @@
4345
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
4446
SparkSource,
4547
)
46-
from feast.infra.offline_stores.offline_utils import (
47-
infer_event_timestamp_from_entity_df,
48-
)
4948

5049
logger = logging.getLogger(__name__)
5150

@@ -144,9 +143,6 @@ def _spark_types_compatible(expected: SparkDataType, actual: SparkDataType) -> b
144143
return False
145144

146145

147-
ENTITY_TS_ALIAS = "__entity_event_timestamp"
148-
149-
150146
# Rename entity_df event_timestamp_col to match feature_df
151147
def rename_entity_ts_column(
152148
spark_session: SparkSession, entity_df: DataFrame
@@ -159,9 +155,7 @@ def rename_entity_ts_column(
159155
spark_session=spark_session,
160156
entity_df=entity_df,
161157
)
162-
event_timestamp_col = infer_event_timestamp_from_entity_df(
163-
entity_schema=entity_schema,
164-
)
158+
event_timestamp_col = infer_entity_timestamp_column(entity_schema)
165159
if not isinstance(entity_df, DataFrame):
166160
entity_df = spark_session.createDataFrame(entity_df)
167161
entity_df = entity_df.withColumnRenamed(event_timestamp_col, ENTITY_TS_ALIAS)

sdk/python/feast/infra/compute_engines/utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,32 @@
11
from datetime import datetime
2-
from typing import Optional
2+
from typing import Any, Mapping, Optional, Sequence
33

44
from feast.data_source import DataSource
55
from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext
66
from feast.infra.offline_stores.offline_store import RetrievalJob
7+
from feast.infra.offline_stores.offline_utils import (
8+
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
9+
infer_event_timestamp_from_entity_df,
10+
)
11+
12+
ENTITY_TS_ALIAS = "__entity_event_timestamp"
13+
ENTITY_ROW_ID = "__feast_entity_row_id"
14+
15+
16+
def infer_entity_timestamp_column(entity_schema: Mapping[str, Any]) -> str:
17+
"""Resolve the entity timestamp column used for point-in-time joins."""
18+
if ENTITY_TS_ALIAS in entity_schema:
19+
return ENTITY_TS_ALIAS
20+
return infer_event_timestamp_from_entity_df(dict(entity_schema))
21+
22+
23+
def find_entity_timestamp_column(columns: Sequence[str]) -> Optional[str]:
24+
"""Find the timestamp column in an entity DataFrame schema, if present."""
25+
if ENTITY_TS_ALIAS in columns:
26+
return ENTITY_TS_ALIAS
27+
if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in columns:
28+
return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
29+
return None
730

831

932
def create_offline_store_retrieval_job(

0 commit comments

Comments
 (0)