1616from feast .infra .compute_engines .dag .node import DAGNode
1717from feast .infra .compute_engines .dag .value import DAGValue
1818from feast .infra .compute_engines .flink .utils import (
19- flink_table_to_pandas ,
19+ flink_table_to_arrow_batches ,
2020 pandas_to_flink_table ,
21+ register_flink_temporary_view ,
2122)
2223from feast .infra .compute_engines .utils import create_offline_store_retrieval_job
2324from feast .infra .offline_stores .offline_utils import (
@@ -47,15 +48,15 @@ def _select_column(alias: str, column: str, output_name: Optional[str] = None) -
4748 return expr
4849
4950
50- def _flink_interval_literal (value : timedelta ) -> str :
51+ def _flink_interval_literals (value : timedelta ) -> List [ str ] :
5152 total_seconds = int (value .total_seconds ())
5253 if total_seconds <= 0 :
53- return "INTERVAL '0' SECOND"
54+ return [ "INTERVAL '0' SECOND" ]
5455
5556 days , remainder = divmod (total_seconds , 24 * 60 * 60 )
5657 hours , remainder = divmod (remainder , 60 * 60 )
5758 minutes , seconds = divmod (remainder , 60 )
58- parts = []
59+ parts : List [ str ] = []
5960 if days :
6061 parts .append (f"INTERVAL '{ days } ' DAY" )
6162 if hours :
@@ -64,7 +65,14 @@ def _flink_interval_literal(value: timedelta) -> str:
6465 parts .append (f"INTERVAL '{ minutes } ' MINUTE" )
6566 if seconds :
6667 parts .append (f"INTERVAL '{ seconds } ' SECOND" )
67- return " + " .join (parts )
68+ return parts
69+
70+
71+ def _subtract_flink_intervals (timestamp_expr : str , value : timedelta ) -> str :
72+ result = timestamp_expr
73+ for interval in _flink_interval_literals (value ):
74+ result = f"{ result } - { interval } "
75+ return result
6876
6977
7078def _get_columns_from_schema (table : Any ) -> Optional [List [str ]]:
@@ -107,6 +115,7 @@ def _require_sql(table_env: Any, node_name: str) -> None:
107115def _register_table (table_env : Any , table : Any , prefix : str ) -> str :
108116 view_name = f"__feast_{ prefix } _{ uuid .uuid4 ().hex } "
109117 table_env .create_temporary_view (view_name , table )
118+ register_flink_temporary_view (table_env , view_name )
110119 return view_name
111120
112121
@@ -447,11 +456,11 @@ def _execute_sql_filter(self, input_value: DAGValue) -> DAGValue:
447456 f"{ _quote_identifier (ENTITY_TS_ALIAS )} "
448457 )
449458 if self .ttl :
450- ttl_interval = _flink_interval_literal (self .ttl )
459+ lower_bound = _subtract_flink_intervals (
460+ _quote_identifier (ENTITY_TS_ALIAS ), self .ttl
461+ )
451462 conditions .append (
452- f"{ _quote_identifier (timestamp_column )} >= "
453- f"{ _quote_identifier (ENTITY_TS_ALIAS )} - "
454- f"({ ttl_interval } )"
463+ f"{ _quote_identifier (timestamp_column )} >= { lower_bound } "
455464 )
456465
457466 if self .filter_expr :
@@ -708,44 +717,49 @@ def execute(self, context: ExecutionContext) -> DAGValue:
708717 if not self .write_output :
709718 return output_value
710719
711- output_df = flink_table_to_pandas (output_table )
712- output_arrow = pa .Table .from_pandas (output_df )
713-
714- if output_arrow .num_rows == 0 :
715- return output_value
716-
720+ columns = _get_columns (output_value )
721+ batch_size = context .repo_config .materialization_config .online_write_batch_size
717722 if self .feature_view .online :
718723 join_key_to_value_type = {
719724 entity .name : entity .dtype .to_value_type ()
720725 for entity in self .feature_view .entity_columns
721726 }
722- batch_size = (
723- context .repo_config .materialization_config .online_write_batch_size
724- )
725- batches = (
726- [output_arrow ]
727- if batch_size is None
728- else output_arrow .to_batches (max_chunksize = batch_size )
729- )
730- for batch in batches :
731- rows_to_write = _convert_arrow_to_proto (
732- batch , self .feature_view , join_key_to_value_type
727+ else :
728+ join_key_to_value_type = {}
729+
730+ for output_arrow in flink_table_to_arrow_batches (
731+ output_table ,
732+ columns ,
733+ batch_size ,
734+ ):
735+ if output_arrow .num_rows == 0 :
736+ continue
737+
738+ if self .feature_view .online :
739+ arrow_batches = (
740+ [output_arrow ]
741+ if batch_size is None
742+ else output_arrow .to_batches (max_chunksize = batch_size )
733743 )
734- context .online_store .online_write_batch (
744+ for batch in arrow_batches :
745+ rows_to_write = _convert_arrow_to_proto (
746+ batch , self .feature_view , join_key_to_value_type
747+ )
748+ context .online_store .online_write_batch (
749+ config = context .repo_config ,
750+ table = self .feature_view ,
751+ data = rows_to_write ,
752+ progress = lambda x : None ,
753+ )
754+
755+ if self .feature_view .offline :
756+ context .offline_store .offline_write_batch (
735757 config = context .repo_config ,
736- table = self .feature_view ,
737- data = rows_to_write ,
758+ feature_view = self .feature_view ,
759+ table = output_arrow ,
738760 progress = lambda x : None ,
739761 )
740762
741- if self .feature_view .offline :
742- context .offline_store .offline_write_batch (
743- config = context .repo_config ,
744- feature_view = self .feature_view ,
745- table = output_arrow ,
746- progress = lambda x : None ,
747- )
748-
749763 return output_value
750764
751765 def _drop_internal_columns (self , input_value : DAGValue ) -> DAGValue :
0 commit comments