Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import tempfile
import uuid
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand All @@ -13,6 +15,7 @@
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pytz import utc
from sdk.python.feast.infra.utils import aws_utils

from feast import FeatureView, OnDemandFeatureView
from feast.data_source import DataSource
Expand Down Expand Up @@ -46,6 +49,12 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
""" Configuration overlay for the spark session """
# sparksession is not serializable and we dont want to pass it around as an argument

staging_location: Optional[StrictStr] = None
""" Remote path for batch materialization jobs"""

region: Optional[StrictStr] = None
""" AWS Region if applicable for s3-based staging locations"""


class SparkOfflineStore(OfflineStore):
@staticmethod
Expand Down Expand Up @@ -105,6 +114,7 @@ def pull_latest_from_table_or_query(
return SparkRetrievalJob(
spark_session=spark_session,
query=query,
config=config,
full_feature_names=False,
on_demand_feature_views=None,
)
Expand All @@ -129,6 +139,7 @@ def get_historical_features(
"Some functionality may still be unstable so functionality can change in the future.",
RuntimeWarning,
)

spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)
Expand Down Expand Up @@ -192,6 +203,7 @@ def get_historical_features(
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
config=config,
)

@staticmethod
Expand Down Expand Up @@ -286,7 +298,10 @@ def pull_all_from_table_or_query(
"""

return SparkRetrievalJob(
spark_session=spark_session, query=query, full_feature_names=False
spark_session=spark_session,
query=query,
full_feature_names=False,
config=config,
)


Expand All @@ -296,6 +311,7 @@ def __init__(
spark_session: SparkSession,
query: str,
full_feature_names: bool,
config: RepoConfig,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
):
Expand All @@ -305,6 +321,7 @@ def __init__(
self._full_feature_names = full_feature_names
self._on_demand_feature_views = on_demand_feature_views or []
self._metadata = metadata
self._config = config

@property
def full_feature_names(self) -> bool:
Expand Down Expand Up @@ -342,6 +359,53 @@ def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
raise ValueError("Cannot persist, table_name is not defined")
self.to_spark_df().createOrReplaceTempView(table_name)

def supports_remote_storage_export(self) -> bool:
return self._config.offline_store.staging_location is not None

def to_remote_storage(self) -> List[str]:
"""Currently only works for local and s3-based staging locations"""
if self.supports_remote_storage_export():

sdf: pyspark.sql.DataFrame = self.to_spark_df()

if self._config.offline_store.staging_location.startswith("file://"):
local_file_staging_location = os.path.abspath(
self._config.offline_store.staging_location
)

# write to staging location
output_uri = os.path.join(
str(local_file_staging_location), str(uuid.uuid4())
)
sdf.write.parquet(output_uri)

return _list_files_in_folder(output_uri)
elif self._config.offline_store.staging_location.startswith("s3://"):

spark_compatible_s3_staging_location = (
self._config.offline_store.staging_location.replace(
"s3://", "s3a://"
)
)

# write to staging location
output_uri = os.path.join(
str(spark_compatible_s3_staging_location), str(uuid.uuid4())
)
sdf.write.parquet(output_uri)

return aws_utils.list_s3_files(
self._config.offline_store.region, output_uri
)

else:
raise NotImplementedError(
"to_remote_storage is only implemented for file:// and s3:// uri schemes"
)

else:
raise NotImplementedError()

@property
def metadata(self) -> Optional[RetrievalMetadata]:
"""
Expand Down Expand Up @@ -444,6 +508,17 @@ def _format_datetime(t: datetime) -> str:
return dt


def _list_files_in_folder(folder):
"""List full filenames in a folder"""
files = []
for file in os.listdir(folder):
filename = os.path.join(folder, file)
if os.path.isfile(filename):
files.append(filename)

return files


def _cast_data_frame(
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
) -> pyspark.sql.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def create_offline_store_config(self):
self.spark_offline_store_config = SparkOfflineStoreConfig()
self.spark_offline_store_config.type = "spark"
self.spark_offline_store_config.spark_conf = self.spark_conf
self.spark_offline_store_config.staging_location = "file://" + str(
tempfile.TemporaryDirectory()
)
self.spark_offline_store_config.region = "eu-west-1"
return self.spark_offline_store_config

def create_data_source(
Expand Down