Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 19 additions & 45 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,11 @@ def to_df(
validation_reference (optional): The validation to apply against the retrieved dataframe.
timeout (optional): The query timeout if applicable.
"""
features_df = self._to_df_internal(timeout=timeout)

if self.on_demand_feature_views:
# TODO(adchia): Fix requirement to specify dependent feature views in feature_refs
for odfv in self.on_demand_feature_views:
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
features_df = features_df.join(
odfv.get_transformed_features_df(
features_df,
self.full_feature_names,
)
)

if validation_reference:
if not flags_helper.is_test():
warnings.warn(
"Dataset validation is an experimental feature. "
"This API is unstable and it could and most probably will be changed in the future. "
"We do not guarantee that future changes will maintain backward compatibility.",
RuntimeWarning,
)

validation_result = validation_reference.profile.validate(features_df)
if not validation_result.is_success:
raise ValidationFailed(validation_result)

return features_df
return (
self.to_arrow(validation_reference=validation_reference, timeout=timeout)
.to_pandas()
.reset_index(drop=True)
)

def to_arrow(
self,
Expand All @@ -122,23 +97,20 @@ def to_arrow(
validation_reference (optional): The validation to apply against the retrieved dataframe.
timeout (optional): The query timeout if applicable.
"""
if not self.on_demand_feature_views and not validation_reference:
return self._to_arrow_internal(timeout=timeout)

features_df = self._to_df_internal(timeout=timeout)
features_table = self._to_arrow_internal(timeout=timeout)
if self.on_demand_feature_views:
for odfv in self.on_demand_feature_views:
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
features_df = features_df.join(
odfv.get_transformed_features_df(
features_df,
self.full_feature_names,
)
transformed_arrow = odfv.transform_arrow(
features_table, self.full_feature_names
)

for col in transformed_arrow.column_names:
if col.startswith("__index"):
continue
features_table = features_table.append_column(
col, transformed_arrow[col]
)

if validation_reference:
if not flags_helper.is_test():
warnings.warn(
Expand All @@ -148,11 +120,13 @@ def to_arrow(
RuntimeWarning,
)

validation_result = validation_reference.profile.validate(features_df)
validation_result = validation_reference.profile.validate(
features_table.to_pandas()
)
if not validation_result.is_success:
raise ValidationFailed(validation_result)

return pyarrow.Table.from_pandas(features_df)
return features_table

def to_sql(self) -> str:
"""
Expand Down
55 changes: 55 additions & 0 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dill
import pandas as pd
import pyarrow
from typeguard import typechecked

from feast.base_feature_view import BaseFeatureView
Expand Down Expand Up @@ -391,6 +392,60 @@ def get_request_data_schema(self) -> Dict[str, ValueType]:
def _get_projected_feature_name(self, feature: str) -> str:
return f"{self.projection.name_to_use()}__{feature}"

def transform_arrow(
self,
pa_table: pyarrow.Table,
full_feature_names: bool = False,
) -> pyarrow.Table:
if not isinstance(pa_table, pyarrow.Table):
raise TypeError("transform_arrow only accepts pyarrow.Table")
columns_to_cleanup = []
for source_fv_projection in self.source_feature_view_projections.values():
for feature in source_fv_projection.features:
full_feature_ref = f"{source_fv_projection.name}__{feature.name}"
if full_feature_ref in pa_table.column_names:
# Make sure the partial feature name is always present
pa_table = pa_table.append_column(
feature.name, pa_table[full_feature_ref]
)
# pa_table[feature.name] = pa_table[full_feature_ref]
columns_to_cleanup.append(feature.name)
elif feature.name in pa_table.column_names:
# Make sure the full feature name is always present
# pa_table[full_feature_ref] = pa_table[feature.name]
pa_table = pa_table.append_column(
full_feature_ref, pa_table[feature.name]
)
columns_to_cleanup.append(full_feature_ref)

df_with_transformed_features: pyarrow.Table = (
self.feature_transformation.transform_arrow(pa_table)
)

# Work out whether the correct columns names are used.
rename_columns: Dict[str, str] = {}
for feature in self.features:
short_name = feature.name
long_name = self._get_projected_feature_name(feature.name)
if (
short_name in df_with_transformed_features.column_names
and full_feature_names
):
rename_columns[short_name] = long_name
elif not full_feature_names:
rename_columns[long_name] = short_name

# Cleanup extra columns used for transformation
for col in columns_to_cleanup:
if col in df_with_transformed_features.column_names:
df_with_transformed_features = df_with_transformed_features.dtop(col)
return df_with_transformed_features.rename_columns(
[
rename_columns.get(c, c)
for c in df_with_transformed_features.column_names
]
)

def get_transformed_features_df(
self,
df_with_features: pd.DataFrame,
Expand Down
14 changes: 14 additions & 0 deletions sdk/python/feast/transformation/pandas_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dill
import pandas as pd
import pyarrow

from feast.field import Field, from_value_type
from feast.protos.feast.core.Transformation_pb2 import (
Expand All @@ -26,6 +27,19 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
self.udf = udf
self.udf_string = udf_string

def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
if not isinstance(pa_table, pyarrow.Table):
raise TypeError(
f"pa_table should be type pyarrow.Table but got {type(pa_table).__name__}"
)
output_df = self.udf.__call__(pa_table.to_pandas())
output_df = pyarrow.Table.from_pandas(output_df)
if not isinstance(output_df, pyarrow.Table):
raise TypeError(
f"output_df should be type pyarrow.Table but got {type(output_df).__name__}"
)
return output_df

def transform(self, input_df: pd.DataFrame) -> pd.DataFrame:
if not isinstance(input_df, pd.DataFrame):
raise TypeError(
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/transformation/python_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List

import dill
import pyarrow

from feast.field import Field, from_value_type
from feast.protos.feast.core.Transformation_pb2 import (
Expand All @@ -24,6 +25,11 @@ def __init__(self, udf: FunctionType, udf_string: str = ""):
self.udf = udf
self.udf_string = udf_string

def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
raise Exception(
'OnDemandFeatureView mode "python" not supported for offline processing.'
)

def transform(self, input_dict: Dict) -> Dict:
if not isinstance(input_dict, Dict):
raise TypeError(
Expand Down
9 changes: 9 additions & 0 deletions sdk/python/feast/transformation/substrait_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def table_provider(names, schema: pyarrow.Schema):
).read_all()
return table.to_pandas()

def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table:
def table_provider(names, schema: pyarrow.Schema):
return pa_table.select(schema.names)

table: pyarrow.Table = pyarrow.substrait.run_query(
self.substrait_plan, table_provider=table_provider
).read_all()
return table

def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]:
df = pd.DataFrame.from_dict(random_input)
output_df: pd.DataFrame = self.transform(df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_to_sql():

@pytest.mark.parametrize("timeout", (None, 30))
def test_to_df_timeout(retrieval_job, timeout: Optional[int]):
with patch.object(retrieval_job, "_to_df_internal") as mock_to_df_internal:
with patch.object(retrieval_job, "_to_arrow_internal") as mock_to_df_internal:
retrieval_job.to_df(timeout=timeout)
mock_to_df_internal.assert_called_once_with(timeout=timeout)

Expand Down