Skip to content

Commit 2747405

Browse files
committed
feat: added batching to feature server /push to offline store ([#5683](#5683))
Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com>
1 parent bb299d9 commit 2747405

File tree

2 files changed

+86
-39
lines changed

2 files changed

+86
-39
lines changed

sdk/python/feast/feature_server.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from datetime import datetime
2424
from importlib import resources as importlib_resources
2525
from types import SimpleNamespace
26-
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Union
26+
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Union
2727

2828
import pandas as pd
2929
import psutil
@@ -395,7 +395,7 @@ async def retrieve_online_documents(
395395
return response_dict
396396

397397
@app.post("/push", dependencies=[Depends(inject_user_details)])
398-
async def push(request: PushFeaturesRequest) -> None:
398+
async def push(request: PushFeaturesRequest) -> Response:
399399
df = pd.DataFrame(request.df)
400400
actions = []
401401
if request.to == "offline":
@@ -470,6 +470,8 @@ async def _push_with_to(push_to: PushMode) -> None:
470470
needs_online = to in (PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE)
471471
needs_offline = to in (PushMode.OFFLINE, PushMode.ONLINE_AND_OFFLINE)
472472

473+
status_code = status.HTTP_200_OK
474+
473475
if offline_batcher is None or not needs_offline:
474476
await _push_with_to(to)
475477
else:
@@ -482,6 +484,9 @@ async def _push_with_to(push_to: PushMode) -> None:
482484
allow_registry_cache=request.allow_registry_cache,
483485
transform_on_write=request.transform_on_write,
484486
)
487+
status_code = status.HTTP_202_ACCEPTED
488+
489+
return Response(status_code=status_code)
485490

486491
async def _get_feast_object(
487492
feature_view_name: str, allow_registry_cache: bool
@@ -851,6 +856,7 @@ def __init__(self, store: "feast.FeatureStore", cfg: Any):
851856
list
852857
)
853858
self._last_flush: DefaultDict[_OfflineBatchKey, float] = defaultdict(time.time)
859+
self._inflight: Set[_OfflineBatchKey] = set()
854860

855861
self._lock = threading.Lock()
856862
self._stop_event = threading.Event()
@@ -889,24 +895,25 @@ def enqueue(
889895
with self._lock:
890896
self._buffers[key].append(df)
891897
total_rows = sum(len(d) for d in self._buffers[key])
898+
should_flush = total_rows >= self._cfg.batch_size
892899

900+
if should_flush:
893901
# Size-based flush
894-
if total_rows >= self._cfg.batch_size:
895-
logger.debug(
896-
"OfflineWriteBatcher size threshold reached for %s: %s rows",
897-
key,
898-
total_rows,
899-
)
900-
self._flush_locked(key)
902+
logger.debug(
903+
"OfflineWriteBatcher size threshold reached for %s: %s rows",
904+
key,
905+
total_rows,
906+
)
907+
self._flush(key)
901908

902909
def flush_all(self) -> None:
903910
"""
904911
Flush all buffers synchronously. Intended for graceful shutdown.
905912
"""
906913
with self._lock:
907914
keys = list(self._buffers.keys())
908-
for key in keys:
909-
self._flush_locked(key)
915+
for key in keys:
916+
self._flush(key)
910917

911918
def shutdown(self, timeout: float = 5.0) -> None:
912919
"""
@@ -942,6 +949,7 @@ def _run(self) -> None:
942949
now = time.time()
943950
try:
944951
with self._lock:
952+
keys_to_flush: List[_OfflineBatchKey] = []
945953
for key, dfs in list(self._buffers.items()):
946954
if not dfs:
947955
continue
@@ -955,38 +963,75 @@ def _run(self) -> None:
955963
key,
956964
age,
957965
)
958-
self._flush_locked(key)
966+
keys_to_flush.append(key)
967+
for key in keys_to_flush:
968+
self._flush(key)
959969
except Exception:
960970
logger.exception("Error in OfflineWriteBatcher background loop")
961971

962972
logger.debug("OfflineWriteBatcher background loop exiting")
963973

964-
def _flush_locked(self, key: _OfflineBatchKey) -> None:
974+
def _drain_locked(self, key: _OfflineBatchKey) -> Optional[List[pd.DataFrame]]:
965975
"""
966-
Flush a single buffer; caller must hold self._lock.
976+
Drain a single buffer; caller must hold self._lock.
967977
"""
978+
if key in self._inflight:
979+
return None
980+
968981
dfs = self._buffers.get(key)
969982
if not dfs:
970-
return
983+
return None
971984

972-
batch_df = pd.concat(dfs, ignore_index=True)
973-
self._buffers[key].clear()
974-
self._last_flush[key] = time.time()
985+
self._buffers[key] = []
986+
self._inflight.add(key)
987+
return dfs
975988

976-
logger.debug(
977-
"Flushing offline batch for push_source=%s with %s rows",
978-
key.push_source_name,
979-
len(batch_df),
980-
)
989+
def _flush(self, key: _OfflineBatchKey) -> None:
990+
"""
991+
Flush a single buffer. Extracts data under lock, then does I/O without lock.
992+
"""
993+
while True:
994+
with self._lock:
995+
dfs = self._drain_locked(key)
981996

982-
# NOTE: offline writes are currently synchronous only, so we call directly
983-
try:
984-
self._store.push(
985-
push_source_name=key.push_source_name,
986-
df=batch_df,
987-
allow_registry_cache=key.allow_registry_cache,
988-
to=PushMode.OFFLINE,
989-
transform_on_write=key.transform_on_write,
997+
if not dfs:
998+
return
999+
1000+
batch_df = pd.concat(dfs, ignore_index=True)
1001+
1002+
# NOTE: offline writes are currently synchronous only, so we call directly
1003+
try:
1004+
self._store.push(
1005+
push_source_name=key.push_source_name,
1006+
df=batch_df,
1007+
allow_registry_cache=key.allow_registry_cache,
1008+
to=PushMode.OFFLINE,
1009+
transform_on_write=key.transform_on_write,
1010+
)
1011+
except Exception:
1012+
logger.exception("Error flushing offline batch for %s", key)
1013+
with self._lock:
1014+
self._buffers[key] = dfs + self._buffers[key]
1015+
self._inflight.discard(key)
1016+
return
1017+
1018+
logger.debug(
1019+
"Flushing offline batch for push_source=%s with %s rows",
1020+
key.push_source_name,
1021+
len(batch_df),
1022+
)
1023+
1024+
with self._lock:
1025+
self._last_flush[key] = time.time()
1026+
self._inflight.discard(key)
1027+
pending_rows = sum(len(d) for d in self._buffers.get(key, []))
1028+
should_flush = pending_rows >= self._cfg.batch_size
1029+
1030+
if not should_flush:
1031+
return
1032+
1033+
logger.debug(
1034+
"OfflineWriteBatcher size threshold reached for %s: %s rows",
1035+
key,
1036+
pending_rows,
9901037
)
991-
except Exception:
992-
logger.exception("Error flushing offline batch for %s", key)

sdk/python/tests/unit/test_feature_server.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _enable_offline_batching_config(
229229
fs, enabled: bool = True, batch_size: int = 1, batch_interval_seconds: int = 60
230230
):
231231
"""
232-
Attach a minimal feature_server.offline_push_batching config
232+
Attach a minimal feature_server.offline_push_batching config
233233
to a mocked FeatureStore.
234234
"""
235235
if not hasattr(fs, "config") or fs.config is None:
@@ -301,7 +301,9 @@ def test_push_batched_matrix(
301301

302302
# use a multi-row payload to ensure we test non-trivial dfs
303303
resp = client.post("/push", json=push_body_many(push_mode, count=2, id_start=100))
304-
assert resp.status_code == 200
304+
needs_offline = push_mode in (PushMode.OFFLINE, PushMode.ONLINE_AND_OFFLINE)
305+
expected_status = 202 if batching_enabled and needs_offline else 200
306+
assert resp.status_code == expected_status
305307

306308
# Collect calls
307309
sync_calls = fs.push.call_args_list
@@ -391,19 +393,19 @@ def test_offline_batches_are_separated_by_flags(mock_fs_factory):
391393

392394
# 1) Default flags: allow_registry_cache=True, transform_on_write=True
393395
resp1 = client.post("/push", json=body_base)
394-
assert resp1.status_code == 200
396+
assert resp1.status_code == 202
395397

396398
# 2) Different allow_registry_cache
397399
body_allow_false = dict(body_base)
398400
body_allow_false["allow_registry_cache"] = False
399401
resp2 = client.post("/push", json=body_allow_false)
400-
assert resp2.status_code == 200
402+
assert resp2.status_code == 202
401403

402404
# 3) Different transform_on_write
403405
body_transform_false = dict(body_base)
404406
body_transform_false["transform_on_write"] = False
405407
resp3 = client.post("/push", json=body_transform_false)
406-
assert resp3.status_code == 200
408+
assert resp3.status_code == 202
407409

408410
# Immediately after: no flush expected yet (interval-based)
409411
assert fs.push.call_count == 0
@@ -447,7 +449,7 @@ def test_offline_batcher_interval_flush(mock_fs_factory):
447449
resp = client.post(
448450
"/push", json=push_body_many(PushMode.OFFLINE, count=2, id_start=500)
449451
)
450-
assert resp.status_code == 200
452+
assert resp.status_code == 202
451453

452454
# Immediately after: no sync push yet (buffer only)
453455
assert fs.push.call_count == 0

0 commit comments

Comments
 (0)