Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion sdk/python/tests/unit/test_on_demand_python_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,13 @@ def setUp(self):
Field(name="avg_daily_trip_rank_names", dtype=Array(String)),
],
)
input_request = RequestSource(
name="vals_to_add",
schema=[
Field(name="val_to_add", dtype=Int64),
Field(name="val_to_add_2", dtype=Int64),
],
)

@on_demand_feature_view(
sources=[request_source, driver_stats_fv],
Expand Down Expand Up @@ -476,8 +483,37 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
output["achieved_ranks"] = ranks
return output

@on_demand_feature_view(
sources=[
driver_stats_fv,
input_request,
],
schema=[
Field(name="conv_rate_plus_val1", dtype=Float64),
Field(name="conv_rate_plus_val2", dtype=Float64),
],
mode="pandas",
)
def pandas_view(features_df: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
df["conv_rate_plus_val1"] = (
features_df["conv_rate"] + features_df["val_to_add"]
)
df["conv_rate_plus_val2"] = (
features_df["conv_rate"] + features_df["val_to_add_2"]
)
return df

self.store.apply(
[driver, driver_stats_source, driver_stats_fv, python_view]
[
driver,
driver_stats_source,
driver_stats_fv,
python_view,
pandas_view,
input_request,
request_source,
]
)
fv_applied = self.store.get_feature_view("driver_hourly_stats")
assert fv_applied.entities == [driver.name]
Expand All @@ -488,6 +524,121 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
feature_view_name="driver_hourly_stats", df=driver_df
)

batch_sample = pd.DataFrame(driver_entities, columns=["driver_id"])
batch_sample["val_to_add"] = 0
batch_sample["val_to_add_2"] = 1
batch_sample["event_timestamp"] = start_date
batch_sample["created"] = start_date
fv_only_cols = ["driver_id", "event_timestamp", "created"]

resp_base_fv = self.store.get_historical_features(
entity_df=batch_sample[fv_only_cols],
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
],
).to_df()
assert resp_base_fv is not None
assert sorted(resp_base_fv.columns) == [
"acc_rate",
"avg_daily_trips",
"conv_rate",
"created__",
"driver_id",
"event_timestamp",
]
resp = self.store.get_historical_features(
entity_df=batch_sample,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"pandas_view:conv_rate_plus_val1",
"pandas_view:conv_rate_plus_val2",
],
).to_df()
assert resp is not None
assert resp["conv_rate_plus_val1"].isnull().sum() == 0

# Now testing feature retrieval for driver ids not in the dataset
missing_batch_sample = pd.DataFrame([1234567890], columns=["driver_id"])
missing_batch_sample["val_to_add"] = 0
missing_batch_sample["val_to_add_2"] = 1
missing_batch_sample["event_timestamp"] = start_date
missing_batch_sample["created"] = start_date
resp_offline = self.store.get_historical_features(
entity_df=missing_batch_sample,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"pandas_view:conv_rate_plus_val1",
"pandas_view:conv_rate_plus_val2",
],
).to_df()
assert resp_offline is not None
assert resp_offline["conv_rate_plus_val1"].isnull().sum() == 1
assert sorted(resp_offline.columns) == [
"acc_rate",
"avg_daily_trips",
"conv_rate",
"conv_rate_plus_val1",
"conv_rate_plus_val2",
"created__",
"driver_id",
"event_timestamp",
"val_to_add",
"val_to_add_2",
]
with pytest.raises(TypeError):
_ = self.store.get_online_features(
entity_rows=[
{"driver_id": 1234567890, "val_to_add": 0, "val_to_add_2": 1}
],
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"pandas_view:conv_rate_plus_val1",
"pandas_view:conv_rate_plus_val2",
],
)
resp_online = self.store.get_online_features(
entity_rows=[{"driver_id": 1001, "val_to_add": 0, "val_to_add_2": 1}],
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"pandas_view:conv_rate_plus_val1",
"pandas_view:conv_rate_plus_val2",
],
).to_df()
assert resp_online is not None
assert sorted(resp_online.columns) == [
"acc_rate",
"avg_daily_trips",
"conv_rate",
"conv_rate_plus_val1",
"conv_rate_plus_val2",
"driver_id",
# It does not have the items below
# "created__",
# "event_timestamp",
# "val_to_add",
# "val_to_add_2",
]
# Note online and offline columns will not match because:
# you want to be space efficient online when considering the impact of network latency so you want to send
# and receive the minimally required set of data, which means after transformation you only need to send the
# output in the response.
# Offline, you will probably prioritize reproducibility and being able to iterate, which means you will want
# the underlying inputs into your transformation, so the extra data is tolerable.
assert sorted(resp_online.columns) != sorted(resp_offline.columns)

def test_setup(self):
pass

def test_python_transformation_returning_all_data_types(self):
entity_rows = [
{
Expand Down