Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError):
def __init__(self, store_name: str, version: int):
super().__init__(
f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. "
f"Currently only SQLite supports version-qualified feature references. "
f"Currently only SQLite and SingleStore support version-qualified feature references. "
)


Expand Down
23 changes: 17 additions & 6 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_online_features(
)

# Check for versioned reads on unsupported stores
self._check_versioned_read_support(grouped_refs)
self._check_versioned_read_support(grouped_refs, config)
_track_read = False
try:
from feast.metrics import _config as _metrics_config
Expand Down Expand Up @@ -254,19 +254,30 @@ def get_online_features(
)
return OnlineResponse(online_features_response)

def _check_versioned_read_support(self, grouped_refs):
def _check_versioned_read_support(self, grouped_refs, config: RepoConfig):
"""Raise an error if versioned reads are attempted on unsupported stores."""
from feast.infra.online_stores.singlestore_online_store.singlestore import (
SingleStoreOnlineStore,
)
from feast.infra.online_stores.sqlite import SqliteOnlineStore

if isinstance(self, SqliteOnlineStore):
return
for table, _ in grouped_refs:
version_tag = getattr(table.projection, "version_tag", None)
if version_tag is not None:
if version_tag is None:
continue

# Version-qualified refs (e.g. @v2) are only supported when online versioning is enabled.
if not config.registry.enable_online_feature_view_versioning:
raise VersionedOnlineReadNotSupported(
self.__class__.__name__, version_tag
)

# Online versioning enabled: allow stores that implement versioned routing.
if isinstance(self, (SqliteOnlineStore, SingleStoreOnlineStore)):
continue

raise VersionedOnlineReadNotSupported(self.__class__.__name__, version_tag)

async def get_online_features_async(
self,
config: RepoConfig,
Expand Down Expand Up @@ -311,7 +322,7 @@ async def get_online_features_async(
)

# Check for versioned reads on unsupported stores
self._check_versioned_read_support(grouped_refs)
self._check_versioned_read_support(grouped_refs, config)

async def query_table(table, requested_features):
# Get the correct set of entity values with the correct join keys.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def online_write_batch(
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=3,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
Expand All @@ -102,7 +102,7 @@ def online_write_batch(
current_batch = insert_values[i : i + batch_size]
cur.executemany(
f"""
INSERT INTO {_table_id(project, table)}
INSERT INTO {_table_id(project, table, config.registry.enable_online_feature_view_versioning)}
(entity_key, feature_name, value, event_ts, created_ts)
values (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
Expand Down Expand Up @@ -130,15 +130,15 @@ def online_read(
keys.append(
serialize_entity_key(
entity_key,
entity_key_serialization_version=3,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
)

if not requested_features:
entity_key_placeholders = ",".join(["%s" for _ in keys])
cur.execute(
f"""
SELECT entity_key, feature_name, value, event_ts FROM {_table_id(project, table)}
SELECT entity_key, feature_name, value, event_ts FROM {_table_id(project, table, config.registry.enable_online_feature_view_versioning)}
WHERE entity_key IN ({entity_key_placeholders})
ORDER BY event_ts;
""",
Expand All @@ -151,7 +151,7 @@ def online_read(
)
cur.execute(
f"""
SELECT entity_key, feature_name, value, event_ts FROM {_table_id(project, table)}
SELECT entity_key, feature_name, value, event_ts FROM {_table_id(project, table, config.registry.enable_online_feature_view_versioning)}
WHERE entity_key IN ({entity_key_placeholders}) and feature_name IN ({requested_features_placeholders})
ORDER BY event_ts;
""",
Expand Down Expand Up @@ -191,21 +191,23 @@ def update(
partial: bool,
) -> None:
project = config.project
versioning = config.registry.enable_online_feature_view_versioning
with self._get_cursor(config) as cur:
# We don't create any special state for the entities in this implementation.
for table in tables_to_keep:
table_name = _table_id(project, table, versioning)
cur.execute(
f"""CREATE TABLE IF NOT EXISTS {_table_id(project, table)} (entity_key VARCHAR(512),
f"""CREATE TABLE IF NOT EXISTS {table_name} (entity_key VARCHAR(512),
feature_name VARCHAR(256),
value BLOB,
event_ts timestamp NULL DEFAULT NULL,
created_ts timestamp NULL DEFAULT NULL,
PRIMARY KEY(entity_key, feature_name),
INDEX {_table_id(project, table)}_ek (entity_key))"""
INDEX {table_name}_ek (entity_key))"""
)

for table in tables_to_delete:
_drop_table_and_index(cur, project, table)
_drop_table_and_index(cur, project, table, versioning)

def teardown(
self,
Expand All @@ -214,16 +216,26 @@ def teardown(
entities: Sequence[Entity],
) -> None:
project = config.project
versioning = config.registry.enable_online_feature_view_versioning
with self._get_cursor(config) as cur:
for table in tables:
_drop_table_and_index(cur, project, table)
_drop_table_and_index(cur, project, table, versioning)


def _drop_table_and_index(cur: Cursor, project: str, table: FeatureView) -> None:
table_name = _table_id(project, table)
def _drop_table_and_index(
cur: Cursor, project: str, table: FeatureView, enable_versioning: bool
) -> None:
table_name = _table_id(project, table, enable_versioning)
cur.execute(f"DROP INDEX {table_name}_ek ON {table_name};")
cur.execute(f"DROP TABLE IF EXISTS {table_name}")


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
name = table.name
if enable_versioning:
version = getattr(table.projection, "version_tag", None)
if version is None:
version = getattr(table, "current_version_number", None)
if version is not None and version > 0:
name = f"{table.name}_v{version}"
return f"{project}_{name}"
1 change: 0 additions & 1 deletion sdk/python/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ env =
IS_TEST=True
filterwarnings =
error::_pytest.warning_types.PytestConfigWarning
error::_pytest.warning_types.PytestUnhandledCoroutineWarning
ignore::DeprecationWarning:pyspark.sql.pandas.*:
ignore::DeprecationWarning:pyspark.sql.connect.*:
ignore::DeprecationWarning:httpx.*:
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@


def pytest_configure(config):
if platform in ["darwin", "windows"]:
if platform in ["darwin"] or platform.startswith("win"):
multiprocessing.set_start_method("spawn", force=True)
else:
multiprocessing.set_start_method("fork")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,77 @@ def test_write_to_online_store(environment, universal_data_sources):
assertpy.assert_that(df["conv_rate"].iloc[0]).is_close_to(0.85, 1e-6)


@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["singlestore"])
def test_singlestore_versioned_online_reads(environment, universal_data_sources):
fs = environment.feature_store
fs.config.registry.enable_online_feature_view_versioning = True

entities, datasets, data_sources = universal_data_sources
driver_entity = driver()

# Apply v0
driver_hourly_stats_v0 = create_driver_hourly_stats_feature_view(
data_sources.driver
)
fs.apply([driver_hourly_stats_v0, driver_entity])

# Write v0 data
df_v0 = pd.DataFrame(
{
"driver_id": [1],
"conv_rate": [0.1],
"acc_rate": [0.2],
"avg_daily_trips": [10],
"driver_metadata": [None],
"driver_config": [None],
"driver_profile": [None],
"event_timestamp": [pd.Timestamp(_utc_now()).round("ms")],
"created": [pd.Timestamp(_utc_now()).round("ms")],
}
)
fs.write_to_online_store("driver_stats", df_v0)

# Apply a schema change to create v1
driver_hourly_stats_v1 = FeatureView(
name="driver_stats",
entities=[driver_entity],
schema=driver_hourly_stats_v0.schema
+ [Field(name="new_feature", dtype=Float32)],
source=data_sources.driver,
ttl=driver_hourly_stats_v0.ttl,
tags=TAGS,
)
fs.apply([driver_hourly_stats_v1, driver_entity])

# Write v1 data
df_v1 = pd.DataFrame(
{
"driver_id": [1],
"conv_rate": [0.1],
"acc_rate": [0.2],
"avg_daily_trips": [20],
"new_feature": [1.0],
"driver_metadata": [None],
"driver_config": [None],
"driver_profile": [None],
"event_timestamp": [pd.Timestamp(_utc_now()).round("ms")],
"created": [pd.Timestamp(_utc_now()).round("ms")],
}
)
fs.write_to_online_store("driver_stats", df_v1)

# Read v0 and v1 explicitly
df = fs.get_online_features(
features=["driver_stats@v0:avg_daily_trips", "driver_stats@v1:avg_daily_trips"],
entity_rows=[{"driver_id": 1}],
full_feature_names=True,
).to_df()

assertpy.assert_that(df["driver_stats@v0__avg_daily_trips"].iloc[0]).is_equal_to(10)
assertpy.assert_that(df["driver_stats@v1__avg_daily_trips"].iloc[0]).is_equal_to(20)


def _get_online_features_dict_remotely(
endpoint: str,
features: Union[List[str], FeatureService],
Expand Down
Loading