Skip to content

Commit d22bb4a

Browse files
committed
fix(trino): Clean up temporary entity tables after retrieval
TrinoOfflineStore.get_historical_features() creates a temporary table for the entity DataFrame but never drops it, leaking tables indefinitely. Apply the same context manager pattern used by BigQuery, Redshift, and Athena offline stores: wrap the query in a generator that issues DROP TABLE IF EXISTS in a finally block. Fixes #6306 Signed-off-by: Jonathan Wrede <wrede.jonathan00@gmail.com>
1 parent 0d51c93 commit d22bb4a

1 file changed

Lines changed: 45 additions & 14 deletions

File tree

  • sdk/python/feast/infra/offline_stores/contrib/trino_offline_store

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import contextlib
2+
import logging
13
import uuid
24
from datetime import date, datetime
3-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
5+
from typing import Any, Callable, ContextManager, Dict, Iterator, List, Literal, Optional, Tuple, Union
46

57
import numpy as np
68
import pandas as pd
@@ -37,6 +39,8 @@
3739
from feast.repo_config import FeastConfigBaseModel, RepoConfig
3840
from feast.saved_dataset import SavedDatasetStorage
3941

42+
logger = logging.getLogger(__name__)
43+
4044

4145
class BasicAuthModel(FeastConfigBaseModel):
4246
username: StrictStr
@@ -177,14 +181,23 @@ class TrinoOfflineStoreConfig(FeastConfigBaseModel):
177181
class TrinoRetrievalJob(RetrievalJob):
178182
def __init__(
179183
self,
180-
query: str,
184+
query: Union[str, Callable[[], ContextManager[str]]],
181185
client: Trino,
182186
config: RepoConfig,
183187
full_feature_names: bool,
184188
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
185189
metadata: Optional[RetrievalMetadata] = None,
186190
):
187-
self._query = query
191+
if not isinstance(query, str):
192+
self._query_generator = query
193+
else:
194+
195+
@contextlib.contextmanager
196+
def query_generator() -> Iterator[str]:
197+
assert isinstance(query, str)
198+
yield query
199+
200+
self._query_generator = query_generator
188201
self._client = client
189202
self._config = config
190203
self._full_feature_names = full_feature_names
@@ -201,17 +214,19 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
201214

202215
def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
203216
"""Return dataset as Pandas DataFrame synchronously including on demand transforms"""
204-
results = self._client.execute_query(query_text=self._query)
205-
self.pyarrow_schema = results.pyarrow_schema
206-
return results.to_dataframe()
217+
with self._query_generator() as query:
218+
results = self._client.execute_query(query_text=query)
219+
self.pyarrow_schema = results.pyarrow_schema
220+
return results.to_dataframe()
207221

208222
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
209223
"""Return payrrow dataset as synchronously including on demand transforms"""
210224
return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout))
211225

212226
def to_sql(self) -> str:
213227
"""Returns the SQL query that will be executed in Trino to build the historical feature table"""
214-
return self._query
228+
with self._query_generator() as query:
229+
return query
215230

216231
def to_trino(
217232
self,
@@ -234,8 +249,9 @@ def to_trino(
234249
destination_table = f"{self._client.catalog}.{self._config.offline_store.dataset}.historical_{today}_{rand_id}"
235250

236251
# TODO: Implement the timeout logic
237-
query = f"CREATE TABLE {destination_table} AS ({self._query})"
238-
self._client.execute_query(query_text=query)
252+
with self._query_generator() as query:
253+
create_query = f"CREATE TABLE {destination_table} AS ({query})"
254+
self._client.execute_query(query_text=create_query)
239255
return destination_table
240256

241257
def persist(
@@ -372,19 +388,36 @@ def get_historical_features(
372388
)
373389

374390
# Generate the Trino SQL query from the query context
391+
entity_table_ref = table_reference
375392
if type(entity_df) is str:
376-
table_reference = f"({entity_df})"
393+
entity_table_ref = f"({entity_df})"
377394
query = offline_utils.build_point_in_time_query(
378395
query_context,
379-
left_table_query_string=table_reference,
396+
left_table_query_string=entity_table_ref,
380397
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
381398
entity_df_columns=entity_schema.keys(),
382399
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
383400
full_feature_names=full_feature_names,
384401
)
385402

403+
@contextlib.contextmanager
404+
def query_generator() -> Iterator[str]:
405+
try:
406+
yield query
407+
finally:
408+
if isinstance(entity_df, pd.DataFrame):
409+
try:
410+
client.execute_query(
411+
f"DROP TABLE IF EXISTS {table_reference}"
412+
)
413+
except Exception:
414+
logger.exception(
415+
"Failed to drop temporary entity table %s",
416+
table_reference,
417+
)
418+
386419
return TrinoRetrievalJob(
387-
query=query,
420+
query=query_generator,
388421
client=client,
389422
config=config,
390423
full_feature_names=full_feature_names,
@@ -483,8 +516,6 @@ def _upload_entity_df_and_get_entity_schema(
483516
else:
484517
raise InvalidEntityType(type(entity_df))
485518

486-
# TODO: Ensure that the table expires after some time
487-
488519

489520
def _get_trino_client(config: RepoConfig) -> Trino:
490521
auth = None

0 commit comments

Comments
 (0)