Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,15 @@ def environment(request, worker_id):
request.param, worker_id=worker_id, fixture_request=request
)

e.setup()

if hasattr(e.data_source_creator, "mock_environ"):
with mock.patch.dict(os.environ, e.data_source_creator.mock_environ):
yield e
else:
yield e

e.feature_store.teardown()
e.data_source_creator.teardown()
if e.online_store_creator:
e.online_store_creator.teardown()
e.teardown()


_config_cache: Any = {}
Expand Down
76 changes: 45 additions & 31 deletions sdk/python/tests/integration/feature_repos/repo_configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import importlib
import json
import os
import tempfile
import uuid
Expand All @@ -11,13 +10,15 @@

import pandas as pd
import pytest
import yaml

from feast import FeatureStore, FeatureView, OnDemandFeatureView, driver_test_data
from feast.constants import FULL_REPO_CONFIGS_MODULE_ENV_NAME
from feast.data_source import DataSource
from feast.errors import FeastModuleImportError
from feast.infra.feature_servers.base_config import FeatureLoggingConfig
from feast.infra.feature_servers.base_config import (
BaseFeatureServerConfig,
FeatureLoggingConfig,
)
from feast.infra.feature_servers.local_process.config import LocalFeatureServerConfig
from feast.repo_config import RegistryConfig, RepoConfig
from tests.integration.feature_repos.integration_test_repo_config import (
Expand Down Expand Up @@ -397,18 +398,48 @@ def construct_universal_feature_views(
@dataclass
class Environment:
name: str
test_repo_config: IntegrationTestRepoConfig
feature_store: FeatureStore
project: str
provider: str
registry: RegistryConfig
data_source_creator: DataSourceCreator
online_store_creator: Optional[OnlineStoreCreator]
online_store: Optional[Union[str, Dict]]
batch_engine: Optional[Union[str, Dict]]
python_feature_server: bool
worker_id: str
online_store_creator: Optional[OnlineStoreCreator] = None
feature_server: BaseFeatureServerConfig
entity_key_serialization_version: int
repo_dir_name: str
fixture_request: Optional[pytest.FixtureRequest] = None

def __post_init__(self):
self.end_date = datetime.utcnow().replace(microsecond=0, second=0, minute=0)
self.start_date: datetime = self.end_date - timedelta(days=3)

def setup(self):
self.data_source_creator.setup(self.registry)

self.config = RepoConfig(
registry=self.registry,
project=self.project,
provider=self.provider,
offline_store=self.data_source_creator.create_offline_store_config(),
online_store=self.online_store_creator.create_online_store()
if self.online_store_creator
else self.online_store,
batch_engine=self.batch_engine,
repo_path=self.repo_dir_name,
feature_server=self.feature_server,
entity_key_serialization_version=self.entity_key_serialization_version,
)
self.feature_store = FeatureStore(config=self.config)

def teardown(self):
self.feature_store.teardown()
self.data_source_creator.teardown()
if self.online_store_creator:
self.online_store_creator.teardown()


def table_name_from_data_source(ds: DataSource) -> Optional[str]:
if hasattr(ds, "table_ref"):
Expand Down Expand Up @@ -436,16 +467,13 @@ def construct_test_environment(
offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(
project, fixture_request=fixture_request
)
offline_store_config = offline_creator.create_offline_store_config()

if test_repo_config.online_store_creator:
online_creator = test_repo_config.online_store_creator(
project, fixture_request=fixture_request
)
online_store = online_creator.create_online_store()
else:
online_creator = None
online_store = test_repo_config.online_store

if test_repo_config.python_feature_server and test_repo_config.provider == "aws":
from feast.infra.feature_servers.aws_lambda.config import (
Expand Down Expand Up @@ -481,35 +509,21 @@ def construct_test_environment(
cache_ttl_seconds=1,
)

config = RepoConfig(
registry=registry,
project=project,
provider=test_repo_config.provider,
offline_store=offline_store_config,
online_store=online_store,
batch_engine=test_repo_config.batch_engine,
repo_path=repo_dir_name,
feature_server=feature_server,
entity_key_serialization_version=entity_key_serialization_version,
)

# Create feature_store.yaml out of the config
with open(Path(repo_dir_name) / "feature_store.yaml", "w") as f:
yaml.safe_dump(json.loads(config.model_dump_json(by_alias=True)), f)

fs = FeatureStore(repo_dir_name)
# We need to initialize the registry, because if nothing is applied in the test before tearing down
# the feature store, that will cause the teardown method to blow up.
fs.registry._initialize_registry(project)
environment = Environment(
name=project,
test_repo_config=test_repo_config,
feature_store=fs,
provider=test_repo_config.provider,
data_source_creator=offline_creator,
python_feature_server=test_repo_config.python_feature_server,
worker_id=worker_id,
online_store_creator=online_creator,
fixture_request=fixture_request,
project=project,
registry=registry,
feature_server=feature_server,
entity_key_serialization_version=entity_key_serialization_version,
repo_dir_name=repo_dir_name,
batch_engine=test_repo_config.batch_engine,
online_store=test_repo_config.online_store,
)

return environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from feast.data_source import DataSource
from feast.feature_logging import LoggingDestination
from feast.repo_config import FeastConfigBaseModel
from feast.repo_config import FeastConfigBaseModel, RegistryConfig
from feast.saved_dataset import SavedDatasetStorage


Expand Down Expand Up @@ -44,6 +44,9 @@ def create_data_source(
"""
raise NotImplementedError

def setup(self, registry: RegistryConfig):
pass

@abstractmethod
def create_offline_store_config(self) -> FeastConfigBaseModel:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_spark_materialization_consistency():
spark_config, None, entity_key_serialization_version=2
)

spark_environment.setup()

df = create_basic_driver_dataset()

ds = spark_environment.data_source_creator.create_data_source(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_snowflake_materialization_consistency(online_store):
batch_engine=SNOWFLAKE_ENGINE_CONFIG,
)
snowflake_environment = construct_test_environment(snowflake_config, None)
snowflake_environment.setup()

df = create_basic_driver_dataset()
ds = snowflake_environment.data_source_creator.create_data_source(
Expand Down Expand Up @@ -112,6 +113,7 @@ def test_snowflake_materialization_consistency_internal_with_lists(
batch_engine=SNOWFLAKE_ENGINE_CONFIG,
)
snowflake_environment = construct_test_environment(snowflake_config, None)
snowflake_environment.setup()

df = create_basic_driver_dataset(Int32, feature_dtype, True, feature_is_empty_list)
ds = snowflake_environment.data_source_creator.create_data_source(
Expand Down Expand Up @@ -195,6 +197,7 @@ def test_snowflake_materialization_entityless_fv():
batch_engine=SNOWFLAKE_ENGINE_CONFIG,
)
snowflake_environment = construct_test_environment(snowflake_config, None)
snowflake_environment.setup()

df = create_basic_driver_dataset()
entityless_df = df.drop("driver_id", axis=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def test_historical_features_with_entities_from_query(
if not orders_table:
raise pytest.skip("Offline source is not sql-based")

data_source_creator = environment.test_repo_config.offline_store_creator
if data_source_creator.__name__ == SnowflakeDataSourceCreator.__name__:
data_source_creator = environment.data_source_creator
if isinstance(data_source_creator, SnowflakeDataSourceCreator):
entity_df_query = f"""
SELECT "customer_id", "driver_id", "order_id", "origin_id", "destination_id", "event_timestamp"
FROM "{orders_table}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ def test_universal_cli(environment: Environment):
repo_path = Path(repo_dir_name)
feature_store_yaml = make_feature_store_yaml(
project,
environment.test_repo_config,
repo_path,
environment.data_source_creator,
environment.provider,
environment.online_store,
)

repo_config = repo_path / "feature_store.yaml"
Expand Down Expand Up @@ -124,9 +125,10 @@ def test_odfv_apply(environment) -> None:
repo_path = Path(repo_dir_name)
feature_store_yaml = make_feature_store_yaml(
project,
environment.test_repo_config,
repo_path,
environment.data_source_creator,
environment.provider,
environment.online_store,
)

repo_config = repo_path / "feature_store.yaml"
Expand Down Expand Up @@ -158,9 +160,10 @@ def test_nullable_online_store(test_nullable_online_store) -> None:
repo_path = Path(repo_dir_name)
feature_store_yaml = make_feature_store_yaml(
project,
test_nullable_online_store,
repo_path,
test_nullable_online_store.offline_store_creator(project),
test_nullable_online_store.provider,
test_nullable_online_store.online_store,
)

repo_config = repo_path / "feature_store.yaml"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_feature_get_historical_features_types_match(

if config.feature_is_list:
assert_feature_list_types(
environment.test_repo_config.provider,
environment.provider,
config.feature_dtype,
historical_features_df,
)
Expand All @@ -119,7 +119,7 @@ def test_feature_get_historical_features_types_match(
config.feature_dtype, historical_features_df
)
assert_expected_arrow_types(
environment.test_repo_config.provider,
environment.provider,
config.feature_dtype,
config.feature_is_list,
historical_features,
Expand Down Expand Up @@ -335,10 +335,7 @@ class TypeTestConfig:
)
def offline_types_test_fixtures(request, environment):
config: TypeTestConfig = request.param
if (
environment.test_repo_config.provider == "aws"
and config.feature_is_list is True
):
if environment.provider == "aws" and config.feature_is_list is True:
pytest.skip("Redshift doesn't support list features")

return get_fixtures(request, environment)
Expand Down
31 changes: 14 additions & 17 deletions sdk/python/tests/unit/infra/offline_stores/test_offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AthenaRetrievalJob,
)
from feast.infra.offline_stores.contrib.mssql_offline_store.mssql import (
MsSqlServerOfflineStoreConfig,
MsSqlServerRetrievalJob,
)
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import (
Expand Down Expand Up @@ -120,12 +119,14 @@ def retrieval_job(request, environment):
iam_role="arn:aws:iam::585132637328:role/service-role/AmazonRedshift-CommandsAccessRole-20240403T092631",
workgroup="",
)
environment.test_repo_config.offline_store = offline_store_config
config = environment.config.copy(
update={"offline_config": offline_store_config}
)
return RedshiftRetrievalJob(
query="query",
redshift_client="",
s3_resource="",
config=environment.test_repo_config,
config=config,
full_feature_names=False,
)
elif request.param is SnowflakeRetrievalJob:
Expand All @@ -141,12 +142,14 @@ def retrieval_job(request, environment):
storage_integration_name="FEAST_S3",
blob_export_location="s3://feast-snowflake-offload/export",
)
environment.test_repo_config.offline_store = offline_store_config
environment.test_repo_config.project = "project"
config = environment.config.copy(
update={"offline_config": offline_store_config}
)
environment.project = "project"
return SnowflakeRetrievalJob(
query="query",
snowflake_conn=MagicMock(),
config=environment.test_repo_config,
config=config,
full_feature_names=False,
)
elif request.param is AthenaRetrievalJob:
Expand All @@ -158,21 +161,18 @@ def retrieval_job(request, environment):
s3_staging_location="athena",
)

environment.test_repo_config.offline_store = offline_store_config
return AthenaRetrievalJob(
query="query",
athena_client="client",
s3_resource="",
config=environment.test_repo_config.offline_store,
config=environment.config,
full_feature_names=False,
)
elif request.param is MsSqlServerRetrievalJob:
return MsSqlServerRetrievalJob(
query="query",
engine=MagicMock(),
config=MsSqlServerOfflineStoreConfig(
connection_string="str"
), # TODO: this does not match the RetrievalJob pattern. Suppose to be RepoConfig
config=environment.config,
full_feature_names=False,
)
elif request.param is PostgreSQLRetrievalJob:
Expand All @@ -182,28 +182,25 @@ def retrieval_job(request, environment):
user="str",
password="str",
)
environment.test_repo_config.offline_store = offline_store_config
return PostgreSQLRetrievalJob(
query="query",
config=environment.test_repo_config.offline_store,
config=environment.config,
full_feature_names=False,
)
elif request.param is SparkRetrievalJob:
offline_store_config = SparkOfflineStoreConfig()
environment.test_repo_config.offline_store = offline_store_config
return SparkRetrievalJob(
spark_session=MagicMock(),
query="str",
full_feature_names=False,
config=environment.test_repo_config,
config=environment.config,
)
elif request.param is TrinoRetrievalJob:
offline_store_config = SparkOfflineStoreConfig()
environment.test_repo_config.offline_store = offline_store_config
return TrinoRetrievalJob(
query="str",
client=MagicMock(),
config=environment.test_repo_config,
config=environment.config,
full_feature_names=False,
)
else:
Expand Down
Loading