1+ import contextlib
12import uuid
23from datetime import datetime
3- from typing import Dict , List , Optional , Union
4+ from typing import Callable , ContextManager , Dict , Iterator , List , Optional , Union
45
56import numpy as np
67import pandas as pd
@@ -113,40 +114,53 @@ def get_historical_features(
113114 )
114115 s3_resource = aws_utils .get_s3_resource (config .offline_store .region )
115116
116- table_name = offline_utils .get_temp_entity_table_name ()
117+ @contextlib .contextmanager
118+ def query_generator () -> Iterator [str ]:
119+ table_name = offline_utils .get_temp_entity_table_name ()
117120
118- entity_schema = _upload_entity_df_and_get_entity_schema (
119- entity_df , redshift_client , config , s3_resource , table_name
120- )
121+ entity_schema = _upload_entity_df_and_get_entity_schema (
122+ entity_df , redshift_client , config , s3_resource , table_name
123+ )
121124
122- entity_df_event_timestamp_col = offline_utils .infer_event_timestamp_from_entity_df (
123- entity_schema
124- )
125+ entity_df_event_timestamp_col = offline_utils .infer_event_timestamp_from_entity_df (
126+ entity_schema
127+ )
125128
126- expected_join_keys = offline_utils .get_expected_join_keys (
127- project , feature_views , registry
128- )
129+ expected_join_keys = offline_utils .get_expected_join_keys (
130+ project , feature_views , registry
131+ )
129132
130- offline_utils .assert_expected_columns_in_entity_df (
131- entity_schema , expected_join_keys , entity_df_event_timestamp_col
132- )
133+ offline_utils .assert_expected_columns_in_entity_df (
134+ entity_schema , expected_join_keys , entity_df_event_timestamp_col
135+ )
133136
134- # Build a query context containing all information required to template the Redshift SQL query
135- query_context = offline_utils .get_feature_view_query_context (
136- feature_refs , feature_views , registry , project ,
137- )
137+ # Build a query context containing all information required to template the Redshift SQL query
138+ query_context = offline_utils .get_feature_view_query_context (
139+ feature_refs , feature_views , registry , project ,
140+ )
138141
139- # Generate the Redshift SQL query from the query context
140- query = offline_utils .build_point_in_time_query (
141- query_context ,
142- left_table_query_string = table_name ,
143- entity_df_event_timestamp_col = entity_df_event_timestamp_col ,
144- query_template = MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN ,
145- full_feature_names = full_feature_names ,
146- )
142+ # Generate the Redshift SQL query from the query context
143+ query = offline_utils .build_point_in_time_query (
144+ query_context ,
145+ left_table_query_string = table_name ,
146+ entity_df_event_timestamp_col = entity_df_event_timestamp_col ,
147+ query_template = MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN ,
148+ full_feature_names = full_feature_names ,
149+ )
150+
151+ yield query
152+
153+ # Clean up the uploaded Redshift table
154+ aws_utils .execute_redshift_statement (
155+ redshift_client ,
156+ config .offline_store .cluster_id ,
157+ config .offline_store .database ,
158+ config .offline_store .user ,
159+ f"DROP TABLE { table_name } " ,
160+ )
147161
148162 return RedshiftRetrievalJob (
149- query = query ,
163+ query = query_generator ,
150164 redshift_client = redshift_client ,
151165 s3_resource = s3_resource ,
152166 config = config ,
@@ -161,7 +175,7 @@ def get_historical_features(
161175class RedshiftRetrievalJob (RetrievalJob ):
162176 def __init__ (
163177 self ,
164- query : str ,
178+ query : Union [ str , Callable [[], ContextManager [ str ]]] ,
165179 redshift_client ,
166180 s3_resource ,
167181 config : RepoConfig ,
@@ -170,14 +184,23 @@ def __init__(
170184 """Initialize RedshiftRetrievalJob object.
171185
172186 Args:
173- query: Redshift SQL query to execute.
187+ query: Redshift SQL query to execute. Either a string, or a generator function that handles the artifact cleanup.
174188 redshift_client: boto3 redshift-data client
175189 s3_resource: boto3 s3 resource object
176190 config: Feast repo config
177191 drop_columns: Optionally a list of columns to drop before unloading to S3.
178192 This is a convenient field, since "SELECT ... EXCEPT col" isn't supported in Redshift.
179193 """
180- self .query = query
194+ if not isinstance (query , str ):
195+ self ._query_generator = query
196+ else :
197+
198+ @contextlib .contextmanager
199+ def query_generator () -> Iterator [str ]:
200+ assert isinstance (query , str )
201+ yield query
202+
203+ self ._query_generator = query_generator
181204 self ._redshift_client = redshift_client
182205 self ._s3_resource = s3_resource
183206 self ._config = config
@@ -189,59 +212,63 @@ def __init__(
189212 self ._drop_columns = drop_columns
190213
191214 def to_df (self ) -> pd .DataFrame :
192- return aws_utils .unload_redshift_query_to_df (
193- self ._redshift_client ,
194- self ._config .offline_store .cluster_id ,
195- self ._config .offline_store .database ,
196- self ._config .offline_store .user ,
197- self ._s3_resource ,
198- self ._s3_path ,
199- self ._config .offline_store .iam_role ,
200- self .query ,
201- self ._drop_columns ,
202- )
215+ with self ._query_generator () as query :
216+ return aws_utils .unload_redshift_query_to_df (
217+ self ._redshift_client ,
218+ self ._config .offline_store .cluster_id ,
219+ self ._config .offline_store .database ,
220+ self ._config .offline_store .user ,
221+ self ._s3_resource ,
222+ self ._s3_path ,
223+ self ._config .offline_store .iam_role ,
224+ query ,
225+ self ._drop_columns ,
226+ )
203227
204228 def to_arrow (self ) -> pa .Table :
205- return aws_utils .unload_redshift_query_to_pa (
206- self ._redshift_client ,
207- self ._config .offline_store .cluster_id ,
208- self ._config .offline_store .database ,
209- self ._config .offline_store .user ,
210- self ._s3_resource ,
211- self ._s3_path ,
212- self ._config .offline_store .iam_role ,
213- self .query ,
214- self ._drop_columns ,
215- )
229+ with self ._query_generator () as query :
230+ return aws_utils .unload_redshift_query_to_pa (
231+ self ._redshift_client ,
232+ self ._config .offline_store .cluster_id ,
233+ self ._config .offline_store .database ,
234+ self ._config .offline_store .user ,
235+ self ._s3_resource ,
236+ self ._s3_path ,
237+ self ._config .offline_store .iam_role ,
238+ query ,
239+ self ._drop_columns ,
240+ )
216241
217242 def to_s3 (self ) -> str :
218243 """ Export dataset to S3 in Parquet format and return path """
219- aws_utils .execute_redshift_query_and_unload_to_s3 (
220- self ._redshift_client ,
221- self ._config .offline_store .cluster_id ,
222- self ._config .offline_store .database ,
223- self ._config .offline_store .user ,
224- self ._s3_path ,
225- self ._config .offline_store .iam_role ,
226- self .query ,
227- self ._drop_columns ,
228- )
229- return self ._s3_path
244+ with self ._query_generator () as query :
245+ aws_utils .execute_redshift_query_and_unload_to_s3 (
246+ self ._redshift_client ,
247+ self ._config .offline_store .cluster_id ,
248+ self ._config .offline_store .database ,
249+ self ._config .offline_store .user ,
250+ self ._s3_path ,
251+ self ._config .offline_store .iam_role ,
252+ query ,
253+ self ._drop_columns ,
254+ )
255+ return self ._s3_path
230256
231257 def to_redshift (self , table_name : str ) -> None :
232258 """ Save dataset as a new Redshift table """
233- query = f'CREATE TABLE "{ table_name } " AS ({ self .query } );\n '
234- if self ._drop_columns is not None :
235- for column in self ._drop_columns :
236- query += f"ALTER TABLE { table_name } DROP COLUMN { column } ;\n "
237-
238- aws_utils .execute_redshift_statement (
239- self ._redshift_client ,
240- self ._config .offline_store .cluster_id ,
241- self ._config .offline_store .database ,
242- self ._config .offline_store .user ,
243- query ,
244- )
259+ with self ._query_generator () as query :
260+ query = f'CREATE TABLE "{ table_name } " AS ({ query } );\n '
261+ if self ._drop_columns is not None :
262+ for column in self ._drop_columns :
263+ query += f"ALTER TABLE { table_name } DROP COLUMN { column } ;\n "
264+
265+ aws_utils .execute_redshift_statement (
266+ self ._redshift_client ,
267+ self ._config .offline_store .cluster_id ,
268+ self ._config .offline_store .database ,
269+ self ._config .offline_store .user ,
270+ query ,
271+ )
245272
246273
247274def _upload_entity_df_and_get_entity_schema (
0 commit comments