Skip to content

Commit bfa9e36

Browse files
committed
fix: Fix SparkRetrievalJob.persist() failing for SparkSource
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent fa271be commit bfa9e36

4 files changed

Lines changed: 293 additions & 8 deletions

File tree

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,20 +1053,33 @@ def persist(
10531053
):
10541054
"""
10551055
Run the retrieval and persist the results in the same offline store used for read.
1056-
Please note the persisting is done only within the scope of the spark session for local warehouse directory.
1056+
Supports both table-based and path-based SparkSource configurations.
1057+
For table-based: persists via saveAsTable (remote warehouse) or createOrReplaceTempView (local).
1058+
For path-based: writes directly to the specified path in the given file format.
10571059
"""
10581060
assert isinstance(storage, SavedDatasetSparkStorage)
1061+
10591062
table_name = storage.spark_options.table
1060-
if not table_name:
1061-
raise ValueError("Cannot persist, table_name is not defined")
1062-
if self._has_remote_warehouse_in_config():
1063-
file_format = storage.spark_options.file_format
1063+
path = storage.spark_options.path
1064+
file_format = storage.spark_options.file_format
1065+
1066+
if path:
10641067
if not file_format:
1065-
self.to_spark_df().write.saveAsTable(table_name)
1068+
file_format = "parquet"
1069+
write_mode = "overwrite" if allow_overwrite else "error"
1070+
self.to_spark_df().write.format(file_format).mode(write_mode).save(path)
1071+
elif table_name:
1072+
if self._has_remote_warehouse_in_config():
1073+
if not file_format:
1074+
self.to_spark_df().write.saveAsTable(table_name)
1075+
else:
1076+
self.to_spark_df().write.format(file_format).saveAsTable(table_name)
10661077
else:
1067-
self.to_spark_df().write.format(file_format).saveAsTable(table_name)
1078+
self.to_spark_df().createOrReplaceTempView(table_name)
10681079
else:
1069-
self.to_spark_df().createOrReplaceTempView(table_name)
1080+
raise ValueError(
1081+
"Cannot persist: either 'table' or 'path' must be specified in SavedDatasetSparkStorage"
1082+
)
10701083

10711084
def _has_remote_warehouse_in_config(self) -> bool:
10721085
"""

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,14 @@ def to_data_source(self) -> DataSource:
453453
file_format=self.spark_options.file_format,
454454
table_format=self.spark_options.table_format,
455455
)
456+
457+
@staticmethod
458+
def from_data_source(data_source: DataSource) -> "SavedDatasetSparkStorage":
459+
assert isinstance(data_source, SparkSource)
460+
return SavedDatasetSparkStorage(
461+
table=data_source.table,
462+
query=data_source.query,
463+
path=data_source.path,
464+
file_format=data_source.file_format,
465+
table_format=data_source.table_format,
466+
)

sdk/python/feast/saved_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __new__(cls, name, bases, dct):
3434

3535
_DATA_SOURCE_TO_SAVED_DATASET_STORAGE = {
3636
"FileSource": "feast.infra.offline_stores.file_source.SavedDatasetFileStorage",
37+
"SparkSource": "feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SavedDatasetSparkStorage",
3738
}
3839

3940

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""
2+
Unit tests for SparkRetrievalJob.persist() and SavedDatasetSparkStorage.from_data_source().
3+
4+
Covers the fix for https://github.com/feast-dev/feast/issues/6261 where:
5+
1. SavedDatasetStorage.from_data_source() did not support SparkSource
6+
2. SavedDatasetSparkStorage lacked a from_data_source() method
7+
3. SparkRetrievalJob.persist() only supported table-based storage, not path-based
8+
"""
9+
10+
from unittest.mock import MagicMock
11+
12+
import pytest
13+
14+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
15+
SparkOfflineStoreConfig,
16+
SparkRetrievalJob,
17+
)
18+
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
19+
SavedDatasetSparkStorage,
20+
SparkSource,
21+
)
22+
from feast.infra.offline_stores.file_source import FileSource
23+
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
24+
from feast.repo_config import RepoConfig
25+
from feast.saved_dataset import SavedDatasetStorage
26+
from feast.table_format import IcebergFormat
27+
28+
# ---------------------------------------------------------------------------
29+
# Shared fixtures
30+
# ---------------------------------------------------------------------------
31+
32+
33+
@pytest.fixture()
34+
def repo_config():
35+
return RepoConfig(
36+
registry="file:///tmp/registry.db",
37+
project="test",
38+
provider="local",
39+
online_store=SqliteOnlineStoreConfig(type="sqlite"),
40+
offline_store=SparkOfflineStoreConfig(type="spark"),
41+
)
42+
43+
44+
@pytest.fixture()
45+
def table_spark_source():
46+
return SparkSource(
47+
name="my_table",
48+
table="db.my_table",
49+
timestamp_field="event_timestamp",
50+
)
51+
52+
53+
@pytest.fixture()
54+
def path_spark_source():
55+
return SparkSource(
56+
name="my_path_source",
57+
path="s3a://bucket/data/features/",
58+
file_format="parquet",
59+
timestamp_field="event_timestamp",
60+
)
61+
62+
63+
def _make_spark_retrieval_job(repo_config, remote_warehouse=True):
64+
"""Build a SparkRetrievalJob with a mocked SparkSession."""
65+
mock_spark = MagicMock()
66+
67+
if remote_warehouse:
68+
mock_spark.conf.get.side_effect = lambda key: {
69+
"hive.metastore.uris": "thrift://metastore:9083",
70+
}.get(key, None)
71+
else:
72+
73+
def _local_conf_get(key):
74+
if key == "hive.metastore.uris":
75+
raise Exception("not set")
76+
if key == "spark.sql.warehouse.dir":
77+
return "file:///tmp/spark-warehouse"
78+
return None
79+
80+
mock_spark.conf.get.side_effect = _local_conf_get
81+
82+
return SparkRetrievalJob(
83+
spark_session=mock_spark,
84+
query="SELECT 1",
85+
full_feature_names=False,
86+
config=repo_config,
87+
)
88+
89+
90+
# ---------------------------------------------------------------------------
91+
# Group 1: SavedDatasetSparkStorage.from_data_source()
92+
# ---------------------------------------------------------------------------
93+
94+
95+
class TestSavedDatasetSparkStorageFromDataSource:
96+
def test_from_data_source_with_table_source(self, table_spark_source):
97+
storage = SavedDatasetSparkStorage.from_data_source(table_spark_source)
98+
99+
assert isinstance(storage, SavedDatasetSparkStorage)
100+
assert storage.spark_options.table == "db.my_table"
101+
assert storage.spark_options.query is None
102+
assert storage.spark_options.path is None
103+
104+
def test_from_data_source_with_path_source(self, path_spark_source):
105+
storage = SavedDatasetSparkStorage.from_data_source(path_spark_source)
106+
107+
assert isinstance(storage, SavedDatasetSparkStorage)
108+
assert storage.spark_options.path == "s3a://bucket/data/features/"
109+
assert storage.spark_options.file_format == "parquet"
110+
assert storage.spark_options.table is None
111+
assert storage.spark_options.query is None
112+
113+
def test_from_data_source_rejects_non_spark_source(self):
114+
file_source = FileSource(
115+
path="/tmp/data.parquet",
116+
timestamp_field="event_timestamp",
117+
)
118+
with pytest.raises(AssertionError):
119+
SavedDatasetSparkStorage.from_data_source(file_source)
120+
121+
122+
# ---------------------------------------------------------------------------
123+
# Group 2: SavedDatasetStorage.from_data_source() dispatch
124+
# ---------------------------------------------------------------------------
125+
126+
127+
class TestSavedDatasetStorageDispatch:
128+
def test_from_data_source_resolves_spark(self, table_spark_source):
129+
storage = SavedDatasetStorage.from_data_source(table_spark_source)
130+
131+
assert isinstance(storage, SavedDatasetSparkStorage)
132+
assert storage.spark_options.table == "db.my_table"
133+
134+
def test_from_data_source_resolves_path_spark(self, path_spark_source):
135+
storage = SavedDatasetStorage.from_data_source(path_spark_source)
136+
137+
assert isinstance(storage, SavedDatasetSparkStorage)
138+
assert storage.spark_options.path == "s3a://bucket/data/features/"
139+
assert storage.spark_options.file_format == "parquet"
140+
141+
def test_roundtrip_table_source(self, table_spark_source):
142+
storage = SavedDatasetStorage.from_data_source(table_spark_source)
143+
roundtripped = storage.to_data_source()
144+
145+
assert isinstance(roundtripped, SparkSource)
146+
assert roundtripped.table == table_spark_source.table
147+
assert roundtripped.query == table_spark_source.query
148+
assert roundtripped.path == table_spark_source.path
149+
150+
def test_roundtrip_path_source(self):
151+
source = SparkSource(
152+
name="my_path_source",
153+
table="fallback_name",
154+
timestamp_field="event_timestamp",
155+
)
156+
storage = SavedDatasetStorage.from_data_source(source)
157+
roundtripped = storage.to_data_source()
158+
159+
assert isinstance(roundtripped, SparkSource)
160+
assert roundtripped.table == source.table
161+
162+
163+
# ---------------------------------------------------------------------------
164+
# Group 3: SparkRetrievalJob.persist()
165+
# ---------------------------------------------------------------------------
166+
167+
168+
class TestSparkRetrievalJobPersist:
169+
def test_persist_with_table_saves_as_table(self, repo_config):
170+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
171+
storage = SavedDatasetSparkStorage(table="output_table")
172+
173+
job.persist(storage)
174+
175+
mock_df = job.spark_session.sql.return_value
176+
mock_df.write.saveAsTable.assert_called_once_with("output_table")
177+
178+
def test_persist_with_table_and_format(self, repo_config):
179+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
180+
storage = SavedDatasetSparkStorage(table="output_table", file_format="parquet")
181+
182+
job.persist(storage)
183+
184+
mock_df = job.spark_session.sql.return_value
185+
mock_df.write.format.assert_called_once_with("parquet")
186+
mock_df.write.format.return_value.saveAsTable.assert_called_once_with(
187+
"output_table"
188+
)
189+
190+
def test_persist_with_path_writes_to_path(self, repo_config):
191+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
192+
storage = SavedDatasetSparkStorage(
193+
path="s3a://bucket/output/", file_format="parquet"
194+
)
195+
196+
job.persist(storage)
197+
198+
mock_df = job.spark_session.sql.return_value
199+
mock_df.write.format.assert_called_once_with("parquet")
200+
mock_df.write.format.return_value.mode.assert_called_once_with("error")
201+
mock_df.write.format.return_value.mode.return_value.save.assert_called_once_with(
202+
"s3a://bucket/output/"
203+
)
204+
205+
def test_persist_with_path_defaults_to_parquet(self, repo_config):
206+
"""When path is set with table_format but no file_format, persist defaults to parquet."""
207+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
208+
storage = SavedDatasetSparkStorage(
209+
path="s3a://bucket/output/",
210+
file_format=None,
211+
table_format=IcebergFormat(catalog="test_catalog"),
212+
)
213+
214+
job.persist(storage)
215+
216+
mock_df = job.spark_session.sql.return_value
217+
mock_df.write.format.assert_called_once_with("parquet")
218+
219+
def test_persist_with_path_allow_overwrite(self, repo_config):
220+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
221+
storage = SavedDatasetSparkStorage(
222+
path="s3a://bucket/output/", file_format="parquet"
223+
)
224+
225+
job.persist(storage, allow_overwrite=True)
226+
227+
mock_df = job.spark_session.sql.return_value
228+
mock_df.write.format.return_value.mode.assert_called_once_with("overwrite")
229+
230+
def test_persist_with_path_custom_format(self, repo_config):
231+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
232+
storage = SavedDatasetSparkStorage(
233+
path="s3a://bucket/output/", file_format="avro"
234+
)
235+
236+
job.persist(storage)
237+
238+
mock_df = job.spark_session.sql.return_value
239+
mock_df.write.format.assert_called_once_with("avro")
240+
mock_df.write.format.return_value.mode.return_value.save.assert_called_once_with(
241+
"s3a://bucket/output/"
242+
)
243+
244+
def test_persist_raises_without_table_or_path(self, repo_config):
245+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=True)
246+
storage = SavedDatasetSparkStorage(query="SELECT * FROM t")
247+
248+
with pytest.raises(
249+
ValueError, match="either 'table' or 'path' must be specified"
250+
):
251+
job.persist(storage)
252+
253+
def test_persist_local_warehouse_creates_temp_view(self, repo_config):
254+
job = _make_spark_retrieval_job(repo_config, remote_warehouse=False)
255+
storage = SavedDatasetSparkStorage(table="output_table")
256+
257+
job.persist(storage)
258+
259+
mock_df = job.spark_session.sql.return_value
260+
mock_df.createOrReplaceTempView.assert_called_once_with("output_table")

0 commit comments

Comments
 (0)