Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class DataSource(ABC):
was created, used for deduplicating rows.
field_mapping (optional): A dictionary mapping of column names in this data
source to feature names in a feature table or view. Only used for feature
columns, not entity or timestamp columns.
columns and timestamp columns, not entity columns.
description (optional) A human-readable description.
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the data source, typically the email of the primary
Expand Down Expand Up @@ -463,9 +463,11 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
batch_source=DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None,
batch_source=(
DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None
),
)

def to_proto(self) -> DataSourceProto:
Expand Down Expand Up @@ -643,9 +645,11 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
batch_source=DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None,
batch_source=(
DataSource.from_proto(data_source.batch_source)
if data_source.batch_source
else None
),
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArti
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
# Spark offline store does the field mapping in pull_latest_from_table_or_query() call
# This may be needed in future if this materialization engine supports other offline stores
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import spark_schema_to_np_dtypes
from feast.utils import _get_fields_with_aliases

# Make sure spark warning doesn't raise more than once.
warnings.simplefilter("once", RuntimeWarning)
Expand Down Expand Up @@ -90,16 +91,22 @@ def pull_latest_from_table_or_query(
if created_timestamp_column:
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)
(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=join_key_columns + feature_name_columns + timestamps,
field_mappings=data_source.field_mapping,
)

fields_as_string = ", ".join(fields_with_aliases)
aliases_as_string = ", ".join(aliases)

start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
query = f"""
SELECT
{field_string}
{aliases_as_string}
{f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""}
FROM (
SELECT {field_string},
SELECT {fields_as_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_
FROM {from_expression} t1
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}')
Expand Down Expand Up @@ -279,14 +286,19 @@ def pull_all_from_table_or_query(
spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)
(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=join_key_columns + feature_name_columns + [timestamp_field],
field_mappings=data_source.field_mapping,
)

fields_with_alias_string = ", ".join(fields_with_aliases)

fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field])
from_expression = data_source.get_table_query_string()
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT {fields}
SELECT {fields_with_alias_string}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
"""
Expand Down
33 changes: 29 additions & 4 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def _get_requested_feature_views_to_features_dict(
on_demand_feature_views: List["OnDemandFeatureView"],
) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]:
"""Create a dict of FeatureView -> List[Feature] for all requested features.
Set full_feature_names to True to have feature names prefixed by their feature view name."""
Set full_feature_names to True to have feature names prefixed by their feature view name.
"""

feature_views_to_feature_map: Dict["FeatureView", List[str]] = defaultdict(list)
on_demand_feature_views_to_feature_map: Dict["OnDemandFeatureView", List[str]] = (
Expand Down Expand Up @@ -212,6 +213,28 @@ def _run_pyarrow_field_mapping(
return table


def _get_fields_with_aliases(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably would be good to add a test for this but i won't block on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have a direct test but pull_latest_query tests will use it. I'll add explicit tests in next PRs.

fields: List[str],
field_mappings: Dict[str, str],
) -> Tuple[List[str], List[str]]:
"""
Get a list of fields with aliases based on the field mappings.
"""
for field in fields:
if "." in field and field not in field_mappings:
raise ValueError(
f"Feature {field} contains a '.' character, which is not allowed in field names. Use field mappings to rename fields."
)
fields_with_aliases = [
f"{field} AS {field_mappings[field]}" if field in field_mappings else field
for field in fields
]
aliases = [
field_mappings[field] if field in field_mappings else field for field in fields
]
return (fields_with_aliases, aliases)


def _coerce_datetime(ts):
"""
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
Expand Down Expand Up @@ -781,9 +804,11 @@ def _populate_response_from_feature_data(
"""
# Add the feature names to the response.
requested_feature_refs = [
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
(
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
)
for feature_name in requested_features
]
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1828,3 +1828,51 @@ def test_apply_entity_to_sql_registry_and_reinitialize_sql_registry(test_registr

updated_test_registry.teardown()
test_registry.teardown()


@pytest.mark.integration
def test_commit_for_read_only_user():
fd, registry_path = mkstemp()
registry_config = RegistryConfig(path=registry_path, cache_ttl_seconds=600)
write_registry = Registry("project", registry_config, None)

entity = Entity(
name="driver_car_id",
description="Car driver id",
tags={"team": "matchmaking"},
)

project = "project"

# Register Entity without commiting
write_registry.apply_entity(entity, project, commit=False)
assert write_registry.cached_registry_proto
project_obj = write_registry.cached_registry_proto.projects[0]
assert project == Project.from_proto(project_obj).name
assert_project(project, write_registry, True)

# Retrieving the entity should still succeed
entities = write_registry.list_entities(project, allow_cache=True, tags=entity.tags)
entity = entities[0]
assert (
len(entities) == 1
and entity.name == "driver_car_id"
and entity.description == "Car driver id"
and "team" in entity.tags
and entity.tags["team"] == "matchmaking"
)

# commit from the original registry
write_registry.commit()

# Reconstruct the new registry in order to read the newly written store
with mock.patch.object(
Registry,
"commit",
side_effect=Exception("Read only users are not allowed to commit"),
):
read_registry = Registry("project", registry_config, None)
entities = read_registry.list_entities(project, tags=entity.tags)
assert len(entities) == 1

write_registry.teardown()
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from datetime import datetime
from unittest.mock import MagicMock, patch

from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkOfflineStore,
SparkOfflineStoreConfig,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.repo_config import RepoConfig


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_session):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_nested_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_header.event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, nested_timestamp, created_timestamp

FROM (
SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
) t2
WHERE feast_row_ = 1""" # noqa: W293

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_without_nested_timestamp_or_query(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp

FROM (
SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
) t2
WHERE feast_row_ = 1""" # noqa: W293

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()
Loading