Skip to content

Commit 5b787af

Browse files
authored
feat: Support staging for spark materialization (#5671) (#5797)
* feat: support staging for spark materialization (#5671) Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com> * feat: support staging for spark materialization ([#5671](#5671)) Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com> --------- Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com>
1 parent 58d0325 commit 5b787af

File tree

3 files changed

+374
-48
lines changed
  • docs/reference/offline-stores
  • sdk/python
    • feast/infra/offline_stores/contrib/spark_offline_store
    • tests/unit/infra/offline_stores/contrib/spark_offline_store

3 files changed

+374
-48
lines changed

docs/reference/offline-stores/spark.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ offline_store:
3232
spark.sql.session.timeZone: "UTC"
3333
spark.sql.execution.arrow.fallback.enabled: "true"
3434
spark.sql.execution.arrow.pyspark.enabled: "true"
35+
# Optional: spill large materializations to the staging location instead of collecting in the driver
36+
staging_location: "s3://my-bucket/tmp/feast"
3537
online_store:
3638
path: data/online_store.db
3739
```
40+
41+
> The `staging_location` can point to object storage (like S3, GCS, or Azure blobs) or a local filesystem directory (e.g., `/tmp/feast/staging`) to spill large materialization outputs before reading them back into Feast.
3842
{% endcode %}
3943

4044
The full set of configuration options is available in [SparkOfflineStoreConfig](https://rtd.feast.dev/en/master/#feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStoreConfig).
@@ -60,7 +64,7 @@ Below is a matrix indicating which functionality is supported by `SparkRetrieval
6064
| export to arrow table | yes |
6165
| export to arrow batches | no |
6266
| export to SQL | no |
63-
| export to data lake (S3, GCS, etc.) | no |
67+
| export to data lake (S3, GCS, etc.) | yes |
6468
| export to data warehouse | no |
6569
| export as Spark dataframe | yes |
6670
| local execution of Python-based on-demand transforms | no |

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 161 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Union,
1717
cast,
1818
)
19+
from urllib.parse import urlparse
1920

2021
if TYPE_CHECKING:
2122
from feast.saved_dataset import ValidationReference
@@ -24,6 +25,7 @@
2425
import pandas
2526
import pandas as pd
2627
import pyarrow
28+
import pyarrow.dataset as ds
2729
import pyarrow.parquet as pq
2830
import pyspark
2931
from pydantic import StrictStr
@@ -445,8 +447,43 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
445447

446448
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
447449
"""Return dataset as pyarrow Table synchronously"""
450+
if self._should_use_staging_for_arrow():
451+
return self._to_arrow_via_staging()
452+
448453
return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout))
449454

455+
def _should_use_staging_for_arrow(self) -> bool:
456+
offline_store = getattr(self._config, "offline_store", None)
457+
return bool(
458+
isinstance(offline_store, SparkOfflineStoreConfig)
459+
and getattr(offline_store, "staging_location", None)
460+
)
461+
462+
def _to_arrow_via_staging(self) -> pyarrow.Table:
463+
paths = self.to_remote_storage()
464+
if not paths:
465+
return pyarrow.table({})
466+
467+
parquet_paths = _filter_parquet_files(paths)
468+
if not parquet_paths:
469+
return pyarrow.table({})
470+
471+
normalized_paths = self._normalize_staging_paths(parquet_paths)
472+
dataset = ds.dataset(normalized_paths, format="parquet")
473+
return dataset.to_table()
474+
475+
def _normalize_staging_paths(self, paths: List[str]) -> List[str]:
476+
"""Normalize staging paths for PyArrow datasets."""
477+
normalized = []
478+
for path in paths:
479+
if path.startswith("file://"):
480+
normalized.append(path[len("file://") :])
481+
elif "://" in path:
482+
normalized.append(path)
483+
else:
484+
normalized.append(path)
485+
return normalized
486+
450487
def to_feast_df(
451488
self,
452489
validation_reference: Optional["ValidationReference"] = None,
@@ -508,55 +545,53 @@ def supports_remote_storage_export(self) -> bool:
508545

509546
def to_remote_storage(self) -> List[str]:
510547
"""Currently only works for local and s3-based staging locations"""
511-
if self.supports_remote_storage_export():
512-
sdf: pyspark.sql.DataFrame = self.to_spark_df()
513-
514-
if self._config.offline_store.staging_location.startswith("/"):
515-
local_file_staging_location = os.path.abspath(
516-
self._config.offline_store.staging_location
517-
)
518-
519-
# write to staging location
520-
output_uri = os.path.join(
521-
str(local_file_staging_location), str(uuid.uuid4())
522-
)
523-
sdf.write.parquet(output_uri)
524-
525-
return _list_files_in_folder(output_uri)
526-
elif self._config.offline_store.staging_location.startswith("s3://"):
527-
from feast.infra.utils import aws_utils
528-
529-
spark_compatible_s3_staging_location = (
530-
self._config.offline_store.staging_location.replace(
531-
"s3://", "s3a://"
532-
)
533-
)
534-
535-
# write to staging location
536-
output_uri = os.path.join(
537-
str(spark_compatible_s3_staging_location), str(uuid.uuid4())
538-
)
539-
sdf.write.parquet(output_uri)
540-
541-
return aws_utils.list_s3_files(
542-
self._config.offline_store.region, output_uri
543-
)
544-
elif self._config.offline_store.staging_location.startswith("hdfs://"):
545-
output_uri = os.path.join(
546-
self._config.offline_store.staging_location, str(uuid.uuid4())
547-
)
548-
sdf.write.parquet(output_uri)
549-
spark_session = get_spark_session_or_start_new_with_repoconfig(
550-
store_config=self._config.offline_store
551-
)
552-
return _list_hdfs_files(spark_session, output_uri)
553-
else:
554-
raise NotImplementedError(
555-
"to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes"
556-
)
548+
if not self.supports_remote_storage_export():
549+
raise NotImplementedError()
550+
551+
sdf: pyspark.sql.DataFrame = self.to_spark_df()
552+
staging_location = self._config.offline_store.staging_location
553+
554+
if staging_location.startswith("/"):
555+
local_file_staging_location = os.path.abspath(staging_location)
556+
output_uri = os.path.join(local_file_staging_location, str(uuid.uuid4()))
557+
sdf.write.parquet(output_uri)
558+
return _list_files_in_folder(output_uri)
559+
elif staging_location.startswith("s3://"):
560+
from feast.infra.utils import aws_utils
557561

562+
spark_compatible_s3_staging_location = staging_location.replace(
563+
"s3://", "s3a://"
564+
)
565+
output_uri = os.path.join(
566+
spark_compatible_s3_staging_location, str(uuid.uuid4())
567+
)
568+
sdf.write.parquet(output_uri)
569+
s3_uri_for_listing = output_uri.replace("s3a://", "s3://", 1)
570+
return aws_utils.list_s3_files(
571+
self._config.offline_store.region, s3_uri_for_listing
572+
)
573+
elif staging_location.startswith("gs://"):
574+
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
575+
sdf.write.parquet(output_uri)
576+
return _list_gcs_files(output_uri)
577+
elif staging_location.startswith(("wasbs://", "abfs://", "abfss://")) or (
578+
staging_location.startswith("https://")
579+
and ".blob.core.windows.net" in staging_location
580+
):
581+
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
582+
sdf.write.parquet(output_uri)
583+
return _list_azure_files(output_uri)
584+
elif staging_location.startswith("hdfs://"):
585+
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
586+
sdf.write.parquet(output_uri)
587+
spark_session = get_spark_session_or_start_new_with_repoconfig(
588+
store_config=self._config.offline_store
589+
)
590+
return _list_hdfs_files(spark_session, output_uri)
558591
else:
559-
raise NotImplementedError()
592+
raise NotImplementedError(
593+
"to_remote_storage is only implemented for file://, s3://, gs://, azure, and hdfs uri schemes"
594+
)
560595

561596
@property
562597
def metadata(self) -> Optional[RetrievalMetadata]:
@@ -789,6 +824,10 @@ def _list_files_in_folder(folder):
789824
return files
790825

791826

827+
def _filter_parquet_files(paths: List[str]) -> List[str]:
828+
return [path for path in paths if path.endswith(".parquet")]
829+
830+
792831
def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
793832
jvm = spark_session._jvm
794833
jsc = spark_session._jsc
@@ -805,6 +844,81 @@ def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
805844
return files
806845

807846

847+
def _list_gcs_files(path: str) -> List[str]:
848+
try:
849+
from google.cloud import storage
850+
except ImportError as e:
851+
from feast.errors import FeastExtrasDependencyImportError
852+
853+
raise FeastExtrasDependencyImportError("gcp", str(e))
854+
855+
assert path.startswith("gs://"), "GCS path must start with gs://"
856+
bucket_path = path[len("gs://") :]
857+
if "/" in bucket_path:
858+
bucket, prefix = bucket_path.split("/", 1)
859+
else:
860+
bucket, prefix = bucket_path, ""
861+
862+
client = storage.Client()
863+
bucket_obj = client.bucket(bucket)
864+
blobs = bucket_obj.list_blobs(prefix=prefix)
865+
866+
files = []
867+
for blob in blobs:
868+
if not blob.name.endswith("/"):
869+
files.append(f"gs://{bucket}/{blob.name}")
870+
return files
871+
872+
873+
def _list_azure_files(path: str) -> List[str]:
874+
try:
875+
from azure.identity import DefaultAzureCredential
876+
from azure.storage.blob import BlobServiceClient
877+
except ImportError as e:
878+
from feast.errors import FeastExtrasDependencyImportError
879+
880+
raise FeastExtrasDependencyImportError("azure", str(e))
881+
882+
parsed = urlparse(path)
883+
scheme = parsed.scheme
884+
885+
if scheme in ("wasbs", "abfs", "abfss"):
886+
if "@" not in parsed.netloc:
887+
raise ValueError("Azure staging URI must include container@account host")
888+
container, account_host = parsed.netloc.split("@", 1)
889+
account_url = f"https://{account_host}"
890+
base = f"{scheme}://{container}@{account_host}"
891+
prefix = parsed.path.lstrip("/")
892+
else:
893+
account_url = f"{parsed.scheme}://{parsed.netloc}"
894+
container_and_prefix = parsed.path.lstrip("/").split("/", 1)
895+
container = container_and_prefix[0]
896+
base = f"{account_url}/{container}"
897+
prefix = container_and_prefix[1] if len(container_and_prefix) > 1 else ""
898+
899+
credential = os.environ.get("AZURE_STORAGE_KEY") or os.environ.get(
900+
"AZURE_STORAGE_ACCOUNT_KEY"
901+
)
902+
if credential:
903+
client = BlobServiceClient(account_url=account_url, credential=credential)
904+
else:
905+
default_credential = DefaultAzureCredential(
906+
exclude_shared_token_cache_credential=True
907+
)
908+
client = BlobServiceClient(
909+
account_url=account_url, credential=default_credential
910+
)
911+
912+
container_client = client.get_container_client(container)
913+
blobs = container_client.list_blobs(name_starts_with=prefix if prefix else None)
914+
915+
files = []
916+
for blob in blobs:
917+
if not blob.name.endswith("/"):
918+
files.append(f"{base}/{blob.name}")
919+
return files
920+
921+
808922
def _cast_data_frame(
809923
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
810924
) -> pyspark.sql.DataFrame:

0 commit comments

Comments
 (0)