Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,24 @@ def list_projects(
return self._list_projects(tags)

def refresh(self, project: Optional[str] = None):
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
if self._refresh_lock.locked():
logger.info("Skipping refresh if already in progress")
return
try:
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
except Exception as e:
logger.error(f"Error while refreshing registry: {e}", exc_info=True)

def _refresh_cached_registry_if_necessary(self):
if self.cache_mode == "sync":
with self._refresh_lock:
# Try acquiring the lock without blocking
if not self._refresh_lock.acquire(blocking=False):
logger.info(
"Skipping refresh if lock is already held by another thread"
)
return
try:
if self.cached_registry_proto == RegistryProto():
# Avoids the need to refresh the registry when cache is not populated yet
# Specially during the __init__ phase
Expand All @@ -454,6 +466,13 @@ def _refresh_cached_registry_if_necessary(self):
if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()
except Exception as e:
logger.error(
f"Error in _refresh_cached_registry_if_necessary: {e}",
exc_info=True,
)
finally:
self._refresh_lock.release() # Always release the lock safely

def _start_thread_async_refresh(self, cache_ttl_seconds):
self.refresh()
Expand Down
Empty file.
197 changes: 197 additions & 0 deletions sdk/python/tests/unit/infra/registry/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import patch

import pytest

from feast.infra.registry.caching_registry import CachingRegistry


class TestCachingRegistry(CachingRegistry):
"""Test subclass that implements abstract methods as no-ops"""

def _get_any_feature_view(self, *args, **kwargs):
pass

def _get_data_source(self, *args, **kwargs):
pass

def _get_entity(self, *args, **kwargs):
pass

def _get_feature_service(self, *args, **kwargs):
pass

def _get_feature_view(self, *args, **kwargs):
pass

def _get_infra(self, *args, **kwargs):
pass

def _get_on_demand_feature_view(self, *args, **kwargs):
pass

def _get_permission(self, *args, **kwargs):
pass

def _get_project(self, *args, **kwargs):
pass

def _get_saved_dataset(self, *args, **kwargs):
pass

def _get_stream_feature_view(self, *args, **kwargs):
pass

def _get_validation_reference(self, *args, **kwargs):
pass

def _list_all_feature_views(self, *args, **kwargs):
pass

def _list_data_sources(self, *args, **kwargs):
pass

def _list_entities(self, *args, **kwargs):
pass

def _list_feature_services(self, *args, **kwargs):
pass

def _list_feature_views(self, *args, **kwargs):
pass

def _list_on_demand_feature_views(self, *args, **kwargs):
pass

def _list_permissions(self, *args, **kwargs):
pass

def _list_project_metadata(self, *args, **kwargs):
pass

def _list_projects(self, *args, **kwargs):
pass

def _list_saved_datasets(self, *args, **kwargs):
pass

def _list_stream_feature_views(self, *args, **kwargs):
pass

def _list_validation_references(self, *args, **kwargs):
pass

def apply_data_source(self, *args, **kwargs):
pass

def apply_entity(self, *args, **kwargs):
pass

def apply_feature_service(self, *args, **kwargs):
pass

def apply_feature_view(self, *args, **kwargs):
pass

def apply_materialization(self, *args, **kwargs):
pass

def apply_permission(self, *args, **kwargs):
pass

def apply_project(self, *args, **kwargs):
pass

def apply_saved_dataset(self, *args, **kwargs):
pass

def apply_user_metadata(self, *args, **kwargs):
pass

def apply_validation_reference(self, *args, **kwargs):
pass

def commit(self, *args, **kwargs):
pass

def delete_data_source(self, *args, **kwargs):
pass

def delete_entity(self, *args, **kwargs):
pass

def delete_feature_service(self, *args, **kwargs):
pass

def delete_feature_view(self, *args, **kwargs):
pass

def delete_permission(self, *args, **kwargs):
pass

def delete_project(self, *args, **kwargs):
pass

def delete_validation_reference(self, *args, **kwargs):
pass

def get_user_metadata(self, *args, **kwargs):
pass

def proto(self, *args, **kwargs):
pass

def update_infra(self, *args, **kwargs):
pass


@pytest.fixture
def registry():
"""Fixture to create a real instance of CachingRegistry"""
return TestCachingRegistry(
project="test_example", cache_ttl_seconds=2, cache_mode="sync"
)


def test_cache_expiry_triggers_refresh(registry):
"""Test that an expired cache triggers a refresh"""
# Set cache creation time to a value that is expired
registry.cached_registry_proto = "some_cached_data"
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
seconds=5
)

# Mock _refresh_cached_registry_if_necessary to check if it is called
with patch.object(
CachingRegistry,
"_refresh_cached_registry_if_necessary",
wraps=registry._refresh_cached_registry_if_necessary,
) as mock_refresh_check:
registry._refresh_cached_registry_if_necessary()
mock_refresh_check.assert_called_once()

# Now check if the refresh was actually triggered
with patch.object(
CachingRegistry, "refresh", wraps=registry.refresh
) as mock_refresh:
registry._refresh_cached_registry_if_necessary()
mock_refresh.assert_called_once()


def test_skip_refresh_if_lock_held(registry):
"""Test that refresh is skipped if the lock is already held by another thread"""
registry.cached_registry_proto = "some_cached_data"
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
seconds=5
)

# Acquire the lock manually to simulate another thread holding it
registry._refresh_lock.acquire()
with patch.object(
CachingRegistry, "refresh", wraps=registry.refresh
) as mock_refresh:
registry._refresh_cached_registry_if_necessary()

# Since the lock was already held, refresh should NOT be called
mock_refresh.assert_not_called()
registry._refresh_lock.release()
Loading