Skip to content

Commit 3af968b

Browse files
committed
fix write node
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
1 parent 155d1ed commit 3af968b

File tree

4 files changed

+94
-43
lines changed

4 files changed

+94
-43
lines changed

sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
SparkRetrievalJob,
1818
_get_entity_schema,
1919
)
20+
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
21+
SparkSource,
22+
)
2023
from feast.infra.offline_stores.offline_utils import (
2124
infer_event_timestamp_from_entity_df,
2225
)
@@ -282,12 +285,30 @@ def execute(self, context: ExecutionContext) -> DAGValue:
282285
feature_view=self.feature_view, repo_config=context.repo_config
283286
)
284287

285-
# ✅ 1. Write to online or offline store (if enabled)
286-
if self.feature_view.online or self.feature_view.offline:
288+
# ✅ 1. Write to online store if online enabled
289+
if self.feature_view.online:
287290
spark_df.mapInArrow(
288-
lambda x: map_in_arrow(x, serialized_artifacts), spark_df.schema
291+
lambda x: map_in_arrow(x, serialized_artifacts, mode="online"),
292+
spark_df.schema,
289293
).count()
290294

295+
# ✅ 2. Write to offline store if offline enabled
296+
if self.feature_view.offline:
297+
if not isinstance(self.feature_view.batch_source, SparkSource):
298+
spark_df.mapInArrow(
299+
lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"),
300+
spark_df.schema,
301+
).count()
302+
# Directly write spark df to spark offline store without using mapInArrow
303+
else:
304+
dest_path = self.feature_view.batch_source.path
305+
file_format = self.feature_view.batch_source.file_format
306+
if not dest_path or not file_format:
307+
raise ValueError(
308+
"Destination path and file format must be specified for SparkSource."
309+
)
310+
spark_df.write.format(file_format).mode("append").save(dest_path)
311+
291312
return DAGValue(
292313
data=spark_df,
293314
format=DAGFormat.SPARK,

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Iterable, Optional
1+
from typing import Dict, Iterable, Literal, Optional
22

33
import pyarrow as pa
44
from pyspark import SparkConf
@@ -27,6 +27,7 @@ def get_or_create_new_spark_session(
2727
def map_in_arrow(
2828
iterator: Iterable[pa.RecordBatch],
2929
serialized_artifacts: "SerializedArtifacts",
30+
mode: Literal["online", "offline"] = "online",
3031
):
3132
for batch in iterator:
3233
table = pa.Table.from_batches([batch])
@@ -37,9 +38,8 @@ def map_in_arrow(
3738
offline_store,
3839
repo_config,
3940
) = serialized_artifacts.unserialize()
40-
print("write_feature_view", feature_view)
4141

42-
if feature_view.online:
42+
if mode == "online":
4343
join_key_to_value_type = {
4444
entity.name: entity.dtype.to_value_type()
4545
for entity in feature_view.entity_columns
@@ -55,8 +55,7 @@ def map_in_arrow(
5555
data=rows_to_write,
5656
progress=lambda x: None,
5757
)
58-
if feature_view.offline:
59-
print("offline_to_write", table)
58+
if mode == "offline":
6059
offline_store.offline_write_batch(
6160
config=repo_config,
6261
feature_view=feature_view,

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import uuid
44
import warnings
55
from datetime import datetime, timezone
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
77

88
import numpy as np
99
import pandas
@@ -54,6 +54,8 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
5454
region: Optional[StrictStr] = None
5555
""" AWS Region if applicable for s3-based staging locations"""
5656

57+
mode: Optional[Literal["driver", "worker"]] = "driver"
58+
5759

5860
class SparkOfflineStore(OfflineStore):
5961
@staticmethod
@@ -218,6 +220,22 @@ def offline_write_batch(
218220
table: pyarrow.Table,
219221
progress: Optional[Callable[[int], Any]],
220222
):
223+
"""
224+
Write pyarrow table to offline store.
225+
This method supports two execution modes:
226+
- "driver": Uses Spark to perform schema validation, type casting, and appending the data to the offline store.
227+
This mode must run on the Spark driver and supports advanced functionality like schema enforcement.
228+
- "worker": A simplified, worker-safe implementation that writes Arrow tables directly to storage.
229+
This mode is designed for distributed execution within mapInArrow or other parallel contexts.
230+
231+
Args:
232+
config: RepoConfig
233+
feature_view: FeatureView
234+
table: pyarrow.Table
235+
progress: Callable[[int], Any]
236+
mode: Literal["driver", "worker"], default is "driver"
237+
238+
"""
221239
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
222240
assert isinstance(feature_view.batch_source, SparkSource)
223241

@@ -230,38 +248,55 @@ def offline_write_batch(
230248
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
231249
)
232250

233-
spark_session = get_spark_session_or_start_new_with_repoconfig(
234-
store_config=config.offline_store
235-
)
251+
mode = config.offline_store.mode
236252

237-
if feature_view.batch_source.path:
238-
# write data to disk so that it can be loaded into spark (for preserving column types)
239-
with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file:
240-
print(tmp_file.name)
241-
pq.write_table(table, tmp_file.name)
242-
243-
# load data
244-
df_batch = spark_session.read.parquet(tmp_file.name)
245-
246-
# load existing data to get spark table schema
247-
df_existing = spark_session.read.format(
248-
feature_view.batch_source.file_format
249-
).load(feature_view.batch_source.path)
250-
251-
# cast columns if applicable
252-
df_batch = _cast_data_frame(df_batch, df_existing)
253-
254-
df_batch.write.format(feature_view.batch_source.file_format).mode(
255-
"append"
256-
).save(feature_view.batch_source.path)
257-
elif feature_view.batch_source.query:
258-
raise NotImplementedError(
259-
"offline_write_batch not implemented for batch sources specified by query"
253+
if mode == "driver":
254+
spark_session = get_spark_session_or_start_new_with_repoconfig(
255+
store_config=config.offline_store
260256
)
257+
258+
if feature_view.batch_source.path:
259+
# write data to disk so that it can be loaded into spark (for preserving column types)
260+
with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file:
261+
print(tmp_file.name)
262+
pq.write_table(table, tmp_file.name)
263+
264+
# load data
265+
df_batch = spark_session.read.parquet(tmp_file.name)
266+
267+
# load existing data to get spark table schema
268+
df_existing = spark_session.read.format(
269+
feature_view.batch_source.file_format
270+
).load(feature_view.batch_source.path)
271+
272+
# cast columns if applicable
273+
df_batch = _cast_data_frame(df_batch, df_existing)
274+
275+
df_batch.write.format(feature_view.batch_source.file_format).mode(
276+
"append"
277+
).save(feature_view.batch_source.path)
278+
elif feature_view.batch_source.query:
279+
raise NotImplementedError(
280+
"offline_write_batch not implemented for batch sources specified by query"
281+
)
282+
else:
283+
raise NotImplementedError(
284+
"offline_write_batch not implemented for batch sources specified by a table"
285+
)
286+
elif mode == "worker":
287+
# Safe worker-side Arrow write
288+
if not feature_view.batch_source.path:
289+
raise ValueError("Path is required for worker mode.")
290+
291+
unique_name = f"batch_{uuid.uuid4().hex}.parquet"
292+
output_path = os.path.join(feature_view.batch_source.path, unique_name)
293+
294+
pq.write_table(table, output_path)
295+
296+
if progress:
297+
progress(table.num_rows)
261298
else:
262-
raise NotImplementedError(
263-
"offline_write_batch not implemented for batch sources specified by a table"
264-
)
299+
raise ValueError(f"Unsupported mode: {mode}")
265300

266301
@staticmethod
267302
def pull_all_from_table_or_query(

sdk/python/tests/integration/compute_engines/spark/test_compute.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def tqdm_builder(length):
271271
fs=fs,
272272
feature="driver_hourly_stats:conv_rate",
273273
entity_df=entity_df,
274-
expected_value=1.6,
275274
)
276275
finally:
277276
spark_environment.teardown()
@@ -303,15 +302,12 @@ def _check_offline_features(
303302
fs,
304303
feature,
305304
entity_df,
306-
expected_value,
307305
):
308306
offline_df = fs.get_historical_features(
309307
entity_df=entity_df,
310308
features=[feature],
311309
).to_df()
312-
313-
assert len(offline_df) == 2
314-
assert offline_df["driver_id"].to_list() == [1001, 1002]
310+
assert len(offline_df) == 4
315311

316312

317313
if __name__ == "__main__":

0 commit comments

Comments
 (0)