Skip to content

Commit 3ddef73

Browse files
Jwredentkathole
authored andcommitted
fix: decouple temp table cleanup from query access
Avoid dropping the temporary entity table on to_sql() calls. Previously, every method used a context manager that dropped the table on exit, so calling to_sql() before to_df() would destroy the table and cause subsequent queries to fail. Now the query is stored as a plain string and cleanup is handled by a dedicated _drop_temp_table() method called only after query execution (to_df, to_trino). A __del__ fallback ensures cleanup if execution methods are never called. The _cleaned_up flag makes the drop idempotent. Signed-off-by: Jonathan Wrede <wrede.jonathan00@gmail.com>
1 parent f1008f2 commit 3ddef73

1 file changed

Lines changed: 31 additions & 36 deletions

File tree

  • sdk/python/feast/infra/offline_stores/contrib/trino_offline_store

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
import contextlib
21
import logging
32
import uuid
43
from datetime import date, datetime
54
from typing import (
65
Any,
7-
Callable,
8-
ContextManager,
96
Dict,
10-
Iterator,
117
List,
128
Literal,
139
Optional,
@@ -192,28 +188,22 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
192188
class TrinoRetrievalJob(RetrievalJob):
193189
def __init__(
194190
self,
195-
query: Union[str, Callable[[], ContextManager[str]]],
191+
query: str,
196192
client: Trino,
197193
config: RepoConfig,
198194
full_feature_names: bool,
199195
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
200196
metadata: Optional[RetrievalMetadata] = None,
197+
temp_table: Optional[str] = None,
201198
):
202-
if not isinstance(query, str):
203-
self._query_generator = query
204-
else:
205-
206-
@contextlib.contextmanager
207-
def query_generator() -> Iterator[str]:
208-
assert isinstance(query, str)
209-
yield query
210-
211-
self._query_generator = query_generator
199+
self._query = query
212200
self._client = client
213201
self._config = config
214202
self._full_feature_names = full_feature_names
215203
self._on_demand_feature_views = on_demand_feature_views or []
216204
self._metadata = metadata
205+
self._temp_table = temp_table
206+
self._cleaned_up = False
217207

218208
@property
219209
def full_feature_names(self) -> bool:
@@ -223,21 +213,37 @@ def full_feature_names(self) -> bool:
223213
def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
224214
return self._on_demand_feature_views
225215

216+
def _drop_temp_table(self) -> None:
217+
if self._cleaned_up or not self._temp_table:
218+
return
219+
self._cleaned_up = True
220+
try:
221+
self._client.execute_query(f"DROP TABLE IF EXISTS {self._temp_table}")
222+
except Exception:
223+
logger.exception(
224+
"Failed to drop temporary entity table %s",
225+
self._temp_table,
226+
)
227+
228+
def __del__(self) -> None:
229+
self._drop_temp_table()
230+
226231
def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
227232
"""Return dataset as Pandas DataFrame synchronously including on demand transforms"""
228-
with self._query_generator() as query:
229-
results = self._client.execute_query(query_text=query)
233+
try:
234+
results = self._client.execute_query(query_text=self._query)
230235
self.pyarrow_schema = results.pyarrow_schema
231236
return results.to_dataframe()
237+
finally:
238+
self._drop_temp_table()
232239

233240
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
234241
"""Return payrrow dataset as synchronously including on demand transforms"""
235242
return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout))
236243

237244
def to_sql(self) -> str:
238245
"""Returns the SQL query that will be executed in Trino to build the historical feature table"""
239-
with self._query_generator() as query:
240-
return query
246+
return self._query
241247

242248
def to_trino(
243249
self,
@@ -260,9 +266,11 @@ def to_trino(
260266
destination_table = f"{self._client.catalog}.{self._config.offline_store.dataset}.historical_{today}_{rand_id}"
261267

262268
# TODO: Implement the timeout logic
263-
with self._query_generator() as query:
264-
create_query = f"CREATE TABLE {destination_table} AS ({query})"
269+
try:
270+
create_query = f"CREATE TABLE {destination_table} AS ({self._query})"
265271
self._client.execute_query(query_text=create_query)
272+
finally:
273+
self._drop_temp_table()
266274
return destination_table
267275

268276
def persist(
@@ -411,22 +419,9 @@ def get_historical_features(
411419
full_feature_names=full_feature_names,
412420
)
413421

414-
@contextlib.contextmanager
415-
def query_generator() -> Iterator[str]:
416-
try:
417-
yield query
418-
finally:
419-
if isinstance(entity_df, pd.DataFrame):
420-
try:
421-
client.execute_query(f"DROP TABLE IF EXISTS {table_reference}")
422-
except Exception:
423-
logger.exception(
424-
"Failed to drop temporary entity table %s",
425-
table_reference,
426-
)
427-
428422
return TrinoRetrievalJob(
429-
query=query_generator,
423+
query=query,
424+
temp_table=table_reference if isinstance(entity_df, pd.DataFrame) else None,
430425
client=client,
431426
config=config,
432427
full_feature_names=full_feature_names,

0 commit comments

Comments
 (0)