Skip to content

Commit 6db1ac4

Browse files
committed
Fix test and comments
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent 1378f31 commit 6db1ac4

File tree

3 files changed

+40
-30
lines changed

3 files changed

+40
-30
lines changed

tensorrt_llm/llmapi/reasoning_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def parse(self, text: str) -> ReasoningParserResult:
5757
# text before reasoning start tag is dropped
5858
text = splits[2]
5959
splits = text.partition(self.reasoning_end)
60-
reasoning_content, content = splits[0].strip(), splits[2].strip()
60+
reasoning_content, content = splits[0], splits[2]
6161
return ReasoningParserResult(content=content,
6262
reasoning_content=reasoning_content)
6363

tensorrt_llm/serve/responses_utils.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def _create_output_messages(
741741

742742
def _get_chat_completion_function_tools(
743743
tools: Optional[List[Tool]]) -> List[ChatCompletionToolsParam]:
744-
function_tools: List[ChatCompletionToolsParam]() = []
744+
function_tools: List[ChatCompletionToolsParam] = []
745745
if tools is None:
746746
return function_tools
747747

@@ -962,7 +962,7 @@ async def _create_output_content(
962962
reasoning_parser: Optional[str] = None,
963963
tool_parser: Optional[str] = None,
964964
tools: Optional[List[Tool]] = None,
965-
) -> List[ResponseOutputItem]:
965+
) -> Tuple[List[ResponseOutputItem], List[ChatCompletionMessageParam]]:
966966
output_items: List[ResponseOutputItem] = []
967967
output_messages: List[ChatCompletionMessageParam] = []
968968
available_tools = _get_chat_completion_function_tools(tools)
@@ -1036,7 +1036,8 @@ async def _create_output_content(
10361036

10371037

10381038
async def _create_output_content_harmony(
1039-
final_res: RequestOutput) -> List[ResponseOutputItem]:
1039+
final_res: RequestOutput
1040+
) -> Tuple[List[ResponseOutputItem], List[Message]]:
10401041
output_messages = _parse_output_tokens(final_res.outputs[0].token_ids)
10411042
output_content = []
10421043

@@ -1423,6 +1424,17 @@ def _generate_streaming_event(
14231424
delta_text = output.text_diff
14241425
calls = []
14251426

1427+
def check_parser(parser_id: Optional[str],
1428+
parser_dict: Optional[Dict[int, BaseReasoningParser]]):
1429+
if parser_id is not None:
1430+
if parser_dict is None:
1431+
raise RuntimeError(
1432+
f"Parser({parser_id}) dictionary is not provided for streaming"
1433+
)
1434+
1435+
check_parser(reasoning_parser_id, reasoning_parser_dict)
1436+
check_parser(tool_parser_id, tool_parser_dict)
1437+
14261438
delta_text, reasoning_delta_text = _apply_reasoning_parser(
14271439
reasoning_parser_id=reasoning_parser_id,
14281440
output_index=output_idx,
@@ -1447,13 +1459,6 @@ def _generate_streaming_event(
14471459
f" ---------> delta text: {delta_text}, reasoning delta text: {reasoning_delta_text}, calls: {calls}"
14481460
))
14491461

1450-
if reasoning_parser_dict is None:
1451-
raise RuntimeError(
1452-
"Reasoning parser dictionary is not provided for streaming")
1453-
if output_idx not in reasoning_parser_dict:
1454-
raise RuntimeError(
1455-
f"Reasoning parser for output index {output_idx} is not found")
1456-
14571462
# Check if we need to send done events for completed sections
14581463
should_send_reasoning_done, should_send_text_done, reasoning_full_content, text_full_content = _should_send_done_events(
14591464
output=output,

tests/unittest/llmapi/apps/_test_openai_responses.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def check_tool_calling(response, first_resp=True, prefix=""):
8686
@pytest.mark.asyncio(loop_scope="module")
8787
async def test_reasoning(client: openai.AsyncOpenAI, model: str):
8888
response = await client.responses.create(
89-
model=model, input="Which one is larger as numeric, 9.9 or 9.11?")
89+
model=model,
90+
input="Which one is larger as numeric, 9.9 or 9.11?",
91+
max_output_tokens=1024)
9092

9193
check_reponse(response, "test_reasoning: ")
9294

@@ -96,9 +98,10 @@ async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
9698
for effort in ["low", "medium", "high"]:
9799
response = await client.responses.create(
98100
model=model,
99-
instructions="Use less than 1024 tokens for reasoning",
101+
instructions="Use less than 1024 tokens for the whole response",
100102
input="Which one is larger as numeric, 9.9 or 9.11?",
101-
reasoning={"effort": effort})
103+
reasoning={"effort": effort},
104+
max_output_tokens=1024)
102105
check_reponse(response, f"test_reasoning_effort_{effort}: ")
103106

104107

@@ -121,20 +124,23 @@ async def test_chat(client: openai.AsyncOpenAI, model: str):
121124
}, {
122125
"role": "user",
123126
"content": "Tell me a joke."
124-
}])
127+
}],
128+
max_output_tokens=1024)
125129
check_reponse(response, "test_chat: ")
126130

127131

128132
@pytest.mark.asyncio(loop_scope="module")
129133
async def test_multi_turn_chat(client: openai.AsyncOpenAI, model: str):
130134
response = await client.responses.create(model=model,
131-
input="What is the answer of 1+1?")
135+
input="What is the answer of 1+1?",
136+
max_output_tokens=1024)
132137
check_reponse(response, "test_multi_turn_chat_1: ")
133138

134139
response_2 = await client.responses.create(
135140
model=model,
136141
input="What is the answer of previous question?",
137-
previous_response_id=response.id)
142+
previous_response_id=response.id,
143+
max_output_tokens=1024)
138144
check_reponse(response_2, "test_multi_turn_chat_2: ")
139145

140146

@@ -168,11 +174,10 @@ async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
168174
}
169175
}
170176
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
171-
response = await client.responses.create(
172-
model=model,
173-
input=messages,
174-
tools=[tool_get_current_weather],
175-
)
177+
response = await client.responses.create(model=model,
178+
input=messages,
179+
tools=[tool_get_current_weather],
180+
max_output_tokens=1024)
176181
messages.extend(response.output)
177182
function_call = check_tool_calling(response, True, "test_tool_calls: ")
178183

@@ -188,7 +193,8 @@ async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
188193

189194
response = await client.responses.create(model=model,
190195
input=messages,
191-
tools=[tool_get_current_weather])
196+
tools=[tool_get_current_weather],
197+
max_output_tokens=1024)
192198

193199
check_tool_calling(response, False, "test_tool_calls: ")
194200

@@ -199,7 +205,7 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
199205
model=model,
200206
input="Explain the theory of relativity in brief.",
201207
stream=True,
202-
)
208+
max_output_tokens=1024)
203209

204210
reasoning_deltas, message_deltas = list(), list()
205211
async for event in stream:
@@ -240,12 +246,11 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
240246
}
241247
}
242248
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
243-
stream = await client.responses.create(
244-
model=model,
245-
input=messages,
246-
tools=[tool_get_current_weather],
247-
stream=True,
248-
)
249+
stream = await client.responses.create(model=model,
250+
input=messages,
251+
tools=[tool_get_current_weather],
252+
stream=True,
253+
max_output_tokens=1024)
249254

250255
function_call = None
251256
reasoning_deltas = list()

0 commit comments

Comments
 (0)