@@ -179,7 +179,15 @@ def retrieve_online_documents(
179179 embedding : Optional [List [float ]],
180180 top_k : int ,
181181 distance_metric : Optional [str ] = "L2" ,
182- ) -> List [Tuple [Optional [datetime ], Optional [EntityKeyProto ], Optional [ValueProto ], Optional [ValueProto ], Optional [ValueProto ]]]:
182+ ) -> List [
183+ Tuple [
184+ Optional [datetime ],
185+ Optional [EntityKeyProto ],
186+ Optional [ValueProto ],
187+ Optional [ValueProto ],
188+ Optional [ValueProto ],
189+ ]
190+ ]:
183191 assert isinstance (config .online_store , RemoteOnlineStoreConfig )
184192 config .online_store .__class__ = RemoteOnlineStoreConfig
185193
@@ -190,18 +198,16 @@ def retrieve_online_documents(
190198 if response .status_code == 200 :
191199 logger .debug ("Able to retrieve the online documents from feature server." )
192200 response_json = json .loads (response .text )
193- event_ts = self ._get_event_ts (response_json )
201+ event_ts : Optional [ datetime ] = self ._get_event_ts (response_json )
194202
195203 # Create feature name to index mapping for efficient lookup
196204 feature_name_to_index = {
197- name : idx for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
205+ name : idx
206+ for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
198207 }
199208
200209 vector_field_metadata = _get_feature_view_vector_field_metadata (table )
201210
202- # Extract feature names once
203- feature_names = response_json ["metadata" ]["feature_names" ]
204-
205211 # Process each result row
206212 num_results = len (response_json ["results" ][0 ]["values" ])
207213 result_tuples = []
@@ -215,13 +221,21 @@ def retrieve_online_documents(
215221 response_json , feature_name_to_index , vector_field_metadata , row_idx
216222 )
217223 distance_val = self ._extract_distance_value (
218- response_json , feature_name_to_index , ' distance' , row_idx
224+ response_json , feature_name_to_index , " distance" , row_idx
219225 )
220226 entity_key_proto = self ._construct_entity_key_from_response (
221- response_json , row_idx , feature_name_to_index
227+ response_json , row_idx , feature_name_to_index , table
222228 )
223229
224- result_tuples .append ((event_ts , entity_key_proto , feature_val , vector_value , distance_val ))
230+ result_tuples .append (
231+ (
232+ event_ts ,
233+ entity_key_proto ,
234+ feature_val ,
235+ vector_value ,
236+ distance_val ,
237+ )
238+ )
225239
226240 return result_tuples
227241 else :
@@ -238,50 +252,77 @@ def retrieve_online_documents_v2(
238252 top_k : int ,
239253 distance_metric : Optional [str ] = None ,
240254 query_string : Optional [str ] = None ,
241- ) -> List [Tuple [Optional [datetime ], Optional [EntityKeyProto ], Optional [Dict [str , ValueProto ]]]]:
255+ ) -> List [
256+ Tuple [
257+ Optional [datetime ],
258+ Optional [EntityKeyProto ],
259+ Optional [Dict [str , ValueProto ]],
260+ ]
261+ ]:
242262 assert isinstance (config .online_store , RemoteOnlineStoreConfig )
243263 config .online_store .__class__ = RemoteOnlineStoreConfig
244264
245265 req_body = self ._construct_online_documents_v2_api_json_request (
246- table , requested_features , embedding , top_k , distance_metric , query_string , api_version = 2
266+ table ,
267+ requested_features ,
268+ embedding ,
269+ top_k ,
270+ distance_metric ,
271+ query_string ,
272+ api_version = 2 ,
247273 )
248274 response = get_remote_online_documents (config = config , req_body = req_body )
249275 if response .status_code == 200 :
250276 logger .debug ("Able to retrieve the online documents from feature server." )
251277 response_json = json .loads (response .text )
252- event_ts = self ._get_event_ts (response_json )
253-
278+ event_ts : Optional [ datetime ] = self ._get_event_ts (response_json )
279+
254280 # Create feature name to index mapping for efficient lookup
255281 feature_name_to_index = {
256- name : idx for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
282+ name : idx
283+ for idx , name in enumerate (response_json ["metadata" ]["feature_names" ])
257284 }
258285
259286 # Process each result row
260- num_results = len (response_json ["results" ][0 ]["values" ]) if response_json ["results" ] else 0
287+ num_results = (
288+ len (response_json ["results" ][0 ]["values" ])
289+ if response_json ["results" ]
290+ else 0
291+ )
261292 result_tuples = []
262293
263294 for row_idx in range (num_results ):
264295 # Build feature values dictionary for requested features
265- feature_values_dict : Dict [ str , ValueProto ] = {}
266-
296+ feature_values_dict = {}
297+
267298 if requested_features :
268299 for feature_name in requested_features :
269300 if feature_name in feature_name_to_index :
270301 feature_idx = feature_name_to_index [feature_name ]
271- if self ._is_feature_present (response_json , feature_idx , row_idx ):
272- feature_values_dict [feature_name ] = self ._extract_feature_value (
273- response_json , feature_idx , row_idx
302+ if self ._is_feature_present (
303+ response_json , feature_idx , row_idx
304+ ):
305+ feature_values_dict [feature_name ] = (
306+ self ._extract_feature_value (
307+ response_json , feature_idx , row_idx
308+ )
274309 )
275310 else :
276311 feature_values_dict [feature_name ] = ValueProto ()
277312
278313 # Construct entity key proto using existing helper method
279314 entity_key_proto = self ._construct_entity_key_from_response (
280- response_json , row_idx , feature_name_to_index
315+ response_json , row_idx , feature_name_to_index , table
316+ )
317+
318+ result_tuples .append (
319+ (
320+ event_ts ,
321+ entity_key_proto ,
322+ feature_values_dict if feature_values_dict else None ,
323+ )
281324 )
282325
283- result_tuples .append ((event_ts , entity_key_proto , feature_values_dict ))
284-
285326 return result_tuples
286327 else :
287328 error_msg = f"Unable to retrieve the online documents using feature server API. Error_code={ response .status_code } , error_message={ response .text } "
@@ -293,8 +334,8 @@ def _extract_requested_feature_value(
293334 response_json : dict ,
294335 feature_name_to_index : dict ,
295336 requested_features : Optional [List [str ]],
296- row_idx : int
297- ) -> ValueProto :
337+ row_idx : int ,
338+ ) -> Optional [ ValueProto ] :
298339 """Extract the first available requested feature value."""
299340 if not requested_features :
300341 return ValueProto ()
@@ -303,7 +344,9 @@ def _extract_requested_feature_value(
303344 if feature_name in feature_name_to_index :
304345 feature_idx = feature_name_to_index [feature_name ]
305346 if self ._is_feature_present (response_json , feature_idx , row_idx ):
306- return self ._extract_feature_value (response_json , feature_idx , row_idx )
347+ return self ._extract_feature_value (
348+ response_json , feature_idx , row_idx
349+ )
307350
308351 return ValueProto ()
309352
@@ -312,15 +355,20 @@ def _extract_vector_field_value(
312355 response_json : dict ,
313356 feature_name_to_index : dict ,
314357 vector_field_metadata ,
315- row_idx : int
316- ) -> ValueProto :
358+ row_idx : int ,
359+ ) -> Optional [ ValueProto ] :
317360 """Extract vector field value from response."""
318- if not vector_field_metadata or vector_field_metadata .name not in feature_name_to_index :
361+ if (
362+ not vector_field_metadata
363+ or vector_field_metadata .name not in feature_name_to_index
364+ ):
319365 return ValueProto ()
320366
321367 vector_feature_idx = feature_name_to_index [vector_field_metadata .name ]
322368 if self ._is_feature_present (response_json , vector_feature_idx , row_idx ):
323- return self ._extract_feature_value (response_json , vector_feature_idx , row_idx )
369+ return self ._extract_feature_value (
370+ response_json , vector_feature_idx , row_idx
371+ )
324372
325373 return ValueProto ()
326374
@@ -329,22 +377,26 @@ def _extract_distance_value(
329377 response_json : dict ,
330378 feature_name_to_index : dict ,
331379 distance_feature_name : str ,
332- row_idx : int
333- ) -> ValueProto :
380+ row_idx : int ,
381+ ) -> Optional [ ValueProto ] :
334382 """Extract distance/score value from response."""
335383 if not distance_feature_name :
336384 return ValueProto ()
337385
338386 distance_feature_idx = feature_name_to_index [distance_feature_name ]
339387 if self ._is_feature_present (response_json , distance_feature_idx , row_idx ):
340- distance_value = response_json ["results" ][distance_feature_idx ]["values" ][row_idx ]
388+ distance_value = response_json ["results" ][distance_feature_idx ]["values" ][
389+ row_idx
390+ ]
341391 distance_val = ValueProto ()
342392 distance_val .float_val = float (distance_value )
343393 return distance_val
344394
345395 return ValueProto ()
346396
347- def _is_feature_present (self , response_json : dict , feature_idx : int , row_idx : int ) -> bool :
397+ def _is_feature_present (
398+ self , response_json : dict , feature_idx : int , row_idx : int
399+ ) -> bool :
348400 """Check if a feature is present in the response."""
349401 return response_json ["results" ][feature_idx ]["statuses" ][row_idx ] == "PRESENT"
350402
@@ -432,12 +484,19 @@ def _get_event_ts(self, response_json) -> datetime:
432484 return datetime .fromisoformat (event_ts .replace ("Z" , "+00:00" ))
433485
434486 def _construct_entity_key_from_response (
435- self , response_json : dict , row_idx : int , feature_name_to_index : dict
487+ self ,
488+ response_json : dict ,
489+ row_idx : int ,
490+ feature_name_to_index : dict ,
491+ table : FeatureView ,
436492 ) -> Optional [EntityKeyProto ]:
437493 """Construct EntityKeyProto from response data."""
438- # Look for entity key fields in the response
439- entity_fields = [name for name in feature_name_to_index .keys ()
440- if name .endswith ('_id' ) or name in ['id' , 'key' , 'entity_id' ]]
494+ # Use the feature view's join_keys to identify entity fields
495+ entity_fields = [
496+ join_key
497+ for join_key in table .join_keys
498+ if join_key in feature_name_to_index
499+ ]
441500
442501 if not entity_fields :
443502 return None
@@ -449,12 +508,16 @@ def _construct_entity_key_from_response(
449508 if entity_field in feature_name_to_index :
450509 feature_idx = feature_name_to_index [entity_field ]
451510 if self ._is_feature_present (response_json , feature_idx , row_idx ):
452- entity_value = self ._extract_feature_value (response_json , feature_idx , row_idx )
511+ entity_value = self ._extract_feature_value (
512+ response_json , feature_idx , row_idx
513+ )
453514 entity_key_proto .entity_values .append (entity_value )
454515
455516 return entity_key_proto if entity_key_proto .entity_values else None
456517
457- def _extract_feature_value (self , response_json : dict , feature_idx : int , row_idx : int ) -> ValueProto :
518+ def _extract_feature_value (
519+ self , response_json : dict , feature_idx : int , row_idx : int
520+ ) -> ValueProto :
458521 """Extract and convert a feature value to ValueProto."""
459522 raw_value = response_json ["results" ][feature_idx ]["values" ][row_idx ]
460523 if raw_value is None :
0 commit comments