Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List, Literal, Optional, Sequence, Union, cast

import dill
import pandas
import pandas as pd
import pyarrow
from tqdm import tqdm
Expand Down Expand Up @@ -178,9 +179,9 @@ def _materialize_one(
self.repo_config.batch_engine.partitions
)

spark_df.foreachPartition(
lambda x: _process_by_partition(x, spark_serialized_artifacts)
)
spark_df.mapInPandas(
lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int"
).count() # dummy action to force evaluation

return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
Expand Down Expand Up @@ -225,38 +226,40 @@ def unserialize(self):
return feature_view, online_store, repo_config


def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
"""Load pandas df to online store"""

# convert to pyarrow table
dicts = []
for row in rows:
dicts.append(row.asDict())
def _map_by_partition(iterator, spark_serialized_artifacts: _SparkSerializedArtifacts):
for pdf in iterator:
if pdf.shape[0] == 0:
print("Skipping")
return

df = pd.DataFrame.from_records(dicts)
if df.shape[0] == 0:
print("Skipping")
return
table = pyarrow.Table.from_pandas(pdf)

table = pyarrow.Table.from_pandas(df)
(
feature_view,
online_store,
repo_config,
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

# unserialize artifacts
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()
join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
yield pd.DataFrame(
[pd.Series(range(1, 2))]
) # dummy result because mapInPandas needs to return something