Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from types import MethodType
from typing import List

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.functions import col, from_json

from feast.data_format import AvroFormat, JsonFormat
from feast.data_source import KafkaSource
from feast.infra.contrib.stream_processor import (
ProcessorConfig,
StreamProcessor,
StreamTable,
)
from feast.stream_feature_view import StreamFeatureView


class SparkProcessorConfig(ProcessorConfig):
spark_session: SparkSession


class SparkKafkaProcessor(StreamProcessor):
spark: SparkSession
format: str
write_function: MethodType
join_keys: List[str]

def __init__(
self,
sfv: StreamFeatureView,
config: ProcessorConfig,
write_function: MethodType,
processing_time: str = "30 seconds",
query_timeout: str = "15 seconds",
):
if not isinstance(sfv.stream_source, KafkaSource):
raise ValueError("data source is not kafka source")
if not isinstance(
sfv.stream_source.kafka_options.message_format, AvroFormat
) and not isinstance(
sfv.stream_source.kafka_options.message_format, JsonFormat
):
raise ValueError(
"spark streaming currently only supports json or avro format for kafka source schema"
)

self.format = (
"json"
if isinstance(sfv.stream_source.kafka_options.message_format, JsonFormat)
else "avro"
)

if not isinstance(config, SparkProcessorConfig):
raise ValueError("config is not spark processor config")
self.spark = config.spark_session
self.write_function = write_function
self.processing_time = processing_time
self.query_timeout = query_timeout
super().__init__(sfv=sfv, data_source=sfv.stream_source)

def ingest_stream_feature_view(self) -> None:
ingested_stream_df = self._ingest_stream_data()
transformed_df = self._construct_transformation_plan(ingested_stream_df)
online_store_query = self._write_to_online_store(transformed_df)
return online_store_query

def _ingest_stream_data(self) -> StreamTable:
"""Only supports json and avro formats currently."""
if self.format == "json":
if not isinstance(
self.data_source.kafka_options.message_format, JsonFormat
):
raise ValueError("kafka source message format is not jsonformat")
stream_df = (
self.spark.readStream.format("kafka")
.option(
"kafka.bootstrap.servers",
self.data_source.kafka_options.bootstrap_servers,
)
.option("subscribe", self.data_source.kafka_options.topic)
.option("startingOffsets", "latest") # Query start
.load()
.selectExpr("CAST(value AS STRING)")
.select(
from_json(
col("value"),
self.data_source.kafka_options.message_format.schema_json,
).alias("table")
)
.select("table.*")
)
else:
if not isinstance(
self.data_source.kafka_options.message_format, AvroFormat
):
raise ValueError("kafka source message format is not avro format")
stream_df = (
self.spark.readStream.format("kafka")
.option(
"kafka.bootstrap.servers",
self.data_source.kafka_options.bootstrap_servers,
)
.option("subscribe", self.data_source.kafka_options.topic)
.option("startingOffsets", "latest") # Query start
.load()
.selectExpr("CAST(value AS STRING)")
.select(
from_avro(
col("value"),
self.data_source.kafka_options.message_format.schema_json,
).alias("table")
)
.select("table.*")
)
return stream_df

def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
return self.sfv.udf.__call__(df) if self.sfv.udf else df

def _write_to_online_store(self, df: StreamTable):
# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write(row: DataFrame, batch_id: int):
pd_row = row.toPandas()
self.write_function(
pd_row, input_timestamp="event_timestamp", output_timestamp=""
)

query = (
df.writeStream.outputMode("update")
.option("checkpointLocation", "/tmp/checkpoint/")
.trigger(processingTime=self.processing_time)
.foreachBatch(batch_write)
.start()
)

query.awaitTermination(timeout=self.query_timeout)
return query
87 changes: 87 additions & 0 deletions sdk/python/feast/infra/contrib/stream_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from abc import ABC
from typing import Callable

import pandas as pd
from pyspark.sql import DataFrame

from feast.data_source import DataSource
from feast.importer import import_class
from feast.repo_config import FeastConfigBaseModel
from feast.stream_feature_view import StreamFeatureView

STREAM_PROCESSOR_CLASS_FOR_TYPE = {
("spark", "kafka"): "feast.infra.contrib.spark_kafka_processor.SparkKafkaProcessor",
}

# TODO: support more types other than just Spark.
StreamTable = DataFrame


class ProcessorConfig(FeastConfigBaseModel):
# Processor mode (spark, etc)
mode: str
# Ingestion source (kafka, kinesis, etc)
source: str


class StreamProcessor(ABC):
"""
A StreamProcessor can ingest and transform data for a specific stream feature view,
and persist that data to the online store.

Attributes:
sfv: The stream feature view on which the stream processor operates.
data_source: The stream data source from which data will be ingested.
"""

sfv: StreamFeatureView
data_source: DataSource

def __init__(self, sfv: StreamFeatureView, data_source: DataSource):
self.sfv = sfv
self.data_source = data_source

def ingest_stream_feature_view(self) -> None:
"""
Ingests data from the stream source attached to the stream feature view; transforms the data
and then persists it to the online store.
"""
pass

def _ingest_stream_data(self) -> StreamTable:
"""
Ingests data into a StreamTable.
"""
pass

def _construct_transformation_plan(self, table: StreamTable) -> StreamTable:
"""
Applies transformations on top of StreamTable object. Since stream engines use lazy
evaluation, the StreamTable will not be materialized until it is actually evaluated.
For example: df.collect() in spark or tbl.execute() in Flink.
"""
pass

def _write_to_online_store(self, table: StreamTable) -> None:
"""
Returns query for persisting data to the online store.
"""
pass


def get_stream_processor_object(
config: ProcessorConfig,
sfv: StreamFeatureView,
write_function: Callable[[pd.DataFrame, str, str], None],
):
"""
Returns a stream processor object based on the config mode and stream source type. The write function is a
function that wraps the feature store "write_to_online_store" capability.
"""
if config.mode == "spark" and config.source == "kafka":
stream_processor = STREAM_PROCESSOR_CLASS_FOR_TYPE[("spark", "kafka")]
module_name, class_name = stream_processor.rsplit(".", 1)
cls = import_class(module_name, class_name, "Processor")
return cls(sfv=sfv, config=config, write_function=write_function,)
else:
raise ValueError("other processors besides spark-kafka not supported")