Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/online-stores/elasticsearch.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ top_k = 5
# Retrieve the top k closest features to the query vector

feature_values = feature_store.retrieve_online_documents(
feature="my_feature",
features=["my_feature"],
query=query_vector,
top_k=top_k
)
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/online-stores/qdrant.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ top_k = 5
# the vector to use can be specified in the repo config.
# Reference: https://qdrant.tech/documentation/concepts/vectors/#named-vectors
feature_values = feature_store.retrieve_online_documents(
feature="my_feature",
features=["my_feature"],
query=query_vector,
top_k=top_k
)
Expand Down
88 changes: 38 additions & 50 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,19 +1831,15 @@ async def get_online_features_async(

def retrieve_online_documents(
self,
feature: Optional[str],
query: Union[str, List[float]],
top_k: int,
features: Optional[List[str]] = None,
features: List[str],
distance_metric: Optional[str] = "L2",
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.

Args:
feature: The list of document features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
features: The list of features that should be retrieved from the online store.
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
Expand All @@ -1853,68 +1849,55 @@ def retrieve_online_documents(
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)
feature_list: List[str] = (
features
if features is not None
else ([feature] if feature is not None else [])
)

(
available_feature_views,
_,
) = utils._get_feature_views_to_use(
registry=self._registry,
project=self.project,
features=feature_list,
features=features,
allow_cache=True,
hide_dummy_entity=False,
)
if features:
feature_view_set = set()
for feature in features:
feature_view_name = feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError(
"Document retrieval only supports a single feature view."
)
requested_feature = None
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]
else:
requested_feature = (
feature.split(":")[1] if isinstance(feature, str) else feature
)
requested_features = [requested_feature] if requested_feature else []

requested_feature_view_name = (
feature.split(":")[0] if feature else list(feature_view_set)[0]
)
feature_view_set = set()
for _feature in features:
feature_view_name = _feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError("Document retrieval only supports a single feature view.")
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]
requested_feature_view_name = list(feature_view_set)[0]
for feature_view in available_feature_views:
if feature_view.name == requested_feature_view_name:
requested_feature_view = feature_view
if not requested_feature_view:
break
else:
raise ValueError(
f"Feature view {requested_feature_view} not found in the registry."
)

requested_feature_view = available_feature_views[0]

provider = self._get_provider()
document_features = self._retrieve_from_online_store(
provider,
requested_feature_view,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
)

# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
entity_key_vals = [feature[1] for feature in document_features]
def _doc_feature(x):
return [feature[x] for feature in document_features]

entity_key_vals, document_feature_vals, document_feature_distance_vals = map(
_doc_feature, (1, 4, 5)
)
join_key_values: Dict[str, List[ValueProto]] = {}
for entity_key_val in entity_key_vals:
if entity_key_val is not None:
Expand All @@ -1924,18 +1907,25 @@ def retrieve_online_documents(
if join_key not in join_key_values:
join_key_values[join_key] = []
join_key_values[join_key].append(entity_value)

document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
requested_feature = requested_feature or requested_features[0]
if vector_field_metadata := _get_feature_view_vector_field_metadata(
requested_feature_view
):
vector_field_name = vector_field_metadata.name
data = {
**join_key_values,
vector_field_name: document_feature_vals,
"distance": document_feature_distance_vals,
}
_requested_features = [_feature.split(":")[-1] for _feature in features]
requested_features_data = {
_feature: data[_feature]
for _feature in _requested_features
if _feature in data
}
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={
**join_key_values,
requested_feature: document_feature_vals,
"distance": document_feature_distance_vals,
},
data=requested_features_data,
)
return OnlineResponse(online_features_response)

Expand Down Expand Up @@ -2012,7 +2002,6 @@ def _retrieve_from_online_store(
self,
provider: Provider,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
Expand All @@ -2032,7 +2021,6 @@ def _retrieve_from_online_store(
documents = provider.retrieve_online_documents(
config=self.config,
table=table,
requested_feature=requested_feature,
requested_features=requested_features,
query=query,
top_k=top_k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
*args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_featres: Optional[List[str]],
requested_featres: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand Down
10 changes: 3 additions & 7 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -413,7 +412,6 @@ def retrieve_online_documents(
distance_metric: distance metric to use for retrieval.
config: The config for the current feature store.
table: The feature view whose feature values should be read.
requested_feature: The name of the feature whose embeddings should be used for retrieval.
requested_features: The list of features whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of documents to retrieve.
Expand All @@ -423,10 +421,8 @@ def retrieve_online_documents(
where the first item is the event timestamp for the row, and the second item is a dict of feature
name to embeddings.
"""
if not requested_feature and not requested_features:
raise ValueError(
"Either requested_feature or requested_features must be specified"
)
if not requested_features:
raise ValueError("Requested_features must be specified")
raise NotImplementedError(
f"Online store {self.__class__.__name__} does not support online retrieval"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
embedding: List[float],
top_k: int,
Expand All @@ -373,7 +372,6 @@ def retrieve_online_documents(
Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_feature: The requested feature as the column to search
requested_features: The list of features whose embeddings should be used for retrieval.
embedding: The query embedding to search for
top_k: The number of items to return
Expand All @@ -394,6 +392,11 @@ def retrieve_online_documents(
f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}"
)

if requested_features:
required_feature_names = ", ".join(
[feature for feature in requested_features]
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]

result: List[
Expand All @@ -415,19 +418,18 @@ def retrieve_online_documents(
"""
SELECT
entity_key,
feature_name,
{feature_names},
value,
vector_value,
vector_value {distance_metric_sql} %s::vector as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
distance_metric_sql=sql.SQL(distance_metric_sql),
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
feature_names=required_feature_names,
top_k=sql.Literal(top_k),
),
(embedding,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = "cosine",
Expand Down
5 changes: 2 additions & 3 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_featuers: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -341,7 +340,7 @@ def retrieve_online_documents(
Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_feature: The requested feature as the column to search
requested_features: The list of requested features to retrieve
embedding: The query embedding to search for
top_k: The number of items to return
Returns:
Expand Down
2 changes: 0 additions & 2 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
Expand All @@ -305,7 +304,6 @@ def retrieve_online_documents(
result = self.online_store.retrieve_online_documents(
config,
table,
requested_feature,
requested_features,
query,
top_k,
Expand Down
2 changes: 0 additions & 2 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
Expand All @@ -440,7 +439,6 @@ def retrieve_online_documents(
distance_metric: distance metric to use for the search.
config: The config for the current feature store.
table: The feature view whose embeddings should be searched.
requested_feature: the requested document feature name.
requested_features: the requested document feature names.
query: The query embedding to search for.
top_k: The number of documents to return.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def test_retrieve_online_documents(environment, fake_document_data):
fs.write_to_online_store("item_embeddings", df)

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
features=["item_embeddings:embedding_float"],
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
Expand All @@ -881,7 +881,7 @@ def test_retrieve_online_documents(environment, fake_document_data):
assert len(documents["item_id"]) == 2

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
features=["item_embeddings:embedding_float"],
query=[1.0, 2.0],
top_k=2,
distance_metric="L1",
Expand All @@ -890,7 +890,7 @@ def test_retrieve_online_documents(environment, fake_document_data):

with pytest.raises(ValueError):
fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
features=["item_embeddings:embedding_float"],
query=[1.0, 2.0],
top_k=2,
distance_metric="wrong",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ def test_sqlite_get_online_documents() -> None:
vector_length,
)
result = store.retrieve_online_documents(
feature="document_embeddings:Embeddings", query=query_embedding, top_k=3
query=query_embedding,
top_k=3,
features=["document_embeddings:Embeddings", "document_embeddings:distance"],
).to_dict()

assert "Embeddings" in result
Expand Down
Loading