1+ import contextlib
2+ import logging
13import uuid
24from datetime import date , datetime
3- from typing import Any , Dict , List , Literal , Optional , Tuple , Union
5+ from typing import Any , Callable , ContextManager , Dict , Iterator , List , Literal , Optional , Tuple , Union
46
57import numpy as np
68import pandas as pd
3739from feast .repo_config import FeastConfigBaseModel , RepoConfig
3840from feast .saved_dataset import SavedDatasetStorage
3941
42+ logger = logging .getLogger (__name__ )
43+
4044
4145class BasicAuthModel (FeastConfigBaseModel ):
4246 username : StrictStr
@@ -177,14 +181,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
177181class TrinoRetrievalJob (RetrievalJob ):
178182 def __init__ (
179183 self ,
180- query : str ,
184+ query : Union [ str , Callable [[], ContextManager [ str ]]] ,
181185 client : Trino ,
182186 config : RepoConfig ,
183187 full_feature_names : bool ,
184188 on_demand_feature_views : Optional [List [OnDemandFeatureView ]] = None ,
185189 metadata : Optional [RetrievalMetadata ] = None ,
186190 ):
187- self ._query = query
191+ if not isinstance (query , str ):
192+ self ._query_generator = query
193+ else :
194+
195+ @contextlib .contextmanager
196+ def query_generator () -> Iterator [str ]:
197+ assert isinstance (query , str )
198+ yield query
199+
200+ self ._query_generator = query_generator
188201 self ._client = client
189202 self ._config = config
190203 self ._full_feature_names = full_feature_names
@@ -201,17 +214,19 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
201214
202215 def _to_df_internal (self , timeout : Optional [int ] = None ) -> pd .DataFrame :
203216 """Return dataset as Pandas DataFrame synchronously including on demand transforms"""
204- results = self ._client .execute_query (query_text = self ._query )
205- self .pyarrow_schema = results .pyarrow_schema
206- return results .to_dataframe ()
217+ with self ._query_generator () as query :
218+ results = self ._client .execute_query (query_text = query )
219+ self .pyarrow_schema = results .pyarrow_schema
220+ return results .to_dataframe ()
207221
208222 def _to_arrow_internal (self , timeout : Optional [int ] = None ) -> pyarrow .Table :
209223 """Return payrrow dataset as synchronously including on demand transforms"""
210224 return pyarrow .Table .from_pandas (self ._to_df_internal (timeout = timeout ))
211225
212226 def to_sql (self ) -> str :
213227 """Returns the SQL query that will be executed in Trino to build the historical feature table"""
214- return self ._query
228+ with self ._query_generator () as query :
229+ return query
215230
216231 def to_trino (
217232 self ,
@@ -234,8 +249,9 @@ def to_trino(
234249 destination_table = f"{ self ._client .catalog } .{ self ._config .offline_store .dataset } .historical_{ today } _{ rand_id } "
235250
236251 # TODO: Implement the timeout logic
237- query = f"CREATE TABLE { destination_table } AS ({ self ._query } )"
238- self ._client .execute_query (query_text = query )
252+ with self ._query_generator () as query :
253+ create_query = f"CREATE TABLE { destination_table } AS ({ query } )"
254+ self ._client .execute_query (query_text = create_query )
239255 return destination_table
240256
241257 def persist (
@@ -372,19 +388,36 @@ def get_historical_features(
372388 )
373389
374390 # Generate the Trino SQL query from the query context
391+ entity_table_ref = table_reference
375392 if type (entity_df ) is str :
376- table_reference = f"({ entity_df } )"
393+ entity_table_ref = f"({ entity_df } )"
377394 query = offline_utils .build_point_in_time_query (
378395 query_context ,
379- left_table_query_string = table_reference ,
396+ left_table_query_string = entity_table_ref ,
380397 entity_df_event_timestamp_col = entity_df_event_timestamp_col ,
381398 entity_df_columns = entity_schema .keys (),
382399 query_template = MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN ,
383400 full_feature_names = full_feature_names ,
384401 )
385402
403+ @contextlib .contextmanager
404+ def query_generator () -> Iterator [str ]:
405+ try :
406+ yield query
407+ finally :
408+ if isinstance (entity_df , pd .DataFrame ):
409+ try :
410+ client .execute_query (
411+ f"DROP TABLE IF EXISTS { table_reference } "
412+ )
413+ except Exception :
414+ logger .exception (
415+ "Failed to drop temporary entity table %s" ,
416+ table_reference ,
417+ )
418+
386419 return TrinoRetrievalJob (
387- query = query ,
420+ query = query_generator ,
388421 client = client ,
389422 config = config ,
390423 full_feature_names = full_feature_names ,
@@ -483,8 +516,6 @@ def _upload_entity_df_and_get_entity_schema(
483516 else :
484517 raise InvalidEntityType (type (entity_df ))
485518
486- # TODO: Ensure that the table expires after some time
487-
488519
489520def _get_trino_client (config : RepoConfig ) -> Trino :
490521 auth = None
0 commit comments