Skip to content

Commit c0201ad

Browse files
committed
fix: build Spark DataFrame from Arrow with schema and empty handling(#5594)
Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com>
1 parent 59dbb33 commit c0201ad

File tree

1 file changed

+10
-2
lines changed
  • sdk/python/feast/infra/compute_engines/spark

1 file changed

+10
-2
lines changed

sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
from pyspark.sql import DataFrame, SparkSession, Window
66
from pyspark.sql import functions as F
7+
from pyspark.sql.pandas.types import from_arrow_schema
78

89
from feast import BatchFeatureView, StreamFeatureView
910
from feast.aggregation import Aggregation
@@ -80,7 +81,15 @@ def execute(self, context: ExecutionContext) -> DAGValue:
8081
if isinstance(retrieval_job, SparkRetrievalJob):
8182
spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()
8283
else:
83-
spark_df = self.spark_session.createDataFrame(retrieval_job.to_arrow())
84+
arrow_table = retrieval_job.to_arrow()
85+
if arrow_table.num_rows == 0:
86+
spark_schema = from_arrow_schema(arrow_table.schema)
87+
spark_df = self.spark_session.createDataFrame(
88+
self.spark_session.sparkContext.emptyRDD(), schema=spark_schema
89+
)
90+
else:
91+
spark_df = self.spark_session.createDataFrame(arrow_table.to_pandas())
92+
8493

8594
return DAGValue(
8695
data=spark_df,
@@ -94,7 +103,6 @@ def execute(self, context: ExecutionContext) -> DAGValue:
94103
},
95104
)
96105

97-
98106
class SparkAggregationNode(DAGNode):
99107
def __init__(
100108
self,

0 commit comments

Comments
 (0)