Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions sdk/python/feast/batch_feature_view.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import functools
import warnings
from datetime import datetime, timedelta
from types import FunctionType
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import dill

Expand Down Expand Up @@ -61,7 +60,7 @@ class BatchFeatureView(FeatureView):
owner: str
timestamp_field: str
materialization_intervals: List[Tuple[datetime, datetime]]
udf: Optional[FunctionType]
udf: Optional[Callable[[Any], Any]]
udf_string: Optional[str]
feature_transformation: Transformation

Expand All @@ -78,7 +77,7 @@ def __init__(
description: str = "",
owner: str = "",
schema: Optional[List[Field]] = None,
udf: Optional[FunctionType] = None,
udf: Optional[Callable[[Any], Any]],
udf_string: Optional[str] = "",
feature_transformation: Optional[Transformation] = None,
):
Expand Down
Empty file.
Empty file.
20 changes: 20 additions & 0 deletions sdk/python/feast/infra/compute_engines/spark/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Dict, Optional

from pydantic import StrictStr

from feast.repo_config import FeastConfigBaseModel


class SparkComputeConfig(FeastConfigBaseModel):
type: StrictStr = "spark"
""" Spark Compute type selector"""

spark_conf: Optional[Dict[str, str]] = None
""" Configuration overlay for the spark session """
# sparksession is not serializable and we dont want to pass it around as an argument

staging_location: Optional[StrictStr] = None
""" Remote path for batch materialization jobs"""

region: Optional[StrictStr] = None
""" AWS Region if applicable for s3-based staging locations"""
19 changes: 19 additions & 0 deletions sdk/python/feast/infra/compute_engines/spark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict, Optional

from pyspark import SparkConf
from pyspark.sql import SparkSession


def get_or_create_new_spark_session(
spark_config: Optional[Dict[str, str]] = None,
) -> SparkSession:
spark_session = SparkSession.getActiveSession()
if not spark_session:
spark_builder = SparkSession.builder
if spark_config:
spark_builder = spark_builder.config(
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
)

spark_session = spark_builder.getOrCreate()
return spark_session
3 changes: 2 additions & 1 deletion sdk/python/feast/stream_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ def get_feature_transformation(self) -> Optional[Transformation]:
if self.mode in (
TransformationMode.PANDAS,
TransformationMode.PYTHON,
TransformationMode.SPARK_SQL,
TransformationMode.SPARK,
) or self.mode in ("pandas", "python", "spark"):
) or self.mode in ("pandas", "python", "spark_sql", "spark"):
return Transformation(
mode=self.mode, udf=self.udf, udf_string=self.udf_string or ""
)
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/transformation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
description: str = "",
owner: str = "",
):
self.mode = mode if isinstance(mode, str) else mode.value
self.mode = mode
self.udf = udf
self.udf_string = udf_string
self.name = name
Expand All @@ -99,7 +99,7 @@ def to_proto(self) -> Union[UserDefinedFunctionProto, SubstraitTransformationPro
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "Transformation":
return Transformation(mode=self.mode, udf=self.udf, udf_string=self.udf_string)

def transform(self, inputs: Any) -> Any:
def transform(self, *inputs: Any) -> Any:
raise NotImplementedError

def transform_arrow(self, *args, **kwargs) -> Any:
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/transformation/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"pandas": "feast.transformation.pandas_transformation.PandasTransformation",
"substrait": "feast.transformation.substrait_transformation.SubstraitTransformation",
"sql": "feast.transformation.sql_transformation.SQLTransformation",
"spark_sql": "feast.transformation.spark_transformation.SparkTransformation",
"spark": "feast.transformation.spark_transformation.SparkTransformation",
}

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/transformation/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class TransformationMode(Enum):
PYTHON = "python"
PANDAS = "pandas"
SPARK_SQL = "spark_sql"
SPARK = "spark"
SQL = "sql"
SUBSTRAIT = "substrait"
120 changes: 117 additions & 3 deletions sdk/python/feast/transformation/spark_transformation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,125 @@
from typing import Any
from typing import Any, Dict, Optional, Union, cast

import pandas as pd
import pyspark.sql

from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session
from feast.transformation.base import Transformation
from feast.transformation.mode import TransformationMode


class SparkTransformation(Transformation):
def transform(self, inputs: Any) -> Any:
pass
r"""
SparkTransformation can be used to define a transformation using a Spark UDF or SQL query.
The current spark session will be used or a new one will be created if not available.
E.g.:
spark_transformation = SparkTransformation(
mode=TransformationMode.SPARK,
udf=remove_extra_spaces,
udf_string="remove extra spaces",
)
OR
spark_transformation = Transformation(
mode=TransformationMode.SPARK_SQL,
udf=remove_extra_spaces_sql,
udf_string="remove extra spaces sql",
)
OR
@transformation(mode=TransformationMode.SPARK)
def remove_extra_spaces_udf(df: pd.DataFrame) -> pd.DataFrame:
return df.assign(name=df['name'].str.replace('\s+', ' '))
"""

def __new__(
cls,
mode: Union[TransformationMode, str],
udf: Any,
udf_string: str,
spark_config: Dict[str, Any] = {},
name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
description: str = "",
owner: str = "",
*args,
**kwargs,
) -> "SparkTransformation":
"""
Creates a SparkTransformation
Args:
mode: (required) The mode of the transformation. Choose one from TransformationMode.SPARK or TransformationMode.SPARK_SQL.
udf: (required) The user-defined transformation function.
udf_string: (required) The string representation of the udf. The dill get source doesn't
spark_config: (optional) The spark configuration to use for the transformation.
name: (optional) The name of the transformation.
tags: (optional) Metadata tags for the transformation.
description: (optional) A description of the transformation.
owner: (optional) The owner of the transformation.
"""
instance = super(SparkTransformation, cls).__new__(
cls,
mode=mode,
spark_config=spark_config,
udf=udf,
udf_string=udf_string,
name=name,
tags=tags,
description=description,
owner=owner,
)
return cast(SparkTransformation, instance)

def __init__(
self,
mode: Union[TransformationMode, str],
udf: Any,
udf_string: str,
spark_config: Dict[str, Any] = {},
name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
description: str = "",
owner: str = "",
*args,
**kwargs,
):
super().__init__(
mode=mode,
udf=udf,
name=name,
udf_string=udf_string,
tags=tags,
description=description,
owner=owner,
)
self.spark_session = get_or_create_new_spark_session(spark_config)

def transform(
self,
*inputs: Union[str, pd.DataFrame],
) -> pd.DataFrame:
if self.mode == TransformationMode.SPARK_SQL:
return self._transform_spark_sql(*inputs)
else:
return self._transform_spark_udf(*inputs)

@staticmethod
def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame, name: str):
df_temp_view = f"feast_transformation_temp_view_{name}"
df.createOrReplaceTempView(df_temp_view)
return df_temp_view

def _transform_spark_sql(
self, *inputs: Union[pyspark.sql.DataFrame, str]
) -> pd.DataFrame:
inputs_str = [
self._create_temp_view_for_dataframe(v, f"index_{i}")
if isinstance(v, pyspark.sql.DataFrame)
else v
for i, v in enumerate(inputs)
]
return self.spark_session.sql(self.udf(*inputs_str))

def _transform_spark_udf(self, *inputs: Any) -> pd.DataFrame:
return self.udf(*inputs)

def infer_features(self, *args, **kwargs) -> Any:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pandas as pd

from feast.transformation.pandas_transformation import PandasTransformation


def pandas_udf(features_df: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["output1"] = features_df["feature1"]
df["output2"] = features_df["feature2"]
return df


def test_init_pandas_transformation():
transformation = PandasTransformation(udf=pandas_udf, udf_string="udf1")
features_df = pd.DataFrame.from_dict({"feature1": [1, 2], "feature2": [2, 3]})
transformed_df = transformation.transform(features_df)
assert transformed_df["output1"].values[0] == 1
assert transformed_df["output2"].values[1] == 3
104 changes: 104 additions & 0 deletions sdk/python/tests/unit/transformation/test_spark_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest.mock import patch

import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, regexp_replace
from pyspark.testing.utils import assertDataFrameEqual

from feast.transformation.base import Transformation
from feast.transformation.mode import TransformationMode
from feast.transformation.spark_transformation import SparkTransformation


def get_sample_df(spark):
sample_data = [
{"name": "John D.", "age": 30},
{"name": "Alice G.", "age": 25},
{"name": "Bob T.", "age": 35},
{"name": "Eve A.", "age": 28},
]
df = spark.createDataFrame(sample_data)
return df


def get_expected_df(spark):
expected_data = [
{"name": "John D.", "age": 30},
{"name": "Alice G.", "age": 25},
{"name": "Bob T.", "age": 35},
{"name": "Eve A.", "age": 28},
]

expected_df = spark.createDataFrame(expected_data)
return expected_df


def remove_extra_spaces(df, column_name):
df_transformed = df.withColumn(
column_name, regexp_replace(col(column_name), "\\s+", " ")
)
return df_transformed


def remove_extra_spaces_sql(df, column_name):
sql = f"""
SELECT
age,
regexp_replace({column_name}, '\\\\s+', ' ') as {column_name}
FROM {df}
"""
return sql


@pytest.fixture
def spark_fixture():
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
yield spark


@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session")
def test_spark_transformation(spark_fixture):
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
df = get_sample_df(spark)

spark_transformation = Transformation(
mode=TransformationMode.SPARK,
udf=remove_extra_spaces,
udf_string="remove extra spaces",
)

transformed_df = spark_transformation.transform(df, "name")
expected_df = get_expected_df(spark)
assertDataFrameEqual(transformed_df, expected_df)


@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session")
def test_spark_transformation_init_transformation(spark_fixture):
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
df = get_sample_df(spark)

spark_transformation = SparkTransformation(
mode=TransformationMode.SPARK,
udf=remove_extra_spaces,
udf_string="remove extra spaces",
)

transformed_df = spark_transformation.transform(df, "name")
expected_df = get_expected_df(spark)
assertDataFrameEqual(transformed_df, expected_df)


@patch("feast.infra.compute_engines.spark.utils.get_or_create_new_spark_session")
def test_spark_transformation_sql(spark_fixture):
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
df = get_sample_df(spark)

spark_transformation = SparkTransformation(
mode=TransformationMode.SPARK_SQL,
udf=remove_extra_spaces_sql,
udf_string="remove extra spaces sql",
)

transformed_df = spark_transformation.transform(df, "name")
expected_df = get_expected_df(spark)
assertDataFrameEqual(transformed_df, expected_df)
Loading
Loading