Skip to content

Commit 1378f31

Browse files
committed
Add tests
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent ab047fe commit 1378f31

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

tensorrt_llm/serve/responses_utils.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,46 @@ def finish_reason_mapping(finish_reason: str) -> str:
620620
raise RuntimeError("Should never reach here!")
621621

622622

623+
def _response_output_item_to_chat_completion_message(
624+
item: Union[Dict,
625+
ResponseInputOutputItem]) -> ChatCompletionMessageParam:
626+
if not isinstance(item, dict):
627+
item = item.model_dump()
628+
629+
item_type = item.get("type", "")
630+
631+
match item_type:
632+
case "":
633+
if "role" in item:
634+
return item
635+
else:
636+
raise ValueError(f"Invalid input message item: {item}")
637+
case "message":
638+
return {
639+
"role": "assistant",
640+
"content": item["content"][0]["text"],
641+
}
642+
case "reasoning":
643+
return {
644+
"role": "assistant",
645+
"reasoning": item["content"][0]["text"],
646+
}
647+
case "function_call":
648+
return {
649+
"role": "function",
650+
"content": item["arguments"],
651+
}
652+
case "function_call_output":
653+
return {
654+
"role": "tool",
655+
"content": item["output"],
656+
"tool_call_id": item["call_id"],
657+
}
658+
case _:
659+
raise ValueError(
660+
f"Unsupported input item type: {item_type}, item: {item}")
661+
662+
623663
async def _create_input_messages(
624664
request: ResponsesRequest,
625665
prev_msgs: List[ChatCompletionMessageParam],
@@ -643,15 +683,8 @@ async def _create_input_messages(
643683
messages.append({"role": "user", "content": request.input})
644684
else:
645685
for inp in request.input:
646-
if inp.get("type", "") == "function_call_output":
647-
tool_call_inp = {
648-
"role": "tool",
649-
"content": inp["output"],
650-
"tool_call_id": inp["call_id"],
651-
}
652-
messages.append(tool_call_inp)
653-
else:
654-
messages.append(inp)
686+
messages.append(
687+
_response_output_item_to_chat_completion_message(inp))
655688

656689
return messages
657690

@@ -824,7 +857,7 @@ async def request_preprocess(
824857
sampling_params.return_perf_metrics = True
825858

826859
prev_msgs = []
827-
if enable_store:
860+
if enable_store and prev_response_id is not None:
828861
prev_msgs = await conversation_store.get_conversation_history(
829862
prev_response_id)
830863

tests/unittest/llmapi/apps/_test_openai_responses.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,29 @@
1212
pytestmark = pytest.mark.threadleak(enabled=False)
1313

1414

15-
@pytest.fixture(scope="module", ids=["GPT-OSS-20B"])
16-
def model():
17-
return "gpt_oss/gpt-oss-20b/"
15+
@pytest.fixture(scope="module",
16+
params=[
17+
"gpt_oss/gpt-oss-20b", "DeepSeek-R1-Distill-Qwen-1.5B",
18+
"Qwen3/Qwen3-0.6B"
19+
])
20+
def model(request):
21+
return request.param
1822

1923

2024
@pytest.fixture(scope="module")
2125
def server(model: str):
2226
model_path = get_model_path(model)
23-
with RemoteOpenAIServer(model_path) as remote_server:
27+
28+
args = []
29+
if model.startswith("Qwen3"):
30+
args.extend(["--reasoning_parser", "qwen3"])
31+
elif model.startswith("DeepSeek-R1"):
32+
args.extend(["--reasoning_parser", "deepseek-r1"])
33+
34+
if not model.startswith("gpt_oss"):
35+
args.extend(["--tool_parser", "qwen3"])
36+
37+
with RemoteOpenAIServer(model_path, args) as remote_server:
2438
yield remote_server
2539

2640

@@ -43,24 +57,30 @@ def check_reponse(response, prefix=""):
4357

4458
def check_tool_calling(response, first_resp=True, prefix=""):
4559
reasoning_exist, tool_call_exist, message_exist = False, False, False
60+
reasoning_content, message_content = "", ""
4661
function_call = None
4762
for output in response.output:
4863
if output.type == "reasoning":
4964
reasoning_exist = True
65+
reasoning_content = output.content[0].text
5066
elif output.type == "function_call":
5167
tool_call_exist = True
5268
function_call = output
5369
elif output.type == "message":
5470
message_exist = True
71+
message_content = output.content[0].text
5572

73+
err_msg = f"{prefix}Invalid tool calling {'1st' if first_resp else '2nd'} response:"
5674
if first_resp:
57-
assert reasoning_exist and tool_call_exist, f"{prefix}Invalid tool calling 1st response"
58-
assert not message_exist, f"{prefix}Invalid tool calling 1st response"
75+
assert reasoning_exist, f"{err_msg} reasoning content not exists! ({reasoning_content})"
76+
assert tool_call_exist, f"{err_msg} tool call content not exists! ({function_call})"
77+
assert not message_exist, f"{err_msg} message content should not exist! ({message_content})"
5978

6079
return function_call
6180
else:
62-
assert reasoning_exist and message_exist, f"{prefix}Invalid tool calling 2nd response"
63-
assert not tool_call_exist, f"{prefix}Invalid tool calling 2nd response"
81+
assert reasoning_exist, f"{err_msg} reasoning content not exists! ({reasoning_content})"
82+
assert message_exist, f"{err_msg} message content not exists! ({message_content})"
83+
assert not tool_call_exist, f"{err_msg} tool call content should not exist! ({function_call})"
6484

6585

6686
@pytest.mark.asyncio(loop_scope="module")
@@ -124,6 +144,9 @@ def get_current_weather(location: str, format: str = "celsius") -> dict:
124144

125145
@pytest.mark.asyncio(loop_scope="module")
126146
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
147+
if model.startswith("DeepSeek-R1"):
148+
pytest.skip("DeepSeek-R1 does not support tool calls")
149+
127150
tool_get_current_weather = {
128151
"type": "function",
129152
"name": "get_current_weather",
@@ -193,6 +216,9 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
193216

194217
@pytest.mark.asyncio(loop_scope="module")
195218
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
219+
if model.startswith("DeepSeek-R1"):
220+
pytest.skip("DeepSeek-R1 does not support tool calls")
221+
196222
tool_get_current_weather = {
197223
"type": "function",
198224
"name": "get_current_weather",
@@ -231,6 +257,8 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
231257
elif isinstance(event, ResponseReasoningTextDeltaEvent):
232258
reasoning_deltas.append(event.delta)
233259

260+
assert function_call is not None, "function call not exists!"
261+
234262
reasoning = "".join(reasoning_deltas)
235263
tool_args = json.loads(function_call.arguments)
236264

0 commit comments

Comments
 (0)