Skip to content

Commit 8d6cc24

Browse files
Fix: if feature_name_columns is empty, use SELECT * so the UDF receives the full raw source schema
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 92ffbb9 commit 8d6cc24

2 files changed

Lines changed: 194 additions & 5 deletions

File tree

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,18 @@ def pull_all_from_table_or_query(
387387
timestamp_fields = [timestamp_field]
388388
if created_timestamp_column:
389389
timestamp_fields.append(created_timestamp_column)
390-
(fields_with_aliases, aliases) = _get_fields_with_aliases(
391-
fields=join_key_columns + feature_name_columns + timestamp_fields,
392-
field_mappings=data_source.field_mapping,
393-
)
394390

395-
fields_with_alias_string = ", ".join(fields_with_aliases)
391+
if feature_name_columns:
392+
(fields_with_aliases, _) = _get_fields_with_aliases(
393+
fields=join_key_columns + feature_name_columns + timestamp_fields,
394+
field_mappings=data_source.field_mapping,
395+
)
396+
fields_with_alias_string = ", ".join(fields_with_aliases)
397+
else:
398+
# Empty feature_name_columns signals "read all source columns".
399+
# Used by BatchFeatureView with TransformationMode.PYTHON/ray/pandas where
400+
# the UDF computes output features from raw input — don't project upfront.
401+
fields_with_alias_string = "*"
396402

397403
from_expression = data_source.get_table_query_string()
398404
timestamp_filter = get_timestamp_filter_sql(
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
Unit tests for SparkOfflineStore.pull_all_from_table_or_query SQL generation.
3+
4+
Covers the bug where feature_name_columns=[] (signalling "read all source
5+
columns" for BatchFeatureView UDF transformations) caused a bare
6+
SELECT user_id, event_timestamp FROM source
7+
instead of SELECT *, silently dropping all columns the UDF needs.
8+
"""
9+
import sys
10+
from datetime import datetime, timezone
11+
from types import ModuleType
12+
from unittest.mock import MagicMock, patch
13+
14+
import pytest
15+
16+
# ---------------------------------------------------------------------------
17+
# Stub pyspark so SparkOfflineStore can be imported without a Spark install.
18+
# Only the minimal surface needed for module-level imports in spark.py.
19+
# ---------------------------------------------------------------------------
20+
def _stub_pyspark():
21+
pyspark = ModuleType("pyspark")
22+
pyspark.SparkConf = MagicMock()
23+
sql = ModuleType("pyspark.sql")
24+
sql.SparkSession = MagicMock()
25+
sql.DataFrame = MagicMock()
26+
sql.functions = ModuleType("pyspark.sql.functions")
27+
sql.types = ModuleType("pyspark.sql.types")
28+
sql.types.StructType = MagicMock()
29+
pyspark.sql = sql
30+
for name, mod in {
31+
"pyspark": pyspark,
32+
"pyspark.sql": sql,
33+
"pyspark.sql.functions": sql.functions,
34+
"pyspark.sql.types": sql.types,
35+
}.items():
36+
sys.modules.setdefault(name, mod)
37+
38+
39+
_stub_pyspark()
40+
41+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
42+
SparkOfflineStore,
43+
SparkOfflineStoreConfig,
44+
)
45+
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
46+
SparkSource,
47+
)
48+
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
49+
from feast.repo_config import RepoConfig
50+
51+
# ---------------------------------------------------------------------------
52+
# Shared fixtures
53+
# ---------------------------------------------------------------------------
54+
55+
START = datetime(2023, 1, 1, tzinfo=timezone.utc)
56+
END = datetime(2024, 1, 1, tzinfo=timezone.utc)
57+
58+
59+
@pytest.fixture()
60+
def repo_config():
61+
return RepoConfig(
62+
registry="file:///tmp/registry.db",
63+
project="test",
64+
provider="local",
65+
online_store=SqliteOnlineStoreConfig(type="sqlite"),
66+
offline_store=SparkOfflineStoreConfig(type="spark"),
67+
)
68+
69+
70+
@pytest.fixture()
71+
def spark_source():
72+
src = SparkSource(
73+
name="raw_reviews",
74+
path="s3a://bucket/processed/reviews/",
75+
file_format="parquet",
76+
timestamp_field="event_timestamp",
77+
)
78+
return src
79+
80+
81+
def _run_pull_all(repo_config, spark_source, feature_name_columns):
82+
"""
83+
Call pull_all_from_table_or_query with a mocked SparkSession and return
84+
the SQL query string that would be issued.
85+
"""
86+
mock_spark = MagicMock()
87+
88+
with patch(
89+
"feast.infra.offline_stores.contrib.spark_offline_store.spark"
90+
".get_spark_session_or_start_new_with_repoconfig",
91+
return_value=mock_spark,
92+
):
93+
job = SparkOfflineStore.pull_all_from_table_or_query(
94+
config=repo_config,
95+
data_source=spark_source,
96+
join_key_columns=["user_id"],
97+
feature_name_columns=feature_name_columns,
98+
timestamp_field="event_timestamp",
99+
created_timestamp_column=None,
100+
start_date=START,
101+
end_date=END,
102+
)
103+
104+
return job.query.strip()
105+
106+
def test_pull_all_with_empty_feature_cols_generates_select_star(
107+
repo_config, spark_source
108+
):
109+
"""
110+
feature_name_columns=[] must produce SELECT * so UDF-based
111+
BatchFeatureViews receive all raw source columns for aggregation.
112+
"""
113+
sql = _run_pull_all(repo_config, spark_source, feature_name_columns=[])
114+
115+
assert sql.startswith("SELECT *"), (
116+
"Expected 'SELECT *' when feature_name_columns=[], "
117+
f"got: {sql[:120]!r}\n\n"
118+
"BatchFeatureView UDFs need all raw source columns to compute "
119+
"aggregations — projecting only join key + timestamp silently "
120+
"drops rating, text, helpful_vote, etc."
121+
)
122+
assert "user_id" not in sql.split("FROM")[0], (
123+
"SELECT * must not also explicitly list join key columns"
124+
)
125+
126+
127+
def test_pull_all_with_feature_cols_generates_explicit_projection(
128+
repo_config, spark_source
129+
):
130+
"""
131+
When feature_name_columns is non-empty (normal FeatureView path),
132+
the query must project only the requested columns — not SELECT *.
133+
"""
134+
sql = _run_pull_all(
135+
repo_config,
136+
spark_source,
137+
feature_name_columns=["avg_rating", "review_count"],
138+
)
139+
140+
assert "SELECT *" not in sql, (
141+
"Non-empty feature_name_columns must produce explicit SELECT projection, not SELECT *"
142+
)
143+
assert "avg_rating" in sql
144+
assert "review_count" in sql
145+
assert "user_id" in sql
146+
assert "event_timestamp" in sql
147+
148+
149+
def test_pull_all_empty_feature_cols_upstream_regression(repo_config, spark_source):
150+
"""
151+
Regression guard: the upstream (unfixed) behaviour with feature_name_columns=[]
152+
produced a query that only selected join key + timestamp, dropping all columns
153+
the UDF needs. Verify the fixed code does NOT produce that broken query.
154+
155+
Broken upstream SQL looked like:
156+
SELECT user_id, event_timestamp FROM ... WHERE ...
157+
"""
158+
sql = _run_pull_all(repo_config, spark_source, feature_name_columns=[])
159+
160+
projection = sql.split("FROM")[0]
161+
assert "user_id" not in projection, (
162+
"Upstream bug: query projected only 'user_id, event_timestamp', "
163+
"silently dropping all columns needed by the BFV UDF. "
164+
"Fixed query should use SELECT *."
165+
)
166+
167+
168+
@pytest.mark.parametrize(
169+
"feature_cols,expect_star",
170+
[
171+
([], True),
172+
(["f1"], False),
173+
(["f1", "f2", "f3"], False),
174+
],
175+
)
176+
def test_pull_all_select_star_only_when_feature_cols_empty(
177+
repo_config, spark_source, feature_cols, expect_star
178+
):
179+
sql = _run_pull_all(repo_config, spark_source, feature_name_columns=feature_cols)
180+
has_star = sql.strip().upper().startswith("SELECT *")
181+
assert has_star == expect_star, (
182+
f"feature_cols={feature_cols!r}: expected SELECT *={expect_star}, got SQL: {sql[:100]!r}"
183+
)

0 commit comments

Comments
 (0)