|
| 1 | +import base64 |
1 | 2 | from datetime import datetime |
2 | 3 | from pathlib import Path |
3 | 4 | from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union |
@@ -247,7 +248,7 @@ def online_write_batch( |
247 | 248 | ) -> None: |
248 | 249 | self.client = self._connect(config) |
249 | 250 | collection = self._get_or_create_collection(config, table) |
250 | | - vector_cols = [f.name for f in table.features if f.vector_index] |
| 251 | + vector_cols = [f.name for f in table.schema if f.vector_index] |
251 | 252 | entity_batch_to_insert = [] |
252 | 253 | unique_entities: dict[str, dict[str, Any]] = {} |
253 | 254 | required_fields = {field["name"] for field in collection["fields"]} |
@@ -503,6 +504,14 @@ def retrieve_online_documents_v2( |
503 | 504 | entity_name_feast_primitive_type_map = { |
504 | 505 | k.name: k.dtype for k in table.entity_columns |
505 | 506 | } |
| 507 | + # Also include feature columns for proper type mapping |
| 508 | + feature_name_feast_primitive_type_map = { |
| 509 | + k.name: k.dtype for k in table.features |
| 510 | + } |
| 511 | + field_name_feast_primitive_type_map = { |
| 512 | + **entity_name_feast_primitive_type_map, |
| 513 | + **feature_name_feast_primitive_type_map, |
| 514 | + } |
506 | 515 | self.client = self._connect(config) |
507 | 516 | collection_name = _table_id(config.project, table) |
508 | 517 | collection = self._get_or_create_collection(config, table) |
@@ -662,14 +671,25 @@ def retrieve_online_documents_v2( |
662 | 671 | embedding |
663 | 672 | ) |
664 | 673 | res[ann_search_field] = serialized_embedding |
665 | | - elif entity_name_feast_primitive_type_map.get( |
666 | | - field, PrimitiveFeastType.INVALID |
667 | | - ) in [ |
668 | | - PrimitiveFeastType.STRING, |
669 | | - PrimitiveFeastType.BYTES, |
670 | | - ]: |
| 674 | + elif ( |
| 675 | + field_name_feast_primitive_type_map.get( |
| 676 | + field, PrimitiveFeastType.INVALID |
| 677 | + ) |
| 678 | + == PrimitiveFeastType.STRING |
| 679 | + ): |
671 | 680 | res[field] = ValueProto(string_val=str(field_value)) |
672 | | - elif entity_name_feast_primitive_type_map.get( |
| 681 | + elif ( |
| 682 | + field_name_feast_primitive_type_map.get( |
| 683 | + field, PrimitiveFeastType.INVALID |
| 684 | + ) |
| 685 | + == PrimitiveFeastType.BYTES |
| 686 | + ): |
| 687 | + try: |
| 688 | + decoded_bytes = base64.b64decode(field_value) |
| 689 | + res[field] = ValueProto(bytes_val=decoded_bytes) |
| 690 | + except Exception: |
| 691 | + res[field] = ValueProto(string_val=str(field_value)) |
| 692 | + elif field_name_feast_primitive_type_map.get( |
673 | 693 | field, PrimitiveFeastType.INVALID |
674 | 694 | ) in [ |
675 | 695 | PrimitiveFeastType.INT64, |
@@ -732,9 +752,13 @@ def _extract_proto_values_to_dict( |
732 | 752 | else: |
733 | 753 | if ( |
734 | 754 | serialize_to_string |
735 | | - and proto_val_type not in ["string_val"] + numeric_types |
| 755 | + and proto_val_type |
| 756 | + not in ["string_val", "bytes_val"] + numeric_types |
736 | 757 | ): |
737 | 758 | vector_values = feature_values.SerializeToString().decode() |
| 759 | + elif proto_val_type == "bytes_val": |
| 760 | + byte_data = getattr(feature_values, proto_val_type) |
| 761 | + vector_values = base64.b64encode(byte_data).decode("utf-8") |
738 | 762 | else: |
739 | 763 | if not isinstance(feature_values, str): |
740 | 764 | vector_values = str( |
|
0 commit comments