Skip to content

Commit 6713384

Browse files
author
Tsotne Tabidze
authored
Clean up uploaded entities in Redshift offline store (#1730)
Signed-off-by: Tsotne Tabidze <tsotne@tecton.ai>
1 parent 8099ea7 commit 6713384

3 files changed

Lines changed: 128 additions & 101 deletions

File tree

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 102 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import contextlib
12
import uuid
23
from datetime import datetime
3-
from typing import Dict, List, Optional, Union
4+
from typing import Callable, ContextManager, Dict, Iterator, List, Optional, Union
45

56
import numpy as np
67
import pandas as pd
@@ -113,40 +114,53 @@ def get_historical_features(
113114
)
114115
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)
115116

116-
table_name = offline_utils.get_temp_entity_table_name()
117+
@contextlib.contextmanager
118+
def query_generator() -> Iterator[str]:
119+
table_name = offline_utils.get_temp_entity_table_name()
117120

118-
entity_schema = _upload_entity_df_and_get_entity_schema(
119-
entity_df, redshift_client, config, s3_resource, table_name
120-
)
121+
entity_schema = _upload_entity_df_and_get_entity_schema(
122+
entity_df, redshift_client, config, s3_resource, table_name
123+
)
121124

122-
entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
123-
entity_schema
124-
)
125+
entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
126+
entity_schema
127+
)
125128

126-
expected_join_keys = offline_utils.get_expected_join_keys(
127-
project, feature_views, registry
128-
)
129+
expected_join_keys = offline_utils.get_expected_join_keys(
130+
project, feature_views, registry
131+
)
129132

130-
offline_utils.assert_expected_columns_in_entity_df(
131-
entity_schema, expected_join_keys, entity_df_event_timestamp_col
132-
)
133+
offline_utils.assert_expected_columns_in_entity_df(
134+
entity_schema, expected_join_keys, entity_df_event_timestamp_col
135+
)
133136

134-
# Build a query context containing all information required to template the Redshift SQL query
135-
query_context = offline_utils.get_feature_view_query_context(
136-
feature_refs, feature_views, registry, project,
137-
)
137+
# Build a query context containing all information required to template the Redshift SQL query
138+
query_context = offline_utils.get_feature_view_query_context(
139+
feature_refs, feature_views, registry, project,
140+
)
138141

139-
# Generate the Redshift SQL query from the query context
140-
query = offline_utils.build_point_in_time_query(
141-
query_context,
142-
left_table_query_string=table_name,
143-
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
144-
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
145-
full_feature_names=full_feature_names,
146-
)
142+
# Generate the Redshift SQL query from the query context
143+
query = offline_utils.build_point_in_time_query(
144+
query_context,
145+
left_table_query_string=table_name,
146+
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
147+
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
148+
full_feature_names=full_feature_names,
149+
)
150+
151+
yield query
152+
153+
# Clean up the uploaded Redshift table
154+
aws_utils.execute_redshift_statement(
155+
redshift_client,
156+
config.offline_store.cluster_id,
157+
config.offline_store.database,
158+
config.offline_store.user,
159+
f"DROP TABLE {table_name}",
160+
)
147161

148162
return RedshiftRetrievalJob(
149-
query=query,
163+
query=query_generator,
150164
redshift_client=redshift_client,
151165
s3_resource=s3_resource,
152166
config=config,
@@ -161,7 +175,7 @@ def get_historical_features(
161175
class RedshiftRetrievalJob(RetrievalJob):
162176
def __init__(
163177
self,
164-
query: str,
178+
query: Union[str, Callable[[], ContextManager[str]]],
165179
redshift_client,
166180
s3_resource,
167181
config: RepoConfig,
@@ -170,14 +184,23 @@ def __init__(
170184
"""Initialize RedshiftRetrievalJob object.
171185
172186
Args:
173-
query: Redshift SQL query to execute.
187+
query: Redshift SQL query to execute. Either a string, or a generator function that handles the artifact cleanup.
174188
redshift_client: boto3 redshift-data client
175189
s3_resource: boto3 s3 resource object
176190
config: Feast repo config
177191
drop_columns: Optionally a list of columns to drop before unloading to S3.
178192
This is a convenient field, since "SELECT ... EXCEPT col" isn't supported in Redshift.
179193
"""
180-
self.query = query
194+
if not isinstance(query, str):
195+
self._query_generator = query
196+
else:
197+
198+
@contextlib.contextmanager
199+
def query_generator() -> Iterator[str]:
200+
assert isinstance(query, str)
201+
yield query
202+
203+
self._query_generator = query_generator
181204
self._redshift_client = redshift_client
182205
self._s3_resource = s3_resource
183206
self._config = config
@@ -189,59 +212,63 @@ def __init__(
189212
self._drop_columns = drop_columns
190213

191214
def to_df(self) -> pd.DataFrame:
192-
return aws_utils.unload_redshift_query_to_df(
193-
self._redshift_client,
194-
self._config.offline_store.cluster_id,
195-
self._config.offline_store.database,
196-
self._config.offline_store.user,
197-
self._s3_resource,
198-
self._s3_path,
199-
self._config.offline_store.iam_role,
200-
self.query,
201-
self._drop_columns,
202-
)
215+
with self._query_generator() as query:
216+
return aws_utils.unload_redshift_query_to_df(
217+
self._redshift_client,
218+
self._config.offline_store.cluster_id,
219+
self._config.offline_store.database,
220+
self._config.offline_store.user,
221+
self._s3_resource,
222+
self._s3_path,
223+
self._config.offline_store.iam_role,
224+
query,
225+
self._drop_columns,
226+
)
203227

204228
def to_arrow(self) -> pa.Table:
205-
return aws_utils.unload_redshift_query_to_pa(
206-
self._redshift_client,
207-
self._config.offline_store.cluster_id,
208-
self._config.offline_store.database,
209-
self._config.offline_store.user,
210-
self._s3_resource,
211-
self._s3_path,
212-
self._config.offline_store.iam_role,
213-
self.query,
214-
self._drop_columns,
215-
)
229+
with self._query_generator() as query:
230+
return aws_utils.unload_redshift_query_to_pa(
231+
self._redshift_client,
232+
self._config.offline_store.cluster_id,
233+
self._config.offline_store.database,
234+
self._config.offline_store.user,
235+
self._s3_resource,
236+
self._s3_path,
237+
self._config.offline_store.iam_role,
238+
query,
239+
self._drop_columns,
240+
)
216241

217242
def to_s3(self) -> str:
218243
""" Export dataset to S3 in Parquet format and return path """
219-
aws_utils.execute_redshift_query_and_unload_to_s3(
220-
self._redshift_client,
221-
self._config.offline_store.cluster_id,
222-
self._config.offline_store.database,
223-
self._config.offline_store.user,
224-
self._s3_path,
225-
self._config.offline_store.iam_role,
226-
self.query,
227-
self._drop_columns,
228-
)
229-
return self._s3_path
244+
with self._query_generator() as query:
245+
aws_utils.execute_redshift_query_and_unload_to_s3(
246+
self._redshift_client,
247+
self._config.offline_store.cluster_id,
248+
self._config.offline_store.database,
249+
self._config.offline_store.user,
250+
self._s3_path,
251+
self._config.offline_store.iam_role,
252+
query,
253+
self._drop_columns,
254+
)
255+
return self._s3_path
230256

231257
def to_redshift(self, table_name: str) -> None:
232258
""" Save dataset as a new Redshift table """
233-
query = f'CREATE TABLE "{table_name}" AS ({self.query});\n'
234-
if self._drop_columns is not None:
235-
for column in self._drop_columns:
236-
query += f"ALTER TABLE {table_name} DROP COLUMN {column};\n"
237-
238-
aws_utils.execute_redshift_statement(
239-
self._redshift_client,
240-
self._config.offline_store.cluster_id,
241-
self._config.offline_store.database,
242-
self._config.offline_store.user,
243-
query,
244-
)
259+
with self._query_generator() as query:
260+
query = f'CREATE TABLE "{table_name}" AS ({query});\n'
261+
if self._drop_columns is not None:
262+
for column in self._drop_columns:
263+
query += f"ALTER TABLE {table_name} DROP COLUMN {column};\n"
264+
265+
aws_utils.execute_redshift_statement(
266+
self._redshift_client,
267+
self._config.offline_store.cluster_id,
268+
self._config.offline_store.database,
269+
self._config.offline_store.user,
270+
query,
271+
)
245272

246273

247274
def _upload_entity_df_and_get_entity_schema(

sdk/python/feast/infra/utils/aws_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import tempfile
44
import uuid
5-
from typing import Generator, List, Optional, Tuple
5+
from typing import Iterator, List, Optional, Tuple
66

77
import pandas as pd
88
import pyarrow as pa
@@ -222,7 +222,7 @@ def temporarily_upload_df_to_redshift(
222222
iam_role: str,
223223
table_name: str,
224224
df: pd.DataFrame,
225-
) -> Generator[None, None, None]:
225+
) -> Iterator[None]:
226226
"""Uploads a Pandas DataFrame to Redshift as a new table with cleanup logic.
227227
228228
This is essentially the same as upload_df_to_redshift (check out its docstring for full details),

sdk/python/tests/integration/offline_store/test_historical_retrieval.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -864,18 +864,18 @@ def test_historical_features_from_redshift_sources(
864864
f"order_is_success, {timestamp_column} FROM {table_name}"
865865
)
866866
# Rename the join key; this should now raise an error.
867-
assertpy.assert_that(store.get_historical_features).raises(
868-
errors.FeastEntityDFMissingColumnsError
869-
).when_called_with(
870-
entity_df=entity_df_query_with_invalid_join_key,
871-
features=[
872-
"driver_stats:conv_rate",
873-
"driver_stats:avg_daily_trips",
874-
"customer_profile:current_balance",
875-
"customer_profile:avg_passenger_count",
876-
"customer_profile:lifetime_trip_count",
877-
],
878-
)
867+
assertpy.assert_that(
868+
store.get_historical_features(
869+
entity_df=entity_df_query_with_invalid_join_key,
870+
features=[
871+
"driver_stats:conv_rate",
872+
"driver_stats:avg_daily_trips",
873+
"customer_profile:current_balance",
874+
"customer_profile:avg_passenger_count",
875+
"customer_profile:lifetime_trip_count",
876+
],
877+
).to_df
878+
).raises(errors.FeastEntityDFMissingColumnsError).when_called_with()
879879

880880
job_from_df = store.get_historical_features(
881881
entity_df=orders_df,
@@ -893,18 +893,18 @@ def test_historical_features_from_redshift_sources(
893893
orders_df_with_invalid_join_key = orders_df.rename(
894894
{"customer_id": "customer"}, axis="columns"
895895
)
896-
assertpy.assert_that(store.get_historical_features).raises(
897-
errors.FeastEntityDFMissingColumnsError
898-
).when_called_with(
899-
entity_df=orders_df_with_invalid_join_key,
900-
features=[
901-
"driver_stats:conv_rate",
902-
"driver_stats:avg_daily_trips",
903-
"customer_profile:current_balance",
904-
"customer_profile:avg_passenger_count",
905-
"customer_profile:lifetime_trip_count",
906-
],
907-
)
896+
assertpy.assert_that(
897+
store.get_historical_features(
898+
entity_df=orders_df_with_invalid_join_key,
899+
features=[
900+
"driver_stats:conv_rate",
901+
"driver_stats:avg_daily_trips",
902+
"customer_profile:current_balance",
903+
"customer_profile:avg_passenger_count",
904+
"customer_profile:lifetime_trip_count",
905+
],
906+
).to_df
907+
).raises(errors.FeastEntityDFMissingColumnsError).when_called_with()
908908

909909
start_time = datetime.utcnow()
910910
actual_df_from_df_entities = job_from_df.to_df()

0 commit comments

Comments
 (0)