Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def __init__(self, name, project=None):
super().__init__(f"Feature view {name} does not exist")


class InvalidSparkSessionException(Exception):
def __init__(self, spark_arg):
super().__init__(
f" Need Spark Session to convert results to spark data frame\
recieved {type(spark_arg)} instead. "
)


class OnDemandFeatureViewNotFoundException(FeastObjectNotFoundException):
def __init__(self, name, project=None):
if project:
Expand Down
46 changes: 45 additions & 1 deletion sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import contextlib
import os
import uuid
import warnings
from datetime import datetime
from functools import reduce
from pathlib import Path
from typing import (
Any,
Expand All @@ -21,11 +23,16 @@
import pyarrow
from pydantic import Field, StrictStr
from pydantic.typing import Literal
from pyspark.sql import DataFrame, SparkSession
from pytz import utc

from feast import OnDemandFeatureView
from feast.data_source import DataSource
from feast.errors import EntitySQLEmptyResults, InvalidEntityType
from feast.errors import (
EntitySQLEmptyResults,
InvalidEntityType,
InvalidSparkSessionException,
)
from feast.feature_logging import LoggingConfig, LoggingSource
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
Expand Down Expand Up @@ -57,6 +64,8 @@

raise FeastExtrasDependencyImportError("snowflake", str(e))

warnings.filterwarnings("ignore", category=DeprecationWarning)


class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
"""Offline store config for Snowflake"""
Expand Down Expand Up @@ -447,6 +456,41 @@ def to_sql(self) -> str:
with self._query_generator() as query:
return query

def to_spark_df(self, spark_session: SparkSession) -> DataFrame:
"""
Method to convert snowflake query results to pyspark data frame.

Args:
spark_session: spark Session variable of current environment.

Returns:
spark_df: A pyspark dataframe.
"""

if isinstance(spark_session, SparkSession):
with self._query_generator() as query:

arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_batches()

if arrow_batches:
spark_df = reduce(
DataFrame.unionAll,
[
spark_session.createDataFrame(batch.to_pandas())
for batch in arrow_batches
],
)

return spark_df

else:
raise EntitySQLEmptyResults(query)

else:
raise InvalidSparkSessionException(spark_session)

def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
assert isinstance(storage, SavedDatasetSnowflakeStorage)
self.to_snowflake(table_name=storage.snowflake_options.table)
Expand Down