Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions protos/feast/core/FeatureView.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ message FeatureViewSpec {
DataSource stream_source = 9;

// Whether these features should be served online or not
// This is also used to determine whether the features should be written to the online store
bool online = 8;

// Whether these features should be written to the offline store
bool offline = 13;
}

message FeatureViewMeta {
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __copy__(self):
schema=self.schema,
tags=self.tags,
online=self.online,
offline=self.offline,
)

# This is deliberately set outside of the FV initialization as we do not have the Entity objects.
Expand All @@ -258,6 +259,7 @@ def __eq__(self, other):
sorted(self.entities) != sorted(other.entities)
or self.ttl != other.ttl
or self.online != other.online
or self.offline != other.offline
or self.batch_source != other.batch_source
or self.stream_source != other.stream_source
or sorted(self.entity_columns) != sorted(other.entity_columns)
Expand Down Expand Up @@ -363,6 +365,7 @@ def to_proto(self) -> FeatureViewProto:
owner=self.owner,
ttl=(ttl_duration if ttl_duration is not None else None),
online=self.online,
offline=self.offline,
batch_source=batch_source_proto,
stream_source=stream_source_proto,
)
Expand Down Expand Up @@ -412,6 +415,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto):
tags=dict(feature_view_proto.spec.tags),
owner=feature_view_proto.spec.owner,
online=feature_view_proto.spec.online,
offline=feature_view_proto.spec.offline,
ttl=(
timedelta(days=0)
if feature_view_proto.spec.ttl.ToNanoseconds() == 0
Expand Down
41 changes: 41 additions & 0 deletions sdk/python/feast/infra/common/serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass

import dill

from feast import FeatureView
from feast.infra.passthrough_provider import PassthroughProvider
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto


@dataclass
class SerializedArtifacts:
"""Class to assist with serializing unpicklable artifacts to be passed to the compute engine."""

feature_view_proto: str
repo_config_byte: str

@classmethod
def serialize(cls, feature_view, repo_config):
# serialize to proto
feature_view_proto = feature_view.to_proto().SerializeToString()

# serialize repo_config to disk. Will be used to instantiate the online store
repo_config_byte = dill.dumps(repo_config)

return SerializedArtifacts(
feature_view_proto=feature_view_proto, repo_config_byte=repo_config_byte
)

def unserialize(self):
# unserialize
proto = FeatureViewProto()
proto.ParseFromString(self.feature_view_proto)
feature_view = FeatureView.from_proto(proto)

# load
repo_config = dill.loads(self.repo_config_byte)

provider = PassthroughProvider(repo_config)
online_store = provider.online_store
offline_store = provider.offline_store
return feature_view, online_store, offline_store, repo_config
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ def build_validation_node(self, input_node):
return node

def build_output_nodes(self, input_node):
node = LocalOutputNode("output")
node = LocalOutputNode("output", self.feature_view)
node.add_input(input_node)
self.nodes.append(node)
39 changes: 36 additions & 3 deletions sdk/python/feast/infra/compute_engines/local/nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime, timedelta
from typing import Optional
from typing import Optional, Union

import pyarrow as pa

from feast import BatchFeatureView, StreamFeatureView
from feast.data_source import DataSource
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue
Expand All @@ -11,6 +12,7 @@
from feast.infra.offline_stores.offline_utils import (
infer_event_timestamp_from_entity_df,
)
from feast.utils import _convert_arrow_to_proto

ENTITY_TS_ALIAS = "__entity_event_timestamp"

Expand Down Expand Up @@ -207,11 +209,42 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:


class LocalOutputNode(LocalNode):
def __init__(self, name: str):
def __init__(
self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView]
):
super().__init__(name)
self.feature_view = feature_view

def execute(self, context: ExecutionContext) -> ArrowTableValue:
input_table = self.get_single_table(context).data
context.node_outputs[self.name] = input_table
# TODO: implement the logic to write to offline store

if self.feature_view.online:
online_store = context.online_store

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in self.feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
input_table, self.feature_view, join_key_to_value_type
)

online_store.online_write_batch(
config=context.repo_config,
table=self.feature_view,
data=rows_to_write,
progress=lambda x: None,
)

if self.feature_view.offline:
offline_store = context.offline_store
offline_store.offline_write_batch(
config=context.repo_config,
feature_view=self.feature_view,
table=input_table,
progress=lambda x: None,
)

return input_table
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from feast.infra.common.materialization_job import MaterializationTask
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
from feast.infra.compute_engines.feature_builder import FeatureBuilder
from feast.infra.compute_engines.spark.node import (
from feast.infra.compute_engines.spark.nodes import (
SparkAggregationNode,
SparkDedupNode,
SparkFilterNode,
Expand Down Expand Up @@ -73,7 +73,8 @@ def build_transformation_node(self, input_node):
return node

def build_output_nodes(self, input_node):
node = SparkWriteNode("output", input_node, self.feature_view)
node = SparkWriteNode("output", self.feature_view)
node.add_input(input_node)
self.nodes.append(node)
return node

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
from feast import BatchFeatureView, StreamFeatureView
from feast.aggregation import Aggregation
from feast.data_source import DataSource
from feast.infra.common.serde import SerializedArtifacts
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.dag.model import DAGFormat
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.dag.value import DAGValue
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
_map_by_partition,
_SparkSerializedArtifacts,
)
from feast.infra.compute_engines.spark.utils import map_in_arrow
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkRetrievalJob,
_get_entity_schema,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.infra.offline_stores.offline_utils import (
infer_event_timestamp_from_entity_df,
)
Expand Down Expand Up @@ -273,30 +274,41 @@ class SparkWriteNode(DAGNode):
def __init__(
self,
name: str,
input_node: DAGNode,
feature_view: Union[BatchFeatureView, StreamFeatureView],
):
super().__init__(name)
self.add_input(input_node)
self.feature_view = feature_view

def execute(self, context: ExecutionContext) -> DAGValue:
spark_df: DataFrame = self.get_single_input_value(context).data
spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
serialized_artifacts = SerializedArtifacts.serialize(
feature_view=self.feature_view, repo_config=context.repo_config
)

# ✅ 1. Write to offline store (if enabled)
if self.feature_view.offline:
# TODO: Update _map_by_partition to be able to write to offline store
pass

# ✅ 2. Write to online store (if enabled)
# ✅ 1. Write to online store if online enabled
if self.feature_view.online:
spark_df.mapInPandas(
lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int"
spark_df.mapInArrow(
lambda x: map_in_arrow(x, serialized_artifacts, mode="online"),
spark_df.schema,
).count()

# ✅ 2. Write to offline store if offline enabled
if self.feature_view.offline:
if not isinstance(self.feature_view.batch_source, SparkSource):
spark_df.mapInArrow(
lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"),
spark_df.schema,
).count()
# Directly write spark df to spark offline store without using mapInArrow
else:
dest_path = self.feature_view.batch_source.path
file_format = self.feature_view.batch_source.file_format
if not dest_path or not file_format:
raise ValueError(
"Destination path and file format must be specified for SparkSource."
)
spark_df.write.format(file_format).mode("append").save(dest_path)

return DAGValue(
data=spark_df,
format=DAGFormat.SPARK,
Expand Down
49 changes: 48 additions & 1 deletion sdk/python/feast/infra/compute_engines/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Dict, Optional
from typing import Dict, Iterable, Literal, Optional

import pyarrow as pa
from pyspark import SparkConf
from pyspark.sql import SparkSession

from feast.infra.common.serde import SerializedArtifacts
from feast.utils import _convert_arrow_to_proto


def get_or_create_new_spark_session(
spark_config: Optional[Dict[str, str]] = None,
Expand All @@ -16,4 +20,47 @@ def get_or_create_new_spark_session(
)

spark_session = spark_builder.getOrCreate()
spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
return spark_session


def map_in_arrow(
iterator: Iterable[pa.RecordBatch],
serialized_artifacts: "SerializedArtifacts",
mode: Literal["online", "offline"] = "online",
):
for batch in iterator:
table = pa.Table.from_batches([batch])

(
feature_view,
online_store,
offline_store,
repo_config,
) = serialized_artifacts.unserialize()

if mode == "online":
join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)

online_store.online_write_batch(
config=repo_config,
table=feature_view,
data=rows_to_write,
progress=lambda x: None,
)
if mode == "offline":
offline_store.offline_write_batch(
config=repo_config,
feature_view=feature_view,
table=table,
progress=lambda x: None,
)

yield batch
Loading
Loading