-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_toolset.py
More file actions
355 lines (270 loc) · 12.1 KB
/
test_toolset.py
File metadata and controls
355 lines (270 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""Tests for StackOneToolSet."""
import asyncio
import base64
import fnmatch
import os
import string
from unittest.mock import patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from stackone_ai.toolset import (
StackOneToolSet,
ToolsetConfigError,
ToolsetError,
ToolsetLoadError,
_build_auth_header,
_run_async,
)
# Hypothesis strategies for PBT
# API key strategy with printable ASCII characters
api_key_strategy = st.text(
alphabet="".join(chr(i) for i in range(32, 127)),
min_size=1,
max_size=200,
)
# Tool name strategy (lowercase letters, digits, underscores)
tool_name_strategy = st.text(
alphabet=string.ascii_lowercase + string.digits + "_",
min_size=1,
max_size=50,
)
# Glob pattern strategy
glob_pattern_strategy = st.text(
alphabet=string.ascii_lowercase + string.digits + "_*?",
min_size=1,
max_size=50,
)
# Provider name strategy
provider_name_strategy = st.text(
alphabet=string.ascii_lowercase,
min_size=2,
max_size=20,
)
class TestToolsetErrors:
"""Test toolset error classes."""
def test_toolset_error_inheritance(self):
"""Test ToolsetError is base exception."""
error = ToolsetError("test error")
assert isinstance(error, Exception)
assert str(error) == "test error"
def test_toolset_config_error_inheritance(self):
"""Test ToolsetConfigError inherits from ToolsetError."""
error = ToolsetConfigError("config error")
assert isinstance(error, ToolsetError)
assert isinstance(error, Exception)
def test_toolset_load_error_inheritance(self):
"""Test ToolsetLoadError inherits from ToolsetError."""
error = ToolsetLoadError("load error")
assert isinstance(error, ToolsetError)
assert isinstance(error, Exception)
class TestBuildAuthHeader:
"""Test _build_auth_header function."""
def test_builds_basic_auth_header(self):
"""Test building Basic auth header from API key."""
result = _build_auth_header("test_api_key")
# Base64 of "test_api_key:"
assert result.startswith("Basic ")
assert result == "Basic dGVzdF9hcGlfa2V5Og=="
def test_builds_auth_header_with_special_chars(self):
"""Test auth header with special characters in key."""
result = _build_auth_header("key:with:colons")
assert result.startswith("Basic ")
@given(api_key=api_key_strategy)
@settings(max_examples=100)
def test_auth_header_format_pbt(self, api_key: str):
"""PBT: Test auth header format for various API keys."""
result = _build_auth_header(api_key)
# Should start with "Basic "
assert result.startswith("Basic ")
# Should be valid base64
encoded_part = result.replace("Basic ", "")
decoded = base64.b64decode(encoded_part).decode("utf-8")
# Decoded should be "api_key:"
assert decoded == f"{api_key}:"
@given(api_key=api_key_strategy)
@settings(max_examples=100)
def test_auth_header_round_trip_pbt(self, api_key: str):
"""PBT: Test that auth header can be decoded back to original key."""
result = _build_auth_header(api_key)
encoded_part = result.replace("Basic ", "")
decoded = base64.b64decode(encoded_part).decode("utf-8")
# Should be able to extract original key (remove trailing colon)
# The format is "api_key:" so we remove the last character
extracted_key = decoded[:-1] if decoded.endswith(":") else decoded
assert extracted_key == api_key
class TestRunAsync:
"""Test _run_async function."""
def test_run_async_outside_event_loop(self):
"""Test running async function when no event loop exists."""
async def simple_coroutine():
return "result"
result = _run_async(simple_coroutine())
assert result == "result"
def test_run_async_inside_event_loop(self):
"""Test running async function when already inside an event loop."""
async def inner_coroutine():
return "inner_result"
async def outer_coroutine():
# This simulates calling _run_async from within an event loop
return _run_async(inner_coroutine())
# Run the outer coroutine which calls _run_async internally
result = asyncio.run(outer_coroutine())
assert result == "inner_result"
def test_run_async_propagates_exceptions(self):
"""Test that exceptions from coroutines are propagated."""
async def failing_coroutine():
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
_run_async(failing_coroutine())
def test_run_async_propagates_exceptions_from_thread(self):
"""Test that exceptions are propagated when running in a thread."""
async def failing_coroutine():
raise RuntimeError("thread error")
async def wrapper():
return _run_async(failing_coroutine())
with pytest.raises(RuntimeError, match="thread error"):
asyncio.run(wrapper())
class TestStackOneToolSetInit:
"""Test StackOneToolSet initialization."""
def test_init_with_api_key(self):
"""Test initialization with explicit API key."""
toolset = StackOneToolSet(api_key="test_key")
assert toolset.api_key == "test_key"
assert toolset.account_id is None
assert toolset.base_url == "https://api.stackone.com"
def test_init_with_env_api_key(self):
"""Test initialization with API key from environment."""
with patch.dict(os.environ, {"STACKONE_API_KEY": "env_key"}):
toolset = StackOneToolSet()
assert toolset.api_key == "env_key"
def test_init_without_api_key_raises(self):
"""Test that missing API key raises ToolsetConfigError."""
with patch.dict(os.environ, {}, clear=True):
# Ensure STACKONE_API_KEY is not set
os.environ.pop("STACKONE_API_KEY", None)
with pytest.raises(ToolsetConfigError, match="API key must be provided"):
StackOneToolSet()
def test_init_with_account_id(self):
"""Test initialization with account ID."""
toolset = StackOneToolSet(api_key="test_key", account_id="acc123")
assert toolset.account_id == "acc123"
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
toolset = StackOneToolSet(api_key="test_key", base_url="https://custom.api.com")
assert toolset.base_url == "https://custom.api.com"
class TestStackOneToolSetNormalizeSchemaProperties:
"""Test _normalize_schema_properties method."""
def test_normalizes_properties_with_required(self):
"""Test normalizing schema with required fields."""
toolset = StackOneToolSet(api_key="test_key")
schema = {
"type": "object",
"properties": {
"required_field": {"type": "string", "description": "Required"},
"optional_field": {"type": "string", "description": "Optional"},
},
"required": ["required_field"],
}
result = toolset._normalize_schema_properties(schema)
assert result["required_field"]["nullable"] is False
assert result["optional_field"]["nullable"] is True
def test_handles_non_dict_properties(self):
"""Test handling non-dict property values."""
toolset = StackOneToolSet(api_key="test_key")
schema = {
"type": "object",
"properties": {
"simple_field": "string value",
},
}
result = toolset._normalize_schema_properties(schema)
assert result["simple_field"]["description"] == "string value"
def test_handles_missing_properties(self):
"""Test handling schema without properties."""
toolset = StackOneToolSet(api_key="test_key")
schema = {"type": "object"}
result = toolset._normalize_schema_properties(schema)
assert result == {}
def test_handles_non_dict_properties_value(self):
"""Test handling when properties is not a dict."""
toolset = StackOneToolSet(api_key="test_key")
schema = {
"type": "object",
"properties": "not a dict",
}
result = toolset._normalize_schema_properties(schema)
assert result == {}
class TestStackOneToolSetBuildMcpHeaders:
"""Test _build_mcp_headers method."""
def test_builds_headers_without_account(self):
"""Test building MCP headers without account ID."""
toolset = StackOneToolSet(api_key="test_key")
headers = toolset._build_mcp_headers(None)
assert "Authorization" in headers
assert "User-Agent" in headers
assert "x-account-id" not in headers
def test_builds_headers_with_account(self):
"""Test building MCP headers with account ID."""
toolset = StackOneToolSet(api_key="test_key")
headers = toolset._build_mcp_headers("acc123")
assert "Authorization" in headers
assert "User-Agent" in headers
assert headers["x-account-id"] == "acc123"
def test_set_accounts():
"""Test setting account IDs for filtering"""
toolset = StackOneToolSet(api_key="test_key")
result = toolset.set_accounts(["acc1", "acc2"])
# Should return self for chaining
assert result is toolset
assert toolset._account_ids == ["acc1", "acc2"]
def test_filter_by_provider():
"""Test provider filtering"""
toolset = StackOneToolSet(api_key="test_key")
# Test matching providers
assert toolset._filter_by_provider("hibob_list_employees", ["hibob", "bamboohr"])
assert toolset._filter_by_provider("bamboohr_create_job", ["hibob", "bamboohr"])
# Test non-matching providers
assert not toolset._filter_by_provider("workday_list_contacts", ["hibob", "bamboohr"])
# Test case-insensitive matching
assert toolset._filter_by_provider("HIBOB_list_employees", ["hibob"])
assert toolset._filter_by_provider("hibob_list_employees", ["HIBOB"])
def test_filter_by_action():
"""Test action filtering with glob patterns"""
toolset = StackOneToolSet(api_key="test_key")
# Test exact match
assert toolset._filter_by_action("hibob_list_employees", ["hibob_list_employees"])
# Test glob pattern
assert toolset._filter_by_action("hibob_list_employees", ["*_list_employees"])
assert toolset._filter_by_action("bamboohr_list_employees", ["*_list_employees"])
assert toolset._filter_by_action("hibob_list_employees", ["hibob_*"])
assert toolset._filter_by_action("hibob_create_employee", ["hibob_*"])
# Test non-matching patterns
assert not toolset._filter_by_action("workday_list_contacts", ["*_list_employees"])
assert not toolset._filter_by_action("bamboohr_create_job", ["hibob_*"])
@given(
tool_name=tool_name_strategy,
pattern=glob_pattern_strategy,
)
@settings(max_examples=100)
def test_filter_by_action_matches_fnmatch_pbt(tool_name: str, pattern: str):
"""PBT: Test that action filtering matches Python fnmatch behavior."""
toolset = StackOneToolSet(api_key="test_key")
result = toolset._filter_by_action(tool_name, [pattern])
expected = fnmatch.fnmatch(tool_name, pattern)
assert result == expected, f"Mismatch for tool='{tool_name}', pattern='{pattern}'"
@given(
provider=provider_name_strategy,
action=st.text(alphabet=string.ascii_lowercase + "_", min_size=1, max_size=20),
entity=st.text(alphabet=string.ascii_lowercase, min_size=1, max_size=20),
)
@settings(max_examples=100)
def test_filter_by_provider_case_insensitive_pbt(provider: str, action: str, entity: str):
"""PBT: Test that provider filtering is case-insensitive."""
toolset = StackOneToolSet(api_key="test_key")
tool_name = f"{provider}_{action}_{entity}"
# Should match regardless of case
assert toolset._filter_by_provider(tool_name, [provider.lower()])
assert toolset._filter_by_provider(tool_name, [provider.upper()])
assert toolset._filter_by_provider(tool_name.upper(), [provider.lower()])
assert toolset._filter_by_provider(tool_name.lower(), [provider.upper()])