Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 84 additions & 38 deletions sdk/python/feast/infra/online_stores/faiss_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,30 @@ def teardown(self):
self.entity_keys = {}


def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
"""Compute the table key, including version suffix when versioning is enabled."""
name = table.name
if enable_versioning:
# Prefer version_tag from the projection (set by version-qualified refs like @v2)
# over current_version_number (the FV's active version in metadata).
version = getattr(table.projection, "version_tag", None)
if version is None:
version = getattr(table, "current_version_number", None)
if version is not None and version > 0:
name = f"{table.name}_v{version}"
return f"{project}_{name}"


class FaissOnlineStore(OnlineStore):
_index: Optional[faiss.IndexIVFFlat] = None
_in_memory_store: InMemoryStore = InMemoryStore()
_indices: Dict[str, faiss.IndexIVFFlat] = {}
_in_memory_stores: Dict[str, InMemoryStore] = {}
_config: Optional[FaissOnlineStoreConfig] = None
_logger: logging.Logger = logging.getLogger(__name__)

def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat:
if self._index is None or self._config is None:
raise ValueError("Index is not initialized")
return self._index
def _get_index(
self, table_key: str
) -> Optional[faiss.IndexIVFFlat]:
return self._indices.get(table_key)

def update(
self,
Expand All @@ -63,32 +77,47 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
feature_views = tables_to_keep
if not feature_views:
return

feature_names = [f.name for f in feature_views[0].features]
dimension = len(feature_names)

self._config = FaissOnlineStoreConfig(**config.online_store.dict())
if self._index is None or not partial:
quantizer = faiss.IndexFlatL2(dimension)
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
self._index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
)
self._in_memory_store = InMemoryStore()
versioning = config.registry.enable_online_feature_view_versioning

for table in tables_to_delete:
table_key = _table_id(config.project, table, versioning)
self._indices.pop(table_key, None)
self._in_memory_stores.pop(table_key, None)

for table in tables_to_keep:
table_key = _table_id(config.project, table, versioning)
feature_names = [f.name for f in table.features]
dimension = len(feature_names)

if table_key not in self._indices or not partial:
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFFlat(
quantizer, dimension, self._config.nlist
)
index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(
np.float32
)
)
self._indices[table_key] = index
self._in_memory_stores[table_key] = InMemoryStore()

self._in_memory_store.update(feature_names, {})
self._in_memory_stores[table_key].update(feature_names, {})

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
self._index = None
self._in_memory_store.teardown()
versioning = config.registry.enable_online_feature_view_versioning
for table in tables:
table_key = _table_id(config.project, table, versioning)
self._indices.pop(table_key, None)
store = self._in_memory_stores.pop(table_key, None)
if store is not None:
store.teardown()

def online_read(
self,
Expand All @@ -97,23 +126,28 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
if self._index is None:
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)
in_memory_store = self._in_memory_stores.get(table_key)

if index is None or in_memory_store is None:
return [(None, None)] * len(entity_keys)

results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []
for entity_key in entity_keys:
serialized_key = serialize_entity_key(
entity_key, config.entity_key_serialization_version
).hex()
idx = self._in_memory_store.entity_keys.get(serialized_key, -1)
idx = in_memory_store.entity_keys.get(serialized_key, -1)
if idx == -1:
results.append((None, None))
else:
feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))
feature_dict = {
name: ValueProto(double_val=value)
for name, value in zip(
self._in_memory_store.feature_names, feature_vector
in_memory_store.feature_names, feature_vector
)
}
results.append((None, feature_dict))
Expand All @@ -128,8 +162,16 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
if self._index is None:
self._logger.warning("Index is not initialized. Skipping write operation.")
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)
in_memory_store = self._in_memory_stores.get(table_key)

if index is None or in_memory_store is None:
self._logger.warning(
"Index for table '%s' is not initialized. Skipping write operation.",
table_key,
)
return

feature_vectors = []
Expand All @@ -142,7 +184,7 @@ def online_write_batch(
feature_vector = np.array(
[
feature_dict[name].double_val
for name in self._in_memory_store.feature_names
for name in in_memory_store.feature_names
],
dtype=np.float32,
)
Expand All @@ -153,21 +195,21 @@ def online_write_batch(
feature_vectors_array = np.array(feature_vectors)

existing_indices = [
self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
]
mask = np.array(existing_indices) != -1
if np.any(mask):
self._index.remove_ids(
index.remove_ids(
np.array([idx for idx in existing_indices if idx != -1])
)

new_indices = np.arange(
self._index.ntotal, self._index.ntotal + len(feature_vectors_array)
index.ntotal, index.ntotal + len(feature_vectors_array)
)
self._index.add(feature_vectors_array)
index.add(feature_vectors_array)

for sk, idx in zip(serialized_keys, new_indices):
self._in_memory_store.entity_keys[sk] = idx
in_memory_store.entity_keys[sk] = idx

if progress:
progress(len(data))
Expand All @@ -189,12 +231,16 @@ def retrieve_online_documents(
Optional[ValueProto],
]
]:
if self._index is None:
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)

if index is None:
self._logger.warning("Index is not initialized. Returning empty result.")
return []

query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1)
distances, indices = self._index.search(query_vector, top_k)
distances, indices = index.search(query_vector, top_k)

results: List[
Tuple[
Expand All @@ -209,7 +255,7 @@ def retrieve_online_documents(
if idx == -1:
continue

feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))

timestamp = Timestamp()
timestamp.GetCurrentTime()
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ def get_online_features(

def _check_versioned_read_support(self, grouped_refs):
"""Raise an error if versioned reads are attempted on unsupported stores."""
from feast.infra.online_stores.faiss_online_store import FaissOnlineStore
from feast.infra.online_stores.sqlite import SqliteOnlineStore

if isinstance(self, SqliteOnlineStore):
if isinstance(self, (SqliteOnlineStore, FaissOnlineStore)):
Comment on lines +259 to +262
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Importing optional faiss dependency in base OnlineStore breaks all online stores when faiss is not installed

The PR adds from feast.infra.online_stores.faiss_online_store import FaissOnlineStore at online_store.py:259 inside _check_versioned_read_support. This method is called from get_online_features (online_store.py:191) and get_online_features_async (online_store.py:315) for every online store (Redis, DynamoDB, Bigtable, etc.). Since faiss_online_store.py:5 has a top-level import faiss and faiss-cpu is an optional dependency (pyproject.toml:64: faiss = ["faiss-cpu>=1.7.0,<=1.10.0"]), this import will raise ModuleNotFoundError for any user who hasn't installed the faiss extra. The pre-existing SqliteOnlineStore import was safe because sqlite.py only uses stdlib/core imports. The fix should wrap this import in a try/except ImportError and handle the case where FaissOnlineStore is unavailable.

Suggested change
from feast.infra.online_stores.faiss_online_store import FaissOnlineStore
from feast.infra.online_stores.sqlite import SqliteOnlineStore
if isinstance(self, SqliteOnlineStore):
if isinstance(self, (SqliteOnlineStore, FaissOnlineStore)):
try:
from feast.infra.online_stores.faiss_online_store import FaissOnlineStore
except ImportError:
FaissOnlineStore = None
from feast.infra.online_stores.sqlite import SqliteOnlineStore
supported = [SqliteOnlineStore]
if FaissOnlineStore is not None:
supported.append(FaissOnlineStore)
if isinstance(self, tuple(supported)):
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

return
for table, _ in grouped_refs:
version_tag = getattr(table.projection, "version_tag", None)
Expand Down
Loading