Skip to content

Commit 8dceb9d

Browse files
committed
fix: Fix shared SQL registry crash - avoid unnecessary UDF deserialization in proto cache building
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent 7279c75 commit 8dceb9d

3 files changed

Lines changed: 250 additions & 50 deletions

File tree

sdk/python/feast/infra/registry/snowflake.py

Lines changed: 86 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ def _list_objects(
957957
python_class: Any,
958958
proto_field_name: str,
959959
tags: Optional[dict[str, str]] = None,
960+
proto_only: bool = False,
960961
):
961962
with GetSnowflakeConnection(self.registry_config) as conn:
962963
query = f"""
@@ -971,11 +972,13 @@ def _list_objects(
971972
if not df.empty:
972973
objects = []
973974
for row in df.iterrows():
974-
obj = python_class.from_proto(
975-
proto_class.FromString(row[1][proto_field_name])
976-
)
977-
if has_all_tags(obj.tags, tags):
978-
objects.append(obj)
975+
proto = proto_class.FromString(row[1][proto_field_name])
976+
if proto_only:
977+
objects.append(proto)
978+
else:
979+
obj = python_class.from_proto(proto)
980+
if has_all_tags(obj.tags, tags):
981+
objects.append(obj)
979982
return objects
980983
return []
981984

@@ -1134,28 +1137,90 @@ def process_project(project: Project):
11341137
r.projects.extend([project.to_proto()])
11351138
last_updated_timestamps.append(last_updated_timestamp)
11361139

1137-
for lister, registry_proto_field in [
1138-
(self.list_entities, r.entities),
1139-
(self.list_feature_views, r.feature_views),
1140-
(self.list_data_sources, r.data_sources),
1141-
(self.list_on_demand_feature_views, r.on_demand_feature_views),
1142-
(self.list_stream_feature_views, r.stream_feature_views),
1143-
(self.list_feature_services, r.feature_services),
1144-
(self.list_saved_datasets, r.saved_datasets),
1145-
(self.list_validation_references, r.validation_references),
1146-
(self.list_permissions, r.permissions),
1140+
# proto_only=True: return raw protos without calling from_proto(),
1141+
# which would trigger dill.loads() on UDFs and fail for cross-project
1142+
# modules. _list_objects hits the DB directly (no cache), avoiding
1143+
# infinite recursion since proto() itself builds the cache.
1144+
for (
1145+
table,
1146+
proto_class,
1147+
python_class,
1148+
proto_field_name,
1149+
registry_proto_field,
1150+
) in [
1151+
("ENTITIES", EntityProto, Entity, "ENTITY_PROTO", r.entities),
1152+
(
1153+
"FEATURE_VIEWS",
1154+
FeatureViewProto,
1155+
FeatureView,
1156+
"FEATURE_VIEW_PROTO",
1157+
r.feature_views,
1158+
),
1159+
(
1160+
"DATA_SOURCES",
1161+
DataSourceProto,
1162+
DataSource,
1163+
"DATA_SOURCE_PROTO",
1164+
r.data_sources,
1165+
),
1166+
(
1167+
"ON_DEMAND_FEATURE_VIEWS",
1168+
OnDemandFeatureViewProto,
1169+
OnDemandFeatureView,
1170+
"ON_DEMAND_FEATURE_VIEW_PROTO",
1171+
r.on_demand_feature_views,
1172+
),
1173+
(
1174+
"STREAM_FEATURE_VIEWS",
1175+
StreamFeatureViewProto,
1176+
StreamFeatureView,
1177+
"STREAM_FEATURE_VIEW_PROTO",
1178+
r.stream_feature_views,
1179+
),
1180+
(
1181+
"FEATURE_SERVICES",
1182+
FeatureServiceProto,
1183+
FeatureService,
1184+
"FEATURE_SERVICE_PROTO",
1185+
r.feature_services,
1186+
),
1187+
(
1188+
"SAVED_DATASETS",
1189+
SavedDatasetProto,
1190+
SavedDataset,
1191+
"SAVED_DATASET_PROTO",
1192+
r.saved_datasets,
1193+
),
1194+
(
1195+
"VALIDATION_REFERENCES",
1196+
ValidationReferenceProto,
1197+
ValidationReference,
1198+
"VALIDATION_REFERENCE_PROTO",
1199+
r.validation_references,
1200+
),
1201+
(
1202+
"PERMISSIONS",
1203+
PermissionProto,
1204+
Permission,
1205+
"PERMISSION_PROTO",
1206+
r.permissions,
1207+
),
11471208
]:
1148-
# Always bypass cache here: proto() builds the cache, so using
1149-
# allow_cache=True would cause infinite recursion via refresh().
1150-
objs: List[Any] = lister(project_name, False) # type: ignore
1209+
objs = self._list_objects(
1210+
table,
1211+
project_name,
1212+
proto_class,
1213+
python_class,
1214+
proto_field_name,
1215+
proto_only=True,
1216+
)
11511217
if objs:
1152-
obj_protos = [obj.to_proto() for obj in objs]
1153-
for obj_proto in obj_protos:
1218+
for obj_proto in objs:
11541219
if "spec" in obj_proto.DESCRIPTOR.fields_by_name:
11551220
obj_proto.spec.project = project_name
11561221
else:
11571222
obj_proto.project = project_name
1158-
registry_proto_field.extend(obj_protos)
1223+
registry_proto_field.extend(objs)
11591224

11601225
# This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783,
11611226
# the registry proto only has a single infra field, which we're currently setting as the "last" project.

sdk/python/feast/infra/registry/sql.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _get_stream_feature_view(self, name: str, project: str):
387387
)
388388

389389
def _list_stream_feature_views(
390-
self, project: str, tags: Optional[dict[str, str]]
390+
self, project: str, tags: Optional[dict[str, str]], **kwargs
391391
) -> List[StreamFeatureView]:
392392
return self._list_objects(
393393
stream_feature_views,
@@ -396,6 +396,7 @@ def _list_stream_feature_views(
396396
StreamFeatureView,
397397
"feature_view_proto",
398398
tags=tags,
399+
**kwargs,
399400
)
400401

401402
def apply_entity(self, entity: Entity, project: str, commit: bool = True):
@@ -537,7 +538,7 @@ def _get_validation_reference(self, name: str, project: str) -> ValidationRefere
537538
)
538539

539540
def _list_validation_references(
540-
self, project: str, tags: Optional[dict[str, str]] = None
541+
self, project: str, tags: Optional[dict[str, str]] = None, **kwargs
541542
) -> List[ValidationReference]:
542543
return self._list_objects(
543544
table=validation_references,
@@ -546,13 +547,20 @@ def _list_validation_references(
546547
python_class=ValidationReference,
547548
proto_field_name="validation_reference_proto",
548549
tags=tags,
550+
**kwargs,
549551
)
550552

551553
def _list_entities(
552-
self, project: str, tags: Optional[dict[str, str]]
554+
self, project: str, tags: Optional[dict[str, str]], **kwargs
553555
) -> List[Entity]:
554556
return self._list_objects(
555-
entities, project, EntityProto, Entity, "entity_proto", tags=tags
557+
entities,
558+
project,
559+
EntityProto,
560+
Entity,
561+
"entity_proto",
562+
tags=tags,
563+
**kwargs,
556564
)
557565

558566
def delete_entity(self, name: str, project: str, commit: bool = True):
@@ -614,7 +622,7 @@ def _get_data_source(self, name: str, project: str) -> DataSource:
614622
)
615623

616624
def _list_data_sources(
617-
self, project: str, tags: Optional[dict[str, str]]
625+
self, project: str, tags: Optional[dict[str, str]], **kwargs
618626
) -> List[DataSource]:
619627
return self._list_objects(
620628
data_sources,
@@ -623,6 +631,7 @@ def _list_data_sources(
623631
DataSource,
624632
"data_source_proto",
625633
tags=tags,
634+
**kwargs,
626635
)
627636

628637
def apply_data_source(
@@ -878,7 +887,7 @@ def delete_data_source(self, name: str, project: str, commit: bool = True):
878887
raise DataSourceObjectNotFoundException(name, project)
879888

880889
def _list_feature_services(
881-
self, project: str, tags: Optional[dict[str, str]]
890+
self, project: str, tags: Optional[dict[str, str]], **kwargs
882891
) -> List[FeatureService]:
883892
return self._list_objects(
884893
feature_services,
@@ -887,10 +896,11 @@ def _list_feature_services(
887896
FeatureService,
888897
"feature_service_proto",
889898
tags=tags,
899+
**kwargs,
890900
)
891901

892902
def _list_feature_views(
893-
self, project: str, tags: Optional[dict[str, str]]
903+
self, project: str, tags: Optional[dict[str, str]], **kwargs
894904
) -> List[FeatureView]:
895905
return self._list_objects(
896906
feature_views,
@@ -899,10 +909,11 @@ def _list_feature_views(
899909
FeatureView,
900910
"feature_view_proto",
901911
tags=tags,
912+
**kwargs,
902913
)
903914

904915
def _list_saved_datasets(
905-
self, project: str, tags: Optional[dict[str, str]] = None
916+
self, project: str, tags: Optional[dict[str, str]] = None, **kwargs
906917
) -> List[SavedDataset]:
907918
return self._list_objects(
908919
saved_datasets,
@@ -911,10 +922,11 @@ def _list_saved_datasets(
911922
SavedDataset,
912923
"saved_dataset_proto",
913924
tags=tags,
925+
**kwargs,
914926
)
915927

916928
def _list_on_demand_feature_views(
917-
self, project: str, tags: Optional[dict[str, str]]
929+
self, project: str, tags: Optional[dict[str, str]], **kwargs
918930
) -> List[OnDemandFeatureView]:
919931
return self._list_objects(
920932
on_demand_feature_views,
@@ -923,6 +935,7 @@ def _list_on_demand_feature_views(
923935
OnDemandFeatureView,
924936
"feature_view_proto",
925937
tags=tags,
938+
**kwargs,
926939
)
927940

928941
def _list_project_metadata(self, project: str) -> List[ProjectMetadata]:
@@ -1232,26 +1245,29 @@ def process_project(project: Project):
12321245
r.projects.extend([project.to_proto()])
12331246
last_updated_timestamps.append(last_updated_timestamp)
12341247

1248+
# proto_only=True: return raw protos without calling from_proto(),
1249+
# which would trigger dill.loads() on UDFs and fail for cross-project
1250+
# modules. The _list_* helpers hit the DB directly (no cache), avoiding
1251+
# infinite recursion since proto() itself builds the cache.
12351252
for lister, registry_proto_field in [
1236-
(self.list_entities, r.entities),
1237-
(self.list_feature_views, r.feature_views),
1238-
(self.list_data_sources, r.data_sources),
1239-
(self.list_on_demand_feature_views, r.on_demand_feature_views),
1240-
(self.list_stream_feature_views, r.stream_feature_views),
1241-
(self.list_feature_services, r.feature_services),
1242-
(self.list_saved_datasets, r.saved_datasets),
1243-
(self.list_validation_references, r.validation_references),
1244-
(self.list_permissions, r.permissions),
1253+
(self._list_entities, r.entities),
1254+
(self._list_feature_views, r.feature_views),
1255+
(self._list_data_sources, r.data_sources),
1256+
(self._list_on_demand_feature_views, r.on_demand_feature_views),
1257+
(self._list_stream_feature_views, r.stream_feature_views),
1258+
(self._list_feature_services, r.feature_services),
1259+
(self._list_saved_datasets, r.saved_datasets),
1260+
(self._list_validation_references, r.validation_references),
1261+
(self._list_permissions, r.permissions),
12451262
]:
1246-
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
1263+
objs: List[Any] = lister(project_name, tags=None, proto_only=True) # type: ignore
12471264
if objs:
1248-
obj_protos = [obj.to_proto() for obj in objs]
1249-
for obj_proto in obj_protos:
1265+
for obj_proto in objs:
12501266
if "spec" in obj_proto.DESCRIPTOR.fields_by_name:
12511267
obj_proto.spec.project = project_name
12521268
else:
12531269
obj_proto.project = project_name
1254-
registry_proto_field.extend(obj_protos)
1270+
registry_proto_field.extend(objs)
12551271

12561272
# This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783,
12571273
# the registry proto only has a single infra field, which we're currently setting as the "last" project.
@@ -1486,18 +1502,21 @@ def _list_objects(
14861502
python_class: Any,
14871503
proto_field_name: str,
14881504
tags: Optional[dict[str, str]] = None,
1505+
proto_only: bool = False,
14891506
):
14901507
with self.read_engine.begin() as conn:
14911508
stmt = select(table).where(table.c.project_id == project)
14921509
rows = conn.execute(stmt).all()
14931510
if rows:
14941511
objects = []
14951512
for row in rows:
1496-
obj = python_class.from_proto(
1497-
proto_class.FromString(row._mapping[proto_field_name])
1498-
)
1499-
if utils.has_all_tags(obj.tags, tags):
1500-
objects.append(obj)
1513+
proto = proto_class.FromString(row._mapping[proto_field_name])
1514+
if proto_only:
1515+
objects.append(proto)
1516+
else:
1517+
obj = python_class.from_proto(proto)
1518+
if utils.has_all_tags(obj.tags, tags):
1519+
objects.append(obj)
15011520
return objects
15021521
return []
15031522

@@ -1568,7 +1587,7 @@ def _get_permission(self, name: str, project: str) -> Permission:
15681587
)
15691588

15701589
def _list_permissions(
1571-
self, project: str, tags: Optional[dict[str, str]]
1590+
self, project: str, tags: Optional[dict[str, str]], **kwargs
15721591
) -> List[Permission]:
15731592
return self._list_objects(
15741593
permissions,
@@ -1577,6 +1596,7 @@ def _list_permissions(
15771596
Permission,
15781597
"permission_proto",
15791598
tags=tags,
1599+
**kwargs,
15801600
)
15811601

15821602
def apply_permission(

0 commit comments

Comments
 (0)