Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,25 +687,48 @@ def _get_unique_entities(
entity_name_to_join_key_map,
join_key_values,
)
# Validate that all expected join keys exist and have non-empty values.
expected_keys = set(entity_name_to_join_key_map.values())
expected_keys.discard("__dummy_id")
missing_keys = sorted(
list(set([key for key in expected_keys if key not in table_entity_values]))
)
empty_keys = sorted(
list(set([key for key in expected_keys if not table_entity_values.get(key)]))
)

# Convert back to rowise.
keys = table_entity_values.keys()
# Sort the rowise data to allow for grouping but keep original index. This lambda is
# sufficient as Entity types cannot be complex (ie. lists).
if missing_keys or empty_keys:
if not any(table_entity_values.values()):
raise KeyError(
f"Missing join key values for keys: {missing_keys}. "
f"No values provided for keys: {empty_keys}. "
f"Provided join_key_values: {list(join_key_values.keys())}"
)

# Convert the column-oriented table_entity_values into row-wise data.
keys = list(table_entity_values.keys())
# Each row is a tuple of ValueProto objects corresponding to the join keys.
rowise = list(enumerate(zip(*table_entity_values.values())))

# If there are no rows, return empty tuples.
if not rowise:
return (), ()

# Sort rowise so that rows with the same join key values are adjacent.
rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1]))

# Identify unique entities and the indexes at which they occur.
unique_entities: Tuple[Dict[str, ValueProto], ...]
indexes: Tuple[List[int], ...]
unique_entities, indexes = tuple(
zip(
*[
(dict(zip(keys, k)), [_[0] for _ in g])
for k, g in itertools.groupby(rowise, key=lambda x: x[1])
]
)
)
# Group rows by their composite join key value.
groups = [
(dict(zip(keys, key_tuple)), [idx for idx, _ in group])
for key_tuple, group in itertools.groupby(rowise, key=lambda row: row[1])
]

# If no groups were formed (should not happen for valid input), return empty tuples.
if not groups:
return (), ()

# Unpack the unique entities and their original row indexes.
unique_entities, indexes = tuple(zip(*groups))
return unique_entities, indexes


Expand Down
15 changes: 15 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ def test_get_online_features() -> None:

assert "trips" in result

with pytest.raises(KeyError) as excinfo:
_ = store.get_online_features(
features=["driver_locations:lon"],
entity_rows=[{"customer_id": 0}],
full_feature_names=False,
).to_dict()

error_message = str(excinfo.value)
assert "Missing join key values for keys:" in error_message
assert (
"Missing join key values for keys: ['customer_id', 'driver_id', 'item_id']."
in error_message
)
assert "Provided join_key_values: ['customer_id']" in error_message

result = store.get_online_features(
features=["customer_profile_pandas_odfv:on_demand_age"],
entity_rows=[{"driver_id": 1, "customer_id": "5"}],
Expand Down
88 changes: 84 additions & 4 deletions sdk/python/tests/unit/test_unit_feature_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import Dict, List

import pytest

from feast import utils
from feast.protos.feast.types.Value_pb2 import Value

Expand All @@ -17,7 +19,7 @@ class MockFeatureView:
projection: MockFeatureViewProjection


def test_get_unique_entities():
def test_get_unique_entities_success():
entity_values = {
"entity_1": [Value(int64_val=1), Value(int64_val=2), Value(int64_val=1)],
"entity_2": [
Expand All @@ -41,9 +43,87 @@ def test_get_unique_entities():
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
)

assert unique_entities == (
expected_entities = (
{"entity_1": Value(int64_val=1), "entity_2": Value(string_val="1")},
{"entity_1": Value(int64_val=2), "entity_2": Value(string_val="2")},
)
assert indexes == ([0, 2], [1])
expected_indexes = ([0, 2], [1])

assert unique_entities == expected_entities
assert indexes == expected_indexes


def test_get_unique_entities_missing_join_key_success():
"""
Tests that _get_unique_entities raises a KeyError when a required join key is missing.
"""
# Here, we omit the required key for "entity_1"
entity_values = {
"entity_2": [
Value(string_val="1"),
Value(string_val="2"),
Value(string_val="1"),
],
}

entity_name_to_join_key_map = {"entity_1": "entity_1", "entity_2": "entity_2"}

fv = MockFeatureView(
name="fv_1",
entities=["entity_1", "entity_2"],
projection=MockFeatureViewProjection(join_key_map={}),
)

unique_entities, indexes = utils._get_unique_entities(
table=fv,
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
)
expected_entities = (
{"entity_2": Value(string_val="1")},
{"entity_2": Value(string_val="2")},
)
expected_indexes = ([0, 2], [1])

assert unique_entities == expected_entities
assert indexes == expected_indexes
# We're not say anything about the entity_1 missing from the unique_entities list
assert "entity_1" not in [entity.keys() for entity in unique_entities]


def test_get_unique_entities_missing_all_join_keys_error():
"""
Tests that _get_unique_entities raises a KeyError when all required join keys are missing.
"""
entity_values_not_in_feature_view = {
"entity_3": [Value(string_val="3")],
}
entity_name_to_join_key_map = {
"entity_1": "entity_1",
"entity_2": "entity_2",
"entity_3": "entity_3",
}

fv = MockFeatureView(
name="fv_1",
entities=["entity_1", "entity_2"],
projection=MockFeatureViewProjection(join_key_map={}),
)

with pytest.raises(KeyError) as excinfo:
utils._get_unique_entities(
table=fv,
join_key_values=entity_values_not_in_feature_view,
entity_name_to_join_key_map=entity_name_to_join_key_map,
)

error_message = str(excinfo.value)
assert (
"Missing join key values for keys: ['entity_1', 'entity_2', 'entity_3']"
in error_message
)
assert (
"No values provided for keys: ['entity_1', 'entity_2', 'entity_3']"
in error_message
)
assert "Provided join_key_values: ['entity_3']" in error_message
Loading