Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions sdk/python/feast/infra/registry_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Set, Union
from typing import Any, Callable, List, Optional, Set, Union

from sqlalchemy import ( # type: ignore
BigInteger,
Expand Down Expand Up @@ -560,7 +560,7 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True):
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return self._get_object(
infra_object = self._get_object(
managed_infra,
"infra_obj",
project,
Expand All @@ -570,6 +570,8 @@ def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
"infra_proto",
None,
)
infra_object = infra_object or InfraProto()
return Infra.from_proto(infra_object)

def apply_user_metadata(
self,
Expand Down Expand Up @@ -683,11 +685,18 @@ def commit(self):
pass

def _apply_object(
self, table, project: str, id_field_name, obj, proto_field_name, name=None
self,
table: Table,
project: str,
id_field_name,
obj,
proto_field_name,
name=None,
):
self._maybe_init_project_metadata(project)

name = name or obj.name
name = name or obj.name if hasattr(obj, "name") else None
assert name, f"name needs to be provided for {obj}"
with self.engine.connect() as conn:
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
Expand Down Expand Up @@ -749,7 +758,14 @@ def _maybe_init_project_metadata(self, project):
conn.execute(insert_stmt)
usage.set_current_project_uuid(new_project_uuid)

def _delete_object(self, table, name, project, id_field_name, not_found_exception):
def _delete_object(
self,
table: Table,
name: str,
project: str,
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.connect() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
Expand All @@ -763,14 +779,14 @@ def _delete_object(self, table, name, project, id_field_name, not_found_exceptio

def _get_object(
self,
table,
name,
project,
proto_class,
python_class,
id_field_name,
proto_field_name,
not_found_exception,
table: Table,
name: str,
project: str,
proto_class: Any,
python_class: Any,
id_field_name: str,
proto_field_name: str,
not_found_exception: Optional[Callable],
):
self._maybe_init_project_metadata(project)

Expand All @@ -782,10 +798,18 @@ def _get_object(
if row:
_proto = proto_class.FromString(row[proto_field_name])
return python_class.from_proto(_proto)
raise not_found_exception(name, project)
if not_found_exception:
raise not_found_exception(name, project)
else:
return None

def _list_objects(
self, table, project, proto_class, python_class, proto_field_name
self,
table: Table,
project: str,
proto_class: Any,
python_class: Any,
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with self.engine.connect() as conn:
Expand Down