Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/reference/offline-stores/spark.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ offline_store:
spark.sql.session.timeZone: "UTC"
spark.sql.execution.arrow.fallback.enabled: "true"
spark.sql.execution.arrow.pyspark.enabled: "true"
# Optional: spill large materializations to the staging location instead of collecting in the driver
staging_location: "s3://my-bucket/tmp/feast"
online_store:
path: data/online_store.db
```

> The `staging_location` can point to object storage (like S3, GCS, or Azure blobs) or a local filesystem directory (e.g., `/tmp/feast/staging`) to spill large materialization outputs before reading them back into Feast.
{% endcode %}

The full set of configuration options is available in [SparkOfflineStoreConfig](https://rtd.feast.dev/en/master/#feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStoreConfig).
Expand All @@ -60,7 +64,7 @@ Below is a matrix indicating which functionality is supported by `SparkRetrieval
| export to arrow table | yes |
| export to arrow batches | no |
| export to SQL | no |
| export to data lake (S3, GCS, etc.) | no |
| export to data lake (S3, GCS, etc.) | yes |
| export to data warehouse | no |
| export as Spark dataframe | yes |
| local execution of Python-based on-demand transforms | no |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Union,
cast,
)
from urllib.parse import urlparse

if TYPE_CHECKING:
from feast.saved_dataset import ValidationReference
Expand All @@ -24,6 +25,7 @@
import pandas
import pandas as pd
import pyarrow
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import pyspark
from pydantic import StrictStr
Expand Down Expand Up @@ -445,8 +447,43 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:

def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
"""Return dataset as pyarrow Table synchronously"""
if self._should_use_staging_for_arrow():
return self._to_arrow_via_staging()

return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout))

def _should_use_staging_for_arrow(self) -> bool:
offline_store = getattr(self._config, "offline_store", None)
return bool(
isinstance(offline_store, SparkOfflineStoreConfig)
and getattr(offline_store, "staging_location", None)
)

def _to_arrow_via_staging(self) -> pyarrow.Table:
paths = self.to_remote_storage()
if not paths:
return pyarrow.table({})

parquet_paths = _filter_parquet_files(paths)
if not parquet_paths:
return pyarrow.table({})

normalized_paths = self._normalize_staging_paths(parquet_paths)
dataset = ds.dataset(normalized_paths, format="parquet")
return dataset.to_table()

def _normalize_staging_paths(self, paths: List[str]) -> List[str]:
"""Normalize staging paths for PyArrow datasets."""
normalized = []
for path in paths:
if path.startswith("file://"):
normalized.append(path[len("file://") :])
elif "://" in path:
normalized.append(path)
else:
normalized.append(path)
return normalized

def to_feast_df(
self,
validation_reference: Optional["ValidationReference"] = None,
Expand Down Expand Up @@ -508,55 +545,53 @@ def supports_remote_storage_export(self) -> bool:

def to_remote_storage(self) -> List[str]:
"""Currently only works for local and s3-based staging locations"""
if self.supports_remote_storage_export():
sdf: pyspark.sql.DataFrame = self.to_spark_df()

if self._config.offline_store.staging_location.startswith("/"):
local_file_staging_location = os.path.abspath(
self._config.offline_store.staging_location
)

# write to staging location
output_uri = os.path.join(
str(local_file_staging_location), str(uuid.uuid4())
)
sdf.write.parquet(output_uri)

return _list_files_in_folder(output_uri)
elif self._config.offline_store.staging_location.startswith("s3://"):
from feast.infra.utils import aws_utils

spark_compatible_s3_staging_location = (
self._config.offline_store.staging_location.replace(
"s3://", "s3a://"
)
)

# write to staging location
output_uri = os.path.join(
str(spark_compatible_s3_staging_location), str(uuid.uuid4())
)
sdf.write.parquet(output_uri)

return aws_utils.list_s3_files(
self._config.offline_store.region, output_uri
)
elif self._config.offline_store.staging_location.startswith("hdfs://"):
output_uri = os.path.join(
self._config.offline_store.staging_location, str(uuid.uuid4())
)
sdf.write.parquet(output_uri)
spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=self._config.offline_store
)
return _list_hdfs_files(spark_session, output_uri)
else:
raise NotImplementedError(
"to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes"
)
if not self.supports_remote_storage_export():
raise NotImplementedError()

sdf: pyspark.sql.DataFrame = self.to_spark_df()
staging_location = self._config.offline_store.staging_location

if staging_location.startswith("/"):
local_file_staging_location = os.path.abspath(staging_location)
output_uri = os.path.join(local_file_staging_location, str(uuid.uuid4()))
sdf.write.parquet(output_uri)
return _list_files_in_folder(output_uri)
elif staging_location.startswith("s3://"):
from feast.infra.utils import aws_utils

spark_compatible_s3_staging_location = staging_location.replace(
"s3://", "s3a://"
)
output_uri = os.path.join(
spark_compatible_s3_staging_location, str(uuid.uuid4())
)
sdf.write.parquet(output_uri)
s3_uri_for_listing = output_uri.replace("s3a://", "s3://", 1)
return aws_utils.list_s3_files(
self._config.offline_store.region, s3_uri_for_listing
)
elif staging_location.startswith("gs://"):
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
sdf.write.parquet(output_uri)
return _list_gcs_files(output_uri)
elif staging_location.startswith(("wasbs://", "abfs://", "abfss://")) or (
staging_location.startswith("https://")
and ".blob.core.windows.net" in staging_location
):
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
sdf.write.parquet(output_uri)
return _list_azure_files(output_uri)
elif staging_location.startswith("hdfs://"):
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
sdf.write.parquet(output_uri)
spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=self._config.offline_store
)
return _list_hdfs_files(spark_session, output_uri)
else:
raise NotImplementedError()
raise NotImplementedError(
"to_remote_storage is only implemented for file://, s3://, gs://, azure, and hdfs uri schemes"
)

@property
def metadata(self) -> Optional[RetrievalMetadata]:
Expand Down Expand Up @@ -789,6 +824,10 @@ def _list_files_in_folder(folder):
return files


def _filter_parquet_files(paths: List[str]) -> List[str]:
return [path for path in paths if path.endswith(".parquet")]


def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
jvm = spark_session._jvm
jsc = spark_session._jsc
Expand All @@ -805,6 +844,81 @@ def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
return files


def _list_gcs_files(path: str) -> List[str]:
try:
from google.cloud import storage
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError

raise FeastExtrasDependencyImportError("gcp", str(e))

assert path.startswith("gs://"), "GCS path must start with gs://"
bucket_path = path[len("gs://") :]
if "/" in bucket_path:
bucket, prefix = bucket_path.split("/", 1)
else:
bucket, prefix = bucket_path, ""

client = storage.Client()
bucket_obj = client.bucket(bucket)
blobs = bucket_obj.list_blobs(prefix=prefix)

files = []
for blob in blobs:
if not blob.name.endswith("/"):
files.append(f"gs://{bucket}/{blob.name}")
return files


def _list_azure_files(path: str) -> List[str]:
try:
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError

raise FeastExtrasDependencyImportError("azure", str(e))

parsed = urlparse(path)
scheme = parsed.scheme

if scheme in ("wasbs", "abfs", "abfss"):
if "@" not in parsed.netloc:
raise ValueError("Azure staging URI must include container@account host")
container, account_host = parsed.netloc.split("@", 1)
account_url = f"https://{account_host}"
base = f"{scheme}://{container}@{account_host}"
prefix = parsed.path.lstrip("/")
else:
account_url = f"{parsed.scheme}://{parsed.netloc}"
container_and_prefix = parsed.path.lstrip("/").split("/", 1)
container = container_and_prefix[0]
base = f"{account_url}/{container}"
prefix = container_and_prefix[1] if len(container_and_prefix) > 1 else ""

credential = os.environ.get("AZURE_STORAGE_KEY") or os.environ.get(
"AZURE_STORAGE_ACCOUNT_KEY"
)
if credential:
client = BlobServiceClient(account_url=account_url, credential=credential)
else:
default_credential = DefaultAzureCredential(
exclude_shared_token_cache_credential=True
)
client = BlobServiceClient(
account_url=account_url, credential=default_credential
)

container_client = client.get_container_client(container)
blobs = container_client.list_blobs(name_starts_with=prefix if prefix else None)

files = []
for blob in blobs:
if not blob.name.endswith("/"):
files.append(f"{base}/{blob.name}")
return files


def _cast_data_frame(
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
) -> pyspark.sql.DataFrame:
Expand Down
Loading
Loading