33import uuid
44import warnings
55from datetime import datetime , timezone
6- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Literal , Optional , Tuple , Union
77
88import numpy as np
99import pandas
@@ -54,6 +54,8 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
5454 region : Optional [StrictStr ] = None
5555 """ AWS Region if applicable for s3-based staging locations"""
5656
57+ mode : Optional [Literal ["driver" , "worker" ]] = "driver"
58+
5759
5860class SparkOfflineStore (OfflineStore ):
5961 @staticmethod
@@ -218,6 +220,22 @@ def offline_write_batch(
218220 table : pyarrow .Table ,
219221 progress : Optional [Callable [[int ], Any ]],
220222 ):
223+ """
224+ Write pyarrow table to offline store.
225+ This method supports two execution modes:
226+ - "driver": Uses Spark to perform schema validation, type casting, and appending the data to the offline store.
227+ This mode must run on the Spark driver and supports advanced functionality like schema enforcement.
228+ - "worker": A simplified, worker-safe implementation that writes Arrow tables directly to storage.
229+ This mode is designed for distributed execution within mapInArrow or other parallel contexts.
230+
231+ Args:
232+ config: RepoConfig
233+ feature_view: FeatureView
234+ table: pyarrow.Table
235+ progress: Callable[[int], Any]
236+ mode: Literal["driver", "worker"], default is "driver"
237+
238+ """
221239 assert isinstance (config .offline_store , SparkOfflineStoreConfig )
222240 assert isinstance (feature_view .batch_source , SparkSource )
223241
@@ -230,38 +248,55 @@ def offline_write_batch(
230248 f"The schema is expected to be { pa_schema } with the columns (in this exact order) to be { column_names } ."
231249 )
232250
233- spark_session = get_spark_session_or_start_new_with_repoconfig (
234- store_config = config .offline_store
235- )
251+ mode = config .offline_store .mode
236252
237- if feature_view .batch_source .path :
238- # write data to disk so that it can be loaded into spark (for preserving column types)
239- with tempfile .NamedTemporaryFile (suffix = ".parquet" ) as tmp_file :
240- print (tmp_file .name )
241- pq .write_table (table , tmp_file .name )
242-
243- # load data
244- df_batch = spark_session .read .parquet (tmp_file .name )
245-
246- # load existing data to get spark table schema
247- df_existing = spark_session .read .format (
248- feature_view .batch_source .file_format
249- ).load (feature_view .batch_source .path )
250-
251- # cast columns if applicable
252- df_batch = _cast_data_frame (df_batch , df_existing )
253-
254- df_batch .write .format (feature_view .batch_source .file_format ).mode (
255- "append"
256- ).save (feature_view .batch_source .path )
257- elif feature_view .batch_source .query :
258- raise NotImplementedError (
259- "offline_write_batch not implemented for batch sources specified by query"
253+ if mode == "driver" :
254+ spark_session = get_spark_session_or_start_new_with_repoconfig (
255+ store_config = config .offline_store
260256 )
257+
258+ if feature_view .batch_source .path :
259+ # write data to disk so that it can be loaded into spark (for preserving column types)
260+ with tempfile .NamedTemporaryFile (suffix = ".parquet" ) as tmp_file :
261+ print (tmp_file .name )
262+ pq .write_table (table , tmp_file .name )
263+
264+ # load data
265+ df_batch = spark_session .read .parquet (tmp_file .name )
266+
267+ # load existing data to get spark table schema
268+ df_existing = spark_session .read .format (
269+ feature_view .batch_source .file_format
270+ ).load (feature_view .batch_source .path )
271+
272+ # cast columns if applicable
273+ df_batch = _cast_data_frame (df_batch , df_existing )
274+
275+ df_batch .write .format (feature_view .batch_source .file_format ).mode (
276+ "append"
277+ ).save (feature_view .batch_source .path )
278+ elif feature_view .batch_source .query :
279+ raise NotImplementedError (
280+ "offline_write_batch not implemented for batch sources specified by query"
281+ )
282+ else :
283+ raise NotImplementedError (
284+ "offline_write_batch not implemented for batch sources specified by a table"
285+ )
286+ elif mode == "worker" :
287+ # Safe worker-side Arrow write
288+ if not feature_view .batch_source .path :
289+ raise ValueError ("Path is required for worker mode." )
290+
291+ unique_name = f"batch_{ uuid .uuid4 ().hex } .parquet"
292+ output_path = os .path .join (feature_view .batch_source .path , unique_name )
293+
294+ pq .write_table (table , output_path )
295+
296+ if progress :
297+ progress (table .num_rows )
261298 else :
262- raise NotImplementedError (
263- "offline_write_batch not implemented for batch sources specified by a table"
264- )
299+ raise ValueError (f"Unsupported mode: { mode } " )
265300
266301 @staticmethod
267302 def pull_all_from_table_or_query (
0 commit comments