Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- CreateTable
CREATE TABLE IF NOT EXISTS "LiteLLM_AccessGroupMembership" (
"id" TEXT NOT NULL,
"parent_group" TEXT NOT NULL,
"child_group" TEXT NOT NULL,
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,

CONSTRAINT "LiteLLM_AccessGroupMembership_pkey" PRIMARY KEY ("id")
);

-- CreateIndex
CREATE UNIQUE INDEX IF NOT EXISTS "LiteLLM_AccessGroupMembership_parent_group_child_group_key" ON "LiteLLM_AccessGroupMembership"("parent_group", "child_group");

-- CreateIndex
CREATE INDEX IF NOT EXISTS "LiteLLM_AccessGroupMembership_parent_group_idx" ON "LiteLLM_AccessGroupMembership"("parent_group");
Comment thread
wtfashwin marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- CreateIndex
CREATE INDEX IF NOT EXISTS "LiteLLM_AccessGroupMembership_child_group_idx" ON "LiteLLM_AccessGroupMembership"("child_group");
15 changes: 15 additions & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -1375,3 +1375,18 @@ model LiteLLM_WorkflowMessage {
@@unique([run_id, sequence_number])
@@index([run_id])
}

// Parent/child relationships between access groups. Used by
// _get_models_from_access_groups to expand a parent group into its
// transitively-included models. Resolution is depth-first with a
// visited set; cyclic edges are logged and skipped at read time.
model LiteLLM_AccessGroupMembership {
id String @id @default(uuid())
parent_group String
child_group String
created_at DateTime @default(now())

@@unique([parent_group, child_group])
@@index([parent_group])
@@index([child_group])
}
68 changes: 65 additions & 3 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple

import litellm
from litellm._logging import verbose_proxy_logger
Expand Down Expand Up @@ -42,20 +42,75 @@ def get_provider_models(
return None


def resolve_nested_groups(
group_name: str,
model_access_groups: Dict[str, List[str]],
group_memberships: Dict[str, List[str]],
visited: Set[str],
) -> List[str]:
"""
Expand a group name to the full list of model names it transitively includes,
following parent -> child edges in `group_memberships`.

Iterative DFS with an explicit stack (no Python recursion - the proxy's
recursive_detector forbids new recursive functions due to past CPU spikes).
Frames are tagged ENTER/EXIT so on-path tracking still works: we add to
`visited` on ENTER, discard on EXIT. A revisit while still on the path is
a real cycle - we log and skip. DAG-shared subtrees (e.g. A -> [B, C],
B -> [D], C -> [D]) re-traverse D, and the caller deduplicates.

Cyclic edges are logged and skipped rather than raised: this runs on the
auth path and a malformed row must not 500 the proxy.
"""
ENTER, EXIT = 0, 1
resolved: List[str] = []
stack: List[Tuple[int, str]] = [(ENTER, group_name)]

while stack:
action, node = stack.pop()
if action == EXIT:
visited.discard(node)
continue
if node in visited:
verbose_proxy_logger.warning(
"access group cycle detected at '%s' - skipping cyclic edge",
node,
)
continue
visited.add(node)
# Schedule EXIT now so visited is cleaned up after this subtree finishes
stack.append((EXIT, node))
resolved.extend(model_access_groups.get(node, []))
# Reverse so siblings are visited in original (left-to-right) order
for child in reversed(group_memberships.get(node, [])):
stack.append((ENTER, child))

return resolved


def _get_models_from_access_groups(
model_access_groups: Dict[str, List[str]],
all_models: List[str],
include_model_access_groups: Optional[bool] = False,
group_memberships: Optional[Dict[str, List[str]]] = None,
) -> List[str]:
memberships = group_memberships or {}
idx_to_remove = []
new_models = []
for idx, model in enumerate(all_models):
if model in model_access_groups:
if model in model_access_groups or model in memberships:
if (
not include_model_access_groups
): # remove access group, unless requested - e.g. when creating a key
idx_to_remove.append(idx)
new_models.extend(model_access_groups[model])
new_models.extend(
resolve_nested_groups(
group_name=model,
model_access_groups=model_access_groups,
group_memberships=memberships,
visited=set(),
)
)

for idx in sorted(idx_to_remove, reverse=True):
all_models.pop(idx)
Expand Down Expand Up @@ -96,6 +151,7 @@ def get_key_models(
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
only_model_access_groups: Optional[bool] = False,
group_memberships: Optional[Dict[str, List[str]]] = None,
) -> List[str]:
"""
Returns:
Expand All @@ -104,6 +160,8 @@ def get_key_models(
- If model_access_groups is provided, only return models that are in the access groups
- If include_model_access_groups is True, it includes the 'keys' of the model_access_groups
in the response - {"beta-models": ["gpt-4", "claude-v1"]} -> returns 'beta-models'
- If group_memberships is provided, expands nested groups transitively
(parent -> child edges); cyclic edges are skipped
"""
all_models: List[str] = []
if len(user_api_key_dict.models) > 0:
Expand All @@ -123,6 +181,7 @@ def get_key_models(
model_access_groups=model_access_groups,
all_models=all_models,
include_model_access_groups=include_model_access_groups,
group_memberships=group_memberships,
)

# deduplicate while preserving order
Expand All @@ -137,12 +196,14 @@ def get_team_models(
proxy_model_list: List[str],
model_access_groups: Dict[str, List[str]],
include_model_access_groups: Optional[bool] = False,
group_memberships: Optional[Dict[str, List[str]]] = None,
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
- If model_access_groups is provided, only return models that are in the access groups
- If group_memberships is provided, expands nested groups transitively
"""
all_models_set: Set[str] = set()
if len(team_models) > 0:
Expand All @@ -158,6 +219,7 @@ def get_team_models(
model_access_groups=model_access_groups,
all_models=list(all_models_set),
include_model_access_groups=include_model_access_groups,
group_memberships=group_memberships,
)

# deduplicate while preserving order
Expand Down
Loading
Loading