Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ def query_generator() -> Iterator[str]:
table_name = offline_utils.get_temp_entity_table_name()

# If using CTE and entity_df is a SQL query, we don't need a table
if config.offline_store.entity_select_mode == EntitySelectMode.embed_query:
if isinstance(entity_df, str):
left_table_query_string = entity_df
else:
raise ValueError(
f"Invalid entity select mode: {config.offline_store.entity_select_mode} cannot be used with entity_df as a DataFrame"
)
use_cte = (
isinstance(entity_df, str)
and config.offline_store.entity_select_mode
== EntitySelectMode.embed_query
)
if use_cte:
left_table_query_string = entity_df
else:
left_table_query_string = table_name
_upload_entity_df(config, entity_df, table_name)
Expand Down Expand Up @@ -187,7 +187,7 @@ def query_generator() -> Iterator[str]:
entity_df_columns=entity_schema.keys(),
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
entity_select_mode=config.offline_store.entity_select_mode,
use_cte=use_cte,
)
finally:
# Only cleanup if we created a table
Expand Down Expand Up @@ -386,7 +386,7 @@ def build_point_in_time_query(
entity_df_columns: KeysView[str],
query_template: str,
full_feature_names: bool = False,
entity_select_mode: EntitySelectMode = EntitySelectMode.temp_table,
use_cte: bool = False,
) -> str:
"""Build point-in-time query between each feature view table and the entity dataframe for PostgreSQL"""
template = Environment(loader=BaseLoader()).from_string(source=query_template)
Expand Down Expand Up @@ -414,7 +414,7 @@ def build_point_in_time_query(
"featureviews": feature_view_query_contexts,
"full_feature_names": full_feature_names,
"final_output_feature_names": final_output_feature_names,
"entity_select_mode": entity_select_mode.value,
"use_cte": use_cte,
}

query = template.render(template_context)
Expand Down Expand Up @@ -456,7 +456,7 @@ def _get_entity_schema(

MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
WITH
{% if entity_select_mode == "embed_query" %}
{% if use_cte %}
entity_query AS ({{ left_table_query_string }}),
{% endif %}
/*
Expand All @@ -479,15 +479,17 @@ def _get_entity_schema(
{% endif %}
{% endfor %}
FROM
{% if entity_select_mode == "embed_query" %}
{% if use_cte %}
entity_query
{% else %}
{{ left_table_query_string }}
{% endif %}
)

{% if featureviews | length > 0 %}
,
{% endif %}

{% for featureview in featureviews %}

"{{ featureview.name }}__entity_dataframe" AS (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,89 @@ def test_get_historical_features_entity_select_modes_embed_query(
assert True # If we get here, the SQL is valid


@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn")
@patch(
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table"
)
@patch(
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.get_query_schema"
)
def test_get_historical_features_entity_select_modes_embed_query_with_dataframe(
mock_get_query_schema, mock_df_to_postgres_table, mock_get_conn
):
mock_conn = MagicMock()
mock_get_conn.return_value.__enter__.return_value = mock_conn

# Mock the query schema to return a simple schema
mock_get_query_schema.return_value = {
"event_timestamp": pd.Timestamp,
"driver_id": pd.Int64Dtype(),
}

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=PostgreSQLOfflineStoreConfig(
type="postgres",
host="localhost",
port=5432,
database="test_db",
db_schema="public",
user="test_user",
password="test_password",
entity_select_mode="embed_query",
),
)

test_data_source = PostgreSQLSource(
name="test_batch_source",
description="test_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
)

test_feature_view = FeatureView(
name="test_feature_view",
entities=_mock_entity(),
schema=[
Field(name="feature1", dtype=Float32),
],
source=test_data_source,
)

mock_registry = MagicMock()
mock_registry.get_feature_view.return_value = test_feature_view

# Use a DataFrame even though embed_query mode is used
entity_df = pd.DataFrame(
{"event_timestamp": [datetime(2021, 1, 1)], "driver_id": [1]}
)

retrieval_job = PostgreSQLOfflineStore.get_historical_features(
config=test_repo_config,
feature_views=[test_feature_view],
feature_refs=["test_feature_view:feature1"],
entity_df=entity_df,
registry=mock_registry,
project="test_project",
)

actual_query = retrieval_job.to_sql().strip()
logger.debug("Actual query:\n%s", actual_query)

# Check that the query starts with WITH and contains the expected comment block
assert actual_query.startswith("""WITH

/*
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
all the logic as the field to GROUP BY the data
*/""")

sqlglot.parse(actual_query)
assert True


@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn")
@patch(
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table"
Expand Down
Loading