Skip to content

Commit fe5b222

Browse files
Add simple ToolOverrides for tools with variant definitions
Provides a lightweight mechanism for tools that have different definitions based on runtime conditions (e.g., feature flags, capabilities). Designed for the small number of tools (~2) that need variant handling: - ToolOverride: condition + replacement definition - ToolOverrides: map[string]ToolOverride with Apply() and ApplyToTools() - Lazy allocation: no allocs when no overrides match This is intentionally simple - only 60 lines - because most tools don't need variants and the full VariantIndex would be overkill.
1 parent d45f0e1 commit fe5b222

File tree

2 files changed

+271
-0
lines changed

2 files changed

+271
-0
lines changed

pkg/inventory/tool_variants.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package inventory
2+
3+
import "context"
4+
5+
// ToolOverride allows replacing a tool's definition based on runtime conditions.
6+
// Use this for the small number of tools that have different schemas/handlers
7+
// depending on features, capabilities, or environment.
8+
type ToolOverride struct {
9+
// ToolName is the canonical tool name to override
10+
ToolName string
11+
12+
// Condition returns true if this override should apply
13+
Condition func(ctx context.Context) (bool, error)
14+
15+
// Override is the replacement tool definition
16+
Override ServerTool
17+
}
18+
19+
// ToolOverrides is a simple map for the few tools that need variant handling.
20+
// Key is the tool name, value is the override to check.
21+
type ToolOverrides map[string]ToolOverride
22+
23+
// Apply checks if an override should be used for the given tool.
24+
// Returns the override if condition matches, nil otherwise.
25+
func (o ToolOverrides) Apply(ctx context.Context, toolName string) *ServerTool {
26+
override, ok := o[toolName]
27+
if !ok {
28+
return nil
29+
}
30+
31+
if override.Condition == nil {
32+
return &override.Override
33+
}
34+
35+
matches, err := override.Condition(ctx)
36+
if err != nil || !matches {
37+
return nil
38+
}
39+
40+
return &override.Override
41+
}
42+
43+
// ApplyToTools applies overrides to a list of tools, returning a new list
44+
// with overridden tools replaced. Tools without overrides are unchanged.
45+
// If no overrides match, returns the original slice (no allocation).
46+
func (o ToolOverrides) ApplyToTools(ctx context.Context, tools []*ServerTool) []*ServerTool {
47+
if len(o) == 0 {
48+
return tools
49+
}
50+
51+
// First pass: check if any overrides apply (avoid allocation if not)
52+
var result []*ServerTool
53+
for i, tool := range tools {
54+
override, hasOverride := o[tool.Tool.Name]
55+
if !hasOverride {
56+
if result != nil {
57+
result[i] = tool
58+
}
59+
continue
60+
}
61+
62+
// Check condition
63+
var applies bool
64+
if override.Condition == nil {
65+
applies = true
66+
} else if matches, err := override.Condition(ctx); err == nil && matches {
67+
applies = true
68+
}
69+
70+
if applies {
71+
// Lazy allocation only when we find a match
72+
if result == nil {
73+
result = make([]*ServerTool, len(tools))
74+
copy(result[:i], tools[:i])
75+
}
76+
result[i] = &override.Override
77+
} else if result != nil {
78+
result[i] = tool
79+
}
80+
}
81+
82+
if result == nil {
83+
return tools // No overrides matched, return original
84+
}
85+
return result
86+
}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package inventory
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/modelcontextprotocol/go-sdk/mcp"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func makeTool(name string) ServerTool {
12+
return ServerTool{
13+
Tool: mcp.Tool{
14+
Name: name,
15+
Description: "Tool " + name,
16+
},
17+
}
18+
}
19+
20+
func makeOverride(name, desc string) ServerTool {
21+
return ServerTool{
22+
Tool: mcp.Tool{
23+
Name: name,
24+
Description: desc,
25+
},
26+
}
27+
}
28+
29+
func TestToolOverrides_Apply(t *testing.T) {
30+
t.Parallel()
31+
32+
overrides := ToolOverrides{
33+
"create_issue": {
34+
ToolName: "create_issue",
35+
Condition: func(_ context.Context) (bool, error) { return true, nil },
36+
Override: makeOverride("create_issue", "Enterprise variant"),
37+
},
38+
}
39+
40+
ctx := context.Background()
41+
42+
// Tool with override
43+
result := overrides.Apply(ctx, "create_issue")
44+
assert.NotNil(t, result)
45+
assert.Equal(t, "Enterprise variant", result.Tool.Description)
46+
47+
// Tool without override
48+
result = overrides.Apply(ctx, "list_repos")
49+
assert.Nil(t, result)
50+
}
51+
52+
func TestToolOverrides_Apply_ConditionFalse(t *testing.T) {
53+
t.Parallel()
54+
55+
overrides := ToolOverrides{
56+
"create_issue": {
57+
ToolName: "create_issue",
58+
Condition: func(_ context.Context) (bool, error) { return false, nil },
59+
Override: makeOverride("create_issue", "Enterprise variant"),
60+
},
61+
}
62+
63+
ctx := context.Background()
64+
65+
// Condition doesn't match - no override
66+
result := overrides.Apply(ctx, "create_issue")
67+
assert.Nil(t, result)
68+
}
69+
70+
func TestToolOverrides_Apply_NilCondition(t *testing.T) {
71+
t.Parallel()
72+
73+
overrides := ToolOverrides{
74+
"create_issue": {
75+
ToolName: "create_issue",
76+
// nil Condition - always applies
77+
Override: makeOverride("create_issue", "Always applied"),
78+
},
79+
}
80+
81+
ctx := context.Background()
82+
83+
result := overrides.Apply(ctx, "create_issue")
84+
assert.NotNil(t, result)
85+
assert.Equal(t, "Always applied", result.Tool.Description)
86+
}
87+
88+
func TestToolOverrides_ApplyToTools(t *testing.T) {
89+
t.Parallel()
90+
91+
tools := []*ServerTool{
92+
ptr(makeTool("create_issue")),
93+
ptr(makeTool("list_repos")),
94+
ptr(makeTool("get_me")),
95+
}
96+
97+
overrides := ToolOverrides{
98+
"create_issue": {
99+
ToolName: "create_issue",
100+
Condition: func(_ context.Context) (bool, error) { return true, nil },
101+
Override: makeOverride("create_issue", "Enterprise create_issue"),
102+
},
103+
}
104+
105+
ctx := context.Background()
106+
result := overrides.ApplyToTools(ctx, tools)
107+
108+
assert.Len(t, result, 3)
109+
assert.Equal(t, "Enterprise create_issue", result[0].Tool.Description)
110+
assert.Equal(t, "Tool list_repos", result[1].Tool.Description)
111+
assert.Equal(t, "Tool get_me", result[2].Tool.Description)
112+
}
113+
114+
func TestToolOverrides_ApplyToTools_Empty(t *testing.T) {
115+
t.Parallel()
116+
117+
tools := []*ServerTool{
118+
ptr(makeTool("create_issue")),
119+
}
120+
121+
overrides := ToolOverrides{}
122+
123+
ctx := context.Background()
124+
result := overrides.ApplyToTools(ctx, tools)
125+
126+
// Empty overrides returns original slice
127+
assert.Equal(t, tools, result)
128+
}
129+
130+
func ptr(t ServerTool) *ServerTool {
131+
return &t
132+
}
133+
134+
func BenchmarkToolOverrides_ApplyToTools(b *testing.B) {
135+
// 130 tools, 2 overrides (realistic)
136+
tools := make([]*ServerTool, 130)
137+
for i := range tools {
138+
tools[i] = ptr(makeTool("tool_" + string(rune('a'+i%26))))
139+
}
140+
141+
overrides := ToolOverrides{
142+
"tool_a": {
143+
ToolName: "tool_a",
144+
Condition: func(_ context.Context) (bool, error) { return true, nil },
145+
Override: makeOverride("tool_a", "Override A"),
146+
},
147+
"tool_b": {
148+
ToolName: "tool_b",
149+
Condition: func(_ context.Context) (bool, error) { return true, nil },
150+
Override: makeOverride("tool_b", "Override B"),
151+
},
152+
}
153+
154+
ctx := context.Background()
155+
156+
b.ReportAllocs() // Only count allocs in the hot loop
157+
b.ResetTimer()
158+
for i := 0; i < b.N; i++ {
159+
_ = overrides.ApplyToTools(ctx, tools)
160+
}
161+
}
162+
163+
func BenchmarkToolOverrides_ApplyToTools_NoMatch(b *testing.B) {
164+
// 130 tools, overrides don't match any - should be zero alloc
165+
tools := make([]*ServerTool, 130)
166+
for i := range tools {
167+
tools[i] = ptr(makeTool("tool_" + string(rune('a'+i%26))))
168+
}
169+
170+
// Override exists but for a tool not in list
171+
overrides := ToolOverrides{
172+
"nonexistent_tool": {
173+
ToolName: "nonexistent_tool",
174+
Override: makeOverride("nonexistent_tool", "Override"),
175+
},
176+
}
177+
178+
ctx := context.Background()
179+
180+
b.ReportAllocs()
181+
b.ResetTimer()
182+
for i := 0; i < b.N; i++ {
183+
_ = overrides.ApplyToTools(ctx, tools)
184+
}
185+
}

0 commit comments

Comments
 (0)