Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 68 additions & 31 deletions sdk/python/feast/infra/registry_stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from sqlalchemy import ( # type: ignore
BigInteger,
Expand Down Expand Up @@ -39,6 +39,7 @@
FeatureService as FeatureServiceProto,
)
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
OnDemandFeatureView as OnDemandFeatureViewProto,
)
Expand Down Expand Up @@ -138,6 +139,14 @@
Column("validation_reference_proto", LargeBinary, nullable=False),
)

managed_infra = Table(
"managed_infra",
metadata,
Column("infra_name", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("infra_proto", LargeBinary, nullable=False),
)


class SqlRegistry(BaseRegistry):
def __init__(
Expand Down Expand Up @@ -168,6 +177,7 @@ def teardown(self):
conn.execute(stmt)

def refresh(self):
# This method is a no-op since we're always reading the latest values from the db.
pass

def get_stream_feature_view(
Expand Down Expand Up @@ -353,16 +363,7 @@ def apply_data_source(
def apply_feature_view(
self, feature_view: BaseFeatureView, project: str, commit: bool = True
):
if isinstance(feature_view, StreamFeatureView):
fv_table = stream_feature_views
elif isinstance(feature_view, FeatureView):
fv_table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
fv_table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
fv_table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
fv_table = self._infer_fv_table(feature_view)

return self._apply_object(
fv_table, "feature_view_name", feature_view, "feature_view_proto"
Expand Down Expand Up @@ -457,7 +458,25 @@ def apply_materialization(
end_date: datetime,
commit: bool = True,
):
pass
table = self._infer_fv_table(feature_view)
python_class, proto_class = self._infer_fv_classes(feature_view)

if python_class in {RequestFeatureView, OnDemandFeatureView}:
raise ValueError(
f"Cannot apply materialization for feature {feature_view.name} of type {python_class}"
)
fv: Union[FeatureView, StreamFeatureView] = self._get_object(
table,
feature_view.name,
project,
proto_class,
python_class,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
fv.materialization_intervals.append((start_date, end_date))
self._apply_object(table, "feature_view_name", fv, "feature_view_proto")

def delete_validation_reference(self, name: str, project: str, commit: bool = True):
self._delete_object(
Expand All @@ -469,27 +488,29 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr
)

def update_infra(self, infra: Infra, project: str, commit: bool = True):
pass
self._apply_object(
managed_infra, "infra_name", infra, "infra_proto", name="infra_obj"
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return Infra()
return self._get_object(
managed_infra,
"infra_obj",
project,
InfraProto,
Infra,
"infra_name",
"infra_proto",
None,
)

def apply_user_metadata(
self,
project: str,
feature_view: BaseFeatureView,
metadata_bytes: Optional[bytes],
):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
Expand All @@ -511,9 +532,7 @@ def apply_user_metadata(
else:
raise FeatureViewNotFoundException(feature_view.name, project=project)

def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
def _infer_fv_table(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
Expand All @@ -524,6 +543,25 @@ def get_user_metadata(
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return table

def _infer_fv_classes(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
python_class, proto_class = StreamFeatureView, StreamFeatureViewProto
elif isinstance(feature_view, FeatureView):
python_class, proto_class = FeatureView, FeatureViewProto
elif isinstance(feature_view, OnDemandFeatureView):
python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto
elif isinstance(feature_view, RequestFeatureView):
python_class, proto_class = RequestFeatureView, RequestFeatureViewProto
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return python_class, proto_class

def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
Expand Down Expand Up @@ -556,12 +594,11 @@ def proto(self) -> RegistryProto:
return r

def commit(self):
# This method is a no-op since we're always writing values eagerly to the db.
pass

def _apply_object(
self, table, id_field_name, obj, proto_field_name,
):
name = obj.name
def _apply_object(self, table, id_field_name, obj, proto_field_name, name=None):
name = name or obj.name
with self.engine.connect() as conn:
stmt = select(table).where(getattr(table.c, id_field_name) == name)
row = conn.execute(stmt).first()
Expand Down