Skip to content

Commit 4f1252c

Browse files
committed
fix(proxy): address greptile review on nested access groups
Four issues from the bot review: 1. validate_models_exist no longer short-circuits to "all missing" when llm_router is None. In DB-only deployments (no in-memory router), known_access_groups is still authoritative for nested-group composition; only names absent from it are reported missing. This unblocks the core use case of creating pure-composition parent groups in that config. 2. get_group_memberships_from_db now catches Exception, not just AttributeError/TypeError. Transient DB/network errors degrade to an empty map instead of 500-ing model-listing requests. Matches the docstring's "resilient by design" promise. 3. Added @@index([child_group]) to LiteLLM_AccessGroupMembership across all three schema.prisma files + a new migration. Previously delete_group_membership_edges' WHERE parent_group = X OR child_group = X would full-scan on the child side as the table grew. 4. Added a 60s per-process TTL cache for the membership map. Hot-path callers (get_available_models_for_user, model_info_v1) now go through get_cached_group_memberships, which is explicitly invalidated by every write (upsert, delete edges, parent-edge clear in update_access_group). Matches the per-process cache model already used for llm_router.get_model_access_groups(). Added 9 tests covering the new validate fall-through (2), broader exception handling (1), and cache hit/miss/invalidation/TTL/error-cache semantics (6). Total nested-group tests: 61. Refs #28032
1 parent f8db829 commit 4f1252c

10 files changed

Lines changed: 259 additions & 23 deletions

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-- CreateIndex
2+
CREATE INDEX IF NOT EXISTS "LiteLLM_AccessGroupMembership_child_group_idx" ON "LiteLLM_AccessGroupMembership"("child_group");

litellm-proxy-extras/litellm_proxy_extras/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,4 +1388,5 @@ model LiteLLM_AccessGroupMembership {
13881388
13891389
@@unique([parent_group, child_group])
13901390
@@index([parent_group])
1391+
@@index([child_group])
13911392
}

litellm/proxy/management_endpoints/model_access_group_management_endpoints.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import json
9+
import time
910
from typing import Any, Dict, List, Optional, Set, Tuple
1011

1112
from fastapi import APIRouter, Depends, HTTPException
@@ -32,6 +33,29 @@
3233
router = APIRouter()
3334

3435

36+
# ---------------------------------------------------------------------------
37+
# Per-process membership-map cache
38+
# ---------------------------------------------------------------------------
39+
# get_group_memberships_from_db is called on every /v1/models and /model/info
40+
# request via get_available_models_for_user. Without a cache that's one extra
41+
# Prisma roundtrip per request - bad under burst traffic. We cache the map
42+
# in process memory for a short TTL and invalidate explicitly on writes
43+
# (upsert/delete) so the consistency window inside the writing process is
44+
# zero. Across processes, eventual consistency is bounded by the TTL -
45+
# matches today's behavior for llm_router.get_model_access_groups() which is
46+
# also per-process.
47+
48+
_MEMBERSHIPS_CACHE_TTL_SECONDS = 60.0
49+
_MEMBERSHIPS_CACHE: Optional[Tuple[float, Dict[str, List[str]]]] = None
50+
51+
52+
def invalidate_group_memberships_cache() -> None:
53+
"""Drop the in-process membership cache. Call after any write that
54+
mutates the LiteLLM_AccessGroupMembership table."""
55+
global _MEMBERSHIPS_CACHE
56+
_MEMBERSHIPS_CACHE = None
57+
58+
3559
def validate_models_exist(
3660
model_names: List[str],
3761
llm_router,
@@ -44,11 +68,17 @@ def validate_models_exist(
4468
Returns:
4569
Tuple[bool, List[str]]: (all_valid, missing_names)
4670
"""
71+
known_groups = known_access_groups or set()
72+
4773
if llm_router is None:
48-
return False, model_names
74+
# DB-only deployment: no in-memory router means we cannot validate
75+
# real model names, but known_access_groups is still authoritative
76+
# for nested-group composition. Anything not in known_groups is
77+
# reported as missing (fail-closed).
78+
missing = [m for m in model_names if m not in known_groups]
79+
return (len(missing) == 0, missing)
4980

5081
router_model_names = set(llm_router.get_model_names())
51-
known_groups = known_access_groups or set()
5282
missing = [
5383
m for m in model_names if m not in router_model_names and m not in known_groups
5484
]
@@ -87,17 +117,18 @@ async def get_group_memberships_from_db(
87117
Build parent_group -> [child_groups] map from the membership table.
88118
Single query, in-memory bucketing - no N+1.
89119
90-
Resilient by design: if the table isn't available (Prisma client predates
91-
this migration, the proxy started before `prisma migrate deploy` finished,
92-
or the membership Prisma model was stripped from a downstream build) we
93-
return an empty map. The auth path then falls back to today's flat-group
94-
semantics instead of 500-ing the whole request.
120+
Resilient by design: any failure to read the membership table (missing
121+
Prisma model, migration race, transient DB/network error, query timeout)
122+
degrades to an empty map. The auth path then falls back to today's
123+
flat-group semantics instead of 500-ing the whole request. We log at
124+
debug so ops can correlate fallback periods with incidents without
125+
drowning normal traffic in warnings.
95126
"""
96127
try:
97128
rows = await prisma_client.db.litellm_accessgroupmembership.find_many()
98-
except (AttributeError, TypeError) as e:
129+
except Exception as e: # noqa: BLE001 - intentional broad catch on auth path
99130
verbose_proxy_logger.debug(
100-
"litellm_accessgroupmembership unavailable - "
131+
"litellm_accessgroupmembership read failed - "
101132
"skipping nested group resolution: %s",
102133
e,
103134
)
@@ -109,6 +140,25 @@ async def get_group_memberships_from_db(
109140
return memberships
110141

111142

143+
async def get_cached_group_memberships(
144+
prisma_client: PrismaClient,
145+
) -> Dict[str, List[str]]:
146+
"""
147+
TTL-cached wrapper around get_group_memberships_from_db. Hot-path
148+
callers (model-listing endpoints) should use this; tests and write
149+
paths that need fresh data can call the underlying helper directly.
150+
"""
151+
global _MEMBERSHIPS_CACHE
152+
now = time.monotonic()
153+
if _MEMBERSHIPS_CACHE is not None:
154+
cached_at, value = _MEMBERSHIPS_CACHE
155+
if now - cached_at < _MEMBERSHIPS_CACHE_TTL_SECONDS:
156+
return value
157+
fresh = await get_group_memberships_from_db(prisma_client=prisma_client)
158+
_MEMBERSHIPS_CACHE = (now, fresh)
159+
return fresh
160+
161+
112162
async def upsert_group_memberships(
113163
parent_group: str,
114164
child_groups: List[str],
@@ -142,6 +192,7 @@ async def upsert_group_memberships(
142192
data=rows,
143193
skip_duplicates=True,
144194
)
195+
invalidate_group_memberships_cache()
145196
return result
146197

147198

@@ -164,6 +215,7 @@ async def delete_group_membership_edges(
164215
]
165216
}
166217
)
218+
invalidate_group_memberships_cache()
167219
return result
168220

169221

@@ -850,6 +902,7 @@ async def update_access_group(
850902
await prisma_client.db.litellm_accessgroupmembership.delete_many(
851903
where={"parent_group": access_group}
852904
)
905+
invalidate_group_memberships_cache()
853906

854907
# Step 2: re-add membership using the appropriate write path
855908
if use_model_ids:

litellm/proxy/proxy_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def generate_feedback_box():
367367
router as key_management_router,
368368
)
369369
from litellm.proxy.management_endpoints.model_access_group_management_endpoints import (
370-
get_group_memberships_from_db,
370+
get_cached_group_memberships,
371371
router as model_access_group_management_router,
372372
)
373373
from litellm.proxy.management_endpoints.model_management_endpoints import (
@@ -11978,11 +11978,11 @@ async def model_info_v1( # noqa: PLR0915
1197811978
proxy_model_list = llm_router.get_model_names()
1197911979
model_access_groups = llm_router.get_model_access_groups()
1198011980

11981-
# Parent->child edges for nested access groups. Empty when no DB is
11982-
# configured, preserving today's flat behavior.
11981+
# Parent->child edges for nested access groups (TTL-cached per process).
11982+
# Empty when no DB is configured, preserving today's flat behavior.
1198311983
group_memberships: Dict[str, List[str]] = {}
1198411984
if prisma_client is not None:
11985-
group_memberships = await get_group_memberships_from_db(
11985+
group_memberships = await get_cached_group_memberships(
1198611986
prisma_client=prisma_client
1198711987
)
1198811988

litellm/proxy/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,4 +1388,5 @@ model LiteLLM_AccessGroupMembership {
13881388
13891389
@@unique([parent_group, child_group])
13901390
@@index([parent_group])
1391+
@@index([child_group])
13911392
}

litellm/proxy/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5848,7 +5848,7 @@ async def get_available_models_for_user(
58485848
get_team_models,
58495849
)
58505850
from litellm.proxy.management_endpoints.model_access_group_management_endpoints import (
5851-
get_group_memberships_from_db,
5851+
get_cached_group_memberships,
58525852
)
58535853
from litellm.proxy.management_endpoints.team_endpoints import validate_membership
58545854

@@ -5860,11 +5860,12 @@ async def get_available_models_for_user(
58605860
proxy_model_list = llm_router.get_model_names()
58615861
model_access_groups = llm_router.get_model_access_groups()
58625862

5863-
# Parent->child edges for nested access groups. Empty when no DB is
5864-
# configured (e.g. SDK-only mode), preserving today's flat behavior.
5863+
# Parent->child edges for nested access groups (TTL-cached per process).
5864+
# Empty when no DB is configured (e.g. SDK-only mode), preserving
5865+
# today's flat behavior.
58655866
group_memberships: Dict[str, List[str]] = {}
58665867
if prisma_client is not None:
5867-
group_memberships = await get_group_memberships_from_db(
5868+
group_memberships = await get_cached_group_memberships(
58685869
prisma_client=prisma_client
58695870
)
58705871

schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,4 +1388,5 @@ model LiteLLM_AccessGroupMembership {
13881388
13891389
@@unique([parent_group, child_group])
13901390
@@index([parent_group])
1391+
@@index([child_group])
13911392
}

tests/test_litellm/proxy/auth/test_nested_access_groups.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,16 +541,29 @@ def get_model_names(self):
541541
assert missing == ["z-missing", "y-missing"]
542542

543543

544-
def test_validate_models_exist_with_null_router_returns_false():
545-
"""No router - everything reports as missing (matches today's defensive behavior)."""
544+
def test_validate_models_exist_with_null_router_still_accepts_known_groups():
545+
"""DB-only deployment: llm_router is None but known_access_groups is still authoritative
546+
for nested-group composition - only names not in known_groups are reported missing.
547+
"""
548+
all_valid, missing = validate_models_exist(
549+
model_names=["image", "reasoning"],
550+
llm_router=None,
551+
known_access_groups={"image", "reasoning"},
552+
)
553+
assert all_valid is True
554+
assert missing == []
555+
556+
557+
def test_validate_models_exist_with_null_router_rejects_unknown_real_models():
558+
"""Without a router we can't validate real model names, so anything not in
559+
known_access_groups is fail-closed reported as missing."""
546560
all_valid, missing = validate_models_exist(
547-
model_names=["any"],
561+
model_names=["gpt-4", "image"],
548562
llm_router=None,
549-
known_access_groups={"any"},
563+
known_access_groups={"image"},
550564
)
551-
# Without a router we can't say what's a model, so we fall back to fail-closed
552565
assert all_valid is False
553-
assert missing == ["any"]
566+
assert missing == ["gpt-4"]
554567

555568

556569
def test_resolve_with_empty_models_and_empty_memberships_returns_empty():
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Cache-behavior tests for the nested-access-group membership map (#28032).
3+
4+
Hot-path callers go through get_cached_group_memberships() which TTL-caches
5+
get_group_memberships_from_db() and is invalidated by every membership
6+
write. These tests pin the cache hit/miss/invalidation semantics so the
7+
optimization can't silently break later.
8+
"""
9+
10+
import os
11+
import sys
12+
from types import SimpleNamespace
13+
from unittest.mock import AsyncMock, MagicMock
14+
15+
sys.path.insert(0, os.path.abspath("../../.."))
16+
17+
import pytest
18+
19+
import litellm.proxy.management_endpoints.model_access_group_management_endpoints as mgmt
20+
from litellm.proxy.management_endpoints.model_access_group_management_endpoints import (
21+
delete_group_membership_edges,
22+
get_cached_group_memberships,
23+
invalidate_group_memberships_cache,
24+
upsert_group_memberships,
25+
)
26+
27+
28+
def _row(parent: str, child: str) -> SimpleNamespace:
29+
return SimpleNamespace(parent_group=parent, child_group=child)
30+
31+
32+
def _make_prisma(membership_rows=None):
33+
membership_rows = membership_rows or []
34+
db = MagicMock()
35+
db.litellm_accessgroupmembership = MagicMock()
36+
db.litellm_accessgroupmembership.find_many = AsyncMock(return_value=membership_rows)
37+
db.litellm_accessgroupmembership.create_many = AsyncMock(return_value=0)
38+
db.litellm_accessgroupmembership.delete_many = AsyncMock(return_value=0)
39+
client = MagicMock()
40+
client.db = db
41+
return client
42+
43+
44+
@pytest.fixture(autouse=True)
45+
def _reset_cache_between_tests():
46+
"""Module-level cache state must not leak between tests."""
47+
invalidate_group_memberships_cache()
48+
yield
49+
invalidate_group_memberships_cache()
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_cache_miss_then_hit_avoids_second_db_query():
54+
prisma = _make_prisma(membership_rows=[_row("project-x", "image")])
55+
56+
first = await get_cached_group_memberships(prisma_client=prisma)
57+
second = await get_cached_group_memberships(prisma_client=prisma)
58+
59+
assert first == second == {"project-x": ["image"]}
60+
# Only the first call should hit the DB
61+
prisma.db.litellm_accessgroupmembership.find_many.assert_awaited_once()
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_cache_invalidation_forces_db_refetch():
66+
prisma = _make_prisma(membership_rows=[_row("project-x", "image")])
67+
68+
await get_cached_group_memberships(prisma_client=prisma)
69+
invalidate_group_memberships_cache()
70+
await get_cached_group_memberships(prisma_client=prisma)
71+
72+
assert prisma.db.litellm_accessgroupmembership.find_many.await_count == 2
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_upsert_invalidates_cache():
77+
"""Writing edges must drop the cache so the next read sees the change."""
78+
prisma = _make_prisma(membership_rows=[_row("project-x", "image")])
79+
prisma.db.litellm_accessgroupmembership.create_many = AsyncMock(return_value=1)
80+
81+
await get_cached_group_memberships(prisma_client=prisma) # populates cache
82+
await upsert_group_memberships(
83+
parent_group="project-x",
84+
child_groups=["reasoning"],
85+
prisma_client=prisma,
86+
)
87+
await get_cached_group_memberships(prisma_client=prisma) # must re-fetch
88+
89+
assert prisma.db.litellm_accessgroupmembership.find_many.await_count == 2
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_delete_edges_invalidates_cache():
94+
"""Deleting edges must drop the cache too."""
95+
prisma = _make_prisma(membership_rows=[_row("project-x", "image")])
96+
prisma.db.litellm_accessgroupmembership.delete_many = AsyncMock(return_value=1)
97+
98+
await get_cached_group_memberships(prisma_client=prisma)
99+
await delete_group_membership_edges(access_group="project-x", prisma_client=prisma)
100+
await get_cached_group_memberships(prisma_client=prisma)
101+
102+
assert prisma.db.litellm_accessgroupmembership.find_many.await_count == 2
103+
104+
105+
@pytest.mark.asyncio
106+
async def test_cache_expires_after_ttl(monkeypatch):
107+
"""When monotonic time advances past the TTL, the next read re-fetches."""
108+
prisma = _make_prisma(membership_rows=[_row("project-x", "image")])
109+
110+
# Freeze time; advance past TTL between calls
111+
now = [1000.0]
112+
monkeypatch.setattr(mgmt.time, "monotonic", lambda: now[0])
113+
114+
await get_cached_group_memberships(prisma_client=prisma)
115+
now[0] += mgmt._MEMBERSHIPS_CACHE_TTL_SECONDS + 1
116+
await get_cached_group_memberships(prisma_client=prisma)
117+
118+
assert prisma.db.litellm_accessgroupmembership.find_many.await_count == 2
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_cache_within_ttl_does_not_refetch(monkeypatch):
123+
"""Reads inside the TTL window stay served from cache."""
124+
prisma = _make_prisma(membership_rows=[_row("project-x", "image")])
125+
126+
now = [1000.0]
127+
monkeypatch.setattr(mgmt.time, "monotonic", lambda: now[0])
128+
129+
await get_cached_group_memberships(prisma_client=prisma)
130+
now[0] += mgmt._MEMBERSHIPS_CACHE_TTL_SECONDS - 1
131+
await get_cached_group_memberships(prisma_client=prisma)
132+
133+
prisma.db.litellm_accessgroupmembership.find_many.assert_awaited_once()
134+
135+
136+
@pytest.mark.asyncio
137+
async def test_cache_falls_through_empty_dict_on_error_path():
138+
"""When the underlying helper returns {} due to a DB error, the cache
139+
still stores it - we don't want to retry on every single request."""
140+
prisma = _make_prisma()
141+
prisma.db.litellm_accessgroupmembership.find_many = AsyncMock(
142+
side_effect=ConnectionError("postgres unreachable")
143+
)
144+
145+
first = await get_cached_group_memberships(prisma_client=prisma)
146+
second = await get_cached_group_memberships(prisma_client=prisma)
147+
148+
assert first == second == {}
149+
# Only one DB attempt; subsequent calls served from the cached {}
150+
prisma.db.litellm_accessgroupmembership.find_many.assert_awaited_once()

0 commit comments

Comments
 (0)