1- import contextlib
21import logging
32import uuid
43from datetime import date , datetime
54from typing import (
65 Any ,
7- Callable ,
8- ContextManager ,
96 Dict ,
10- Iterator ,
117 List ,
128 Literal ,
139 Optional ,
@@ -192,28 +188,22 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
192188class TrinoRetrievalJob (RetrievalJob ):
193189 def __init__ (
194190 self ,
195- query : Union [ str , Callable [[], ContextManager [ str ]]] ,
191+ query : str ,
196192 client : Trino ,
197193 config : RepoConfig ,
198194 full_feature_names : bool ,
199195 on_demand_feature_views : Optional [List [OnDemandFeatureView ]] = None ,
200196 metadata : Optional [RetrievalMetadata ] = None ,
197+ temp_table : Optional [str ] = None ,
201198 ):
202- if not isinstance (query , str ):
203- self ._query_generator = query
204- else :
205-
206- @contextlib .contextmanager
207- def query_generator () -> Iterator [str ]:
208- assert isinstance (query , str )
209- yield query
210-
211- self ._query_generator = query_generator
199+ self ._query = query
212200 self ._client = client
213201 self ._config = config
214202 self ._full_feature_names = full_feature_names
215203 self ._on_demand_feature_views = on_demand_feature_views or []
216204 self ._metadata = metadata
205+ self ._temp_table = temp_table
206+ self ._cleaned_up = False
217207
218208 @property
219209 def full_feature_names (self ) -> bool :
@@ -223,21 +213,37 @@ def full_feature_names(self) -> bool:
223213 def on_demand_feature_views (self ) -> List [OnDemandFeatureView ]:
224214 return self ._on_demand_feature_views
225215
216+ def _drop_temp_table (self ) -> None :
217+ if self ._cleaned_up or not self ._temp_table :
218+ return
219+ self ._cleaned_up = True
220+ try :
221+ self ._client .execute_query (f"DROP TABLE IF EXISTS { self ._temp_table } " )
222+ except Exception :
223+ logger .exception (
224+ "Failed to drop temporary entity table %s" ,
225+ self ._temp_table ,
226+ )
227+
228+ def __del__ (self ) -> None :
229+ self ._drop_temp_table ()
230+
226231 def _to_df_internal (self , timeout : Optional [int ] = None ) -> pd .DataFrame :
227232 """Return dataset as Pandas DataFrame synchronously including on demand transforms"""
228- with self . _query_generator () as query :
229- results = self ._client .execute_query (query_text = query )
233+ try :
234+ results = self ._client .execute_query (query_text = self . _query )
230235 self .pyarrow_schema = results .pyarrow_schema
231236 return results .to_dataframe ()
237+ finally :
238+ self ._drop_temp_table ()
232239
233240 def _to_arrow_internal (self , timeout : Optional [int ] = None ) -> pyarrow .Table :
234241 """Return payrrow dataset as synchronously including on demand transforms"""
235242 return pyarrow .Table .from_pandas (self ._to_df_internal (timeout = timeout ))
236243
237244 def to_sql (self ) -> str :
238245 """Returns the SQL query that will be executed in Trino to build the historical feature table"""
239- with self ._query_generator () as query :
240- return query
246+ return self ._query
241247
242248 def to_trino (
243249 self ,
@@ -260,9 +266,11 @@ def to_trino(
260266 destination_table = f"{ self ._client .catalog } .{ self ._config .offline_store .dataset } .historical_{ today } _{ rand_id } "
261267
262268 # TODO: Implement the timeout logic
263- with self . _query_generator () as query :
264- create_query = f"CREATE TABLE { destination_table } AS ({ query } )"
269+ try :
270+ create_query = f"CREATE TABLE { destination_table } AS ({ self . _query } )"
265271 self ._client .execute_query (query_text = create_query )
272+ finally :
273+ self ._drop_temp_table ()
266274 return destination_table
267275
268276 def persist (
@@ -411,22 +419,9 @@ def get_historical_features(
411419 full_feature_names = full_feature_names ,
412420 )
413421
414- @contextlib .contextmanager
415- def query_generator () -> Iterator [str ]:
416- try :
417- yield query
418- finally :
419- if isinstance (entity_df , pd .DataFrame ):
420- try :
421- client .execute_query (f"DROP TABLE IF EXISTS { table_reference } " )
422- except Exception :
423- logger .exception (
424- "Failed to drop temporary entity table %s" ,
425- table_reference ,
426- )
427-
428422 return TrinoRetrievalJob (
429- query = query_generator ,
423+ query = query ,
424+ temp_table = table_reference if isinstance (entity_df , pd .DataFrame ) else None ,
430425 client = client ,
431426 config = config ,
432427 full_feature_names = full_feature_names ,
0 commit comments