|
|
import ast |
|
|
import json |
|
|
import logging |
|
|
import re |
|
|
from typing import Any, List, Optional |
|
|
|
|
|
from sglang.srt.entrypoints.openai.protocol import Tool |
|
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector |
|
|
from sglang.srt.function_call.core_types import ( |
|
|
StreamingParseResult, |
|
|
ToolCallItem, |
|
|
_GetInfoFunc, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class Qwen3CoderDetector(BaseFormatDetector): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.tool_call_start_token: str = "<tool_call>" |
|
|
self.tool_call_end_token: str = "</tool_call>" |
|
|
self.tool_call_prefix: str = "<function=" |
|
|
self.function_end_token: str = "</function>" |
|
|
self.parameter_prefix: str = "<parameter=" |
|
|
self.parameter_end_token: str = "</parameter>" |
|
|
|
|
|
|
|
|
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL) |
|
|
self.tool_call_function_regex = re.compile( |
|
|
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL |
|
|
) |
|
|
self.tool_call_parameter_regex = re.compile( |
|
|
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", |
|
|
re.DOTALL, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.parsed_pos: int = 0 |
|
|
|
|
|
self.current_tool_param_count: int = 0 |
|
|
|
|
|
self.json_started: bool = False |
|
|
|
|
|
|
|
|
self.is_inside_tool_call: bool = False |
|
|
|
|
|
|
|
|
self.current_func_name: Optional[str] = None |
|
|
|
|
|
def has_tool_call(self, text: str) -> bool: |
|
|
return self.tool_call_start_token in text |
|
|
|
|
|
def _get_arguments_config( |
|
|
self, func_name: str, tools: Optional[list[Tool]] |
|
|
) -> dict: |
|
|
"""Extract argument configuration for a function.""" |
|
|
if tools is None: |
|
|
return {} |
|
|
for config in tools: |
|
|
try: |
|
|
config_type = config.type |
|
|
config_function = config.function |
|
|
config_function_name = config_function.name |
|
|
except AttributeError: |
|
|
continue |
|
|
|
|
|
if config_type == "function" and config_function_name == func_name: |
|
|
try: |
|
|
params = config_function.parameters |
|
|
except AttributeError: |
|
|
return {} |
|
|
|
|
|
if isinstance(params, dict) and "properties" in params: |
|
|
return params["properties"] |
|
|
elif isinstance(params, dict): |
|
|
return params |
|
|
else: |
|
|
return {} |
|
|
logger.warning(f"Tool '{func_name}' is not defined in the tools list.") |
|
|
return {} |
|
|
|
|
|
def _convert_param_value( |
|
|
self, param_value: str, param_name: str, param_config: dict, func_name: str |
|
|
) -> Any: |
|
|
"""Convert parameter value based on its type in the schema.""" |
|
|
|
|
|
if param_value.lower() == "null": |
|
|
return None |
|
|
|
|
|
if param_name not in param_config: |
|
|
if param_config != {}: |
|
|
logger.warning( |
|
|
f"Parsed parameter '{param_name}' is not defined in the tool " |
|
|
f"parameters for tool '{func_name}', directly returning the string value." |
|
|
) |
|
|
return param_value |
|
|
|
|
|
if ( |
|
|
isinstance(param_config[param_name], dict) |
|
|
and "type" in param_config[param_name] |
|
|
): |
|
|
param_type = str(param_config[param_name]["type"]).strip().lower() |
|
|
else: |
|
|
param_type = "string" |
|
|
if param_type in ["string", "str", "text", "varchar", "char", "enum"]: |
|
|
return param_value |
|
|
elif ( |
|
|
param_type.startswith("int") |
|
|
or param_type.startswith("uint") |
|
|
or param_type.startswith("long") |
|
|
or param_type.startswith("short") |
|
|
or param_type.startswith("unsigned") |
|
|
): |
|
|
try: |
|
|
param_value = int(param_value) |
|
|
except Exception: |
|
|
logger.warning( |
|
|
f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " |
|
|
f"'{func_name}', degenerating to string." |
|
|
) |
|
|
return param_value |
|
|
elif param_type.startswith("num") or param_type.startswith("float"): |
|
|
try: |
|
|
maybe_convert = ( |
|
|
False if "." in param_value or "e" in param_value.lower() else True |
|
|
) |
|
|
param_value: float = float(param_value) |
|
|
if maybe_convert and param_value.is_integer(): |
|
|
param_value = int(param_value) |
|
|
except Exception: |
|
|
logger.warning( |
|
|
f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " |
|
|
f"'{func_name}', degenerating to string." |
|
|
) |
|
|
return param_value |
|
|
elif param_type in ["boolean", "bool", "binary"]: |
|
|
param_value = param_value.lower() |
|
|
if param_value not in ["true", "false"]: |
|
|
logger.warning( |
|
|
f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." |
|
|
) |
|
|
return param_value == "true" |
|
|
else: |
|
|
if ( |
|
|
param_type in ["object", "array", "arr"] |
|
|
or param_type.startswith("dict") |
|
|
or param_type.startswith("list") |
|
|
): |
|
|
try: |
|
|
param_value = json.loads(param_value) |
|
|
return param_value |
|
|
except Exception: |
|
|
logger.warning( |
|
|
f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool " |
|
|
f"'{func_name}', will try other methods to parse it." |
|
|
) |
|
|
try: |
|
|
param_value = ast.literal_eval(param_value) |
|
|
except Exception: |
|
|
logger.warning( |
|
|
f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string." |
|
|
) |
|
|
return param_value |
|
|
|
|
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: |
|
|
"""One-shot parsing for non-streaming scenarios.""" |
|
|
if self.tool_call_start_token not in text: |
|
|
return StreamingParseResult(normal_text=text) |
|
|
|
|
|
calls = [] |
|
|
try: |
|
|
|
|
|
|
|
|
raw_tool_calls = self.tool_call_regex.findall(text) |
|
|
if not raw_tool_calls: |
|
|
|
|
|
if self.tool_call_prefix in text: |
|
|
raw_tool_calls = [text] |
|
|
|
|
|
tool_idx = 0 |
|
|
for tool_content in raw_tool_calls: |
|
|
|
|
|
funcs = self.tool_call_function_regex.findall(tool_content) |
|
|
for func_match in funcs: |
|
|
func_body = func_match[0] or func_match[1] |
|
|
if ">" not in func_body: |
|
|
continue |
|
|
|
|
|
name_end = func_body.index(">") |
|
|
func_name = func_body[:name_end] |
|
|
params_str = func_body[name_end + 1 :] |
|
|
|
|
|
param_config = self._get_arguments_config(func_name, tools) |
|
|
parsed_params = {} |
|
|
|
|
|
for p_match in self.tool_call_parameter_regex.findall(params_str): |
|
|
if ">" not in p_match: |
|
|
continue |
|
|
p_idx = p_match.index(">") |
|
|
p_name = p_match[:p_idx] |
|
|
p_val = p_match[p_idx + 1 :] |
|
|
|
|
|
if p_val.startswith("\n"): |
|
|
p_val = p_val[1:] |
|
|
if p_val.endswith("\n"): |
|
|
p_val = p_val[:-1] |
|
|
|
|
|
parsed_params[p_name] = self._convert_param_value( |
|
|
p_val, p_name, param_config, func_name |
|
|
) |
|
|
|
|
|
calls.append( |
|
|
ToolCallItem( |
|
|
tool_index=tool_idx, |
|
|
name=func_name, |
|
|
parameters=json.dumps(parsed_params, ensure_ascii=False), |
|
|
) |
|
|
) |
|
|
tool_idx += 1 |
|
|
|
|
|
|
|
|
start_idx = text.find(self.tool_call_start_token) |
|
|
if start_idx == -1: |
|
|
start_idx = text.find(self.tool_call_prefix) |
|
|
normal_text = text[:start_idx] if start_idx > 0 else "" |
|
|
|
|
|
return StreamingParseResult(normal_text=normal_text, calls=calls) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in detect_and_parse: {e}") |
|
|
return StreamingParseResult(normal_text=text) |
|
|
|
|
|
def parse_streaming_increment( |
|
|
self, new_text: str, tools: List[Tool] |
|
|
) -> StreamingParseResult: |
|
|
""" |
|
|
Robust cursor-based streaming parser. |
|
|
""" |
|
|
self._buffer += new_text |
|
|
|
|
|
|
|
|
if not self._buffer: |
|
|
return StreamingParseResult() |
|
|
|
|
|
calls = [] |
|
|
normal_text_chunks = [] |
|
|
|
|
|
while True: |
|
|
|
|
|
current_slice = self._buffer[self.parsed_pos :] |
|
|
|
|
|
|
|
|
if not current_slice: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_slice.startswith(self.tool_call_start_token): |
|
|
self.parsed_pos += len(self.tool_call_start_token) |
|
|
self.is_inside_tool_call = True |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_slice.startswith(self.tool_call_prefix): |
|
|
end_angle = current_slice.find(">") |
|
|
if end_angle != -1: |
|
|
func_name = current_slice[len(self.tool_call_prefix) : end_angle] |
|
|
|
|
|
self.current_tool_id += 1 |
|
|
self.current_tool_name_sent = True |
|
|
self.current_tool_param_count = 0 |
|
|
self.json_started = False |
|
|
self.current_func_name = func_name |
|
|
|
|
|
calls.append( |
|
|
ToolCallItem( |
|
|
tool_index=self.current_tool_id, |
|
|
name=func_name, |
|
|
parameters="", |
|
|
) |
|
|
) |
|
|
|
|
|
self.parsed_pos += end_angle + 1 |
|
|
continue |
|
|
else: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_slice.startswith(self.parameter_prefix): |
|
|
name_end = current_slice.find(">") |
|
|
if name_end != -1: |
|
|
value_start_idx = name_end + 1 |
|
|
rest_of_slice = current_slice[value_start_idx:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cand_end_param = rest_of_slice.find(self.parameter_end_token) |
|
|
cand_next_param = rest_of_slice.find(self.parameter_prefix) |
|
|
cand_end_func = rest_of_slice.find(self.function_end_token) |
|
|
|
|
|
candidates = [] |
|
|
if cand_end_param != -1: |
|
|
candidates.append( |
|
|
(cand_end_param, len(self.parameter_end_token)) |
|
|
) |
|
|
if cand_next_param != -1: |
|
|
candidates.append((cand_next_param, 0)) |
|
|
if cand_end_func != -1: |
|
|
candidates.append((cand_end_func, 0)) |
|
|
|
|
|
if candidates: |
|
|
best_cand = min(candidates, key=lambda x: x[0]) |
|
|
end_pos = best_cand[0] |
|
|
end_token_len = best_cand[1] |
|
|
|
|
|
param_name = current_slice[ |
|
|
len(self.parameter_prefix) : name_end |
|
|
] |
|
|
raw_value = rest_of_slice[:end_pos] |
|
|
|
|
|
|
|
|
if raw_value.startswith("\n"): |
|
|
raw_value = raw_value[1:] |
|
|
if raw_value.endswith("\n"): |
|
|
raw_value = raw_value[:-1] |
|
|
|
|
|
|
|
|
if not self.json_started: |
|
|
calls.append( |
|
|
ToolCallItem( |
|
|
tool_index=self.current_tool_id, parameters="{" |
|
|
) |
|
|
) |
|
|
self.json_started = True |
|
|
|
|
|
param_config = self._get_arguments_config( |
|
|
self.current_func_name, tools |
|
|
) |
|
|
converted_val = self._convert_param_value( |
|
|
raw_value, param_name, param_config, self.current_func_name |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
json_key_val = f"{json.dumps(param_name)}: {json.dumps(converted_val, ensure_ascii=False)}" |
|
|
|
|
|
if self.current_tool_param_count > 0: |
|
|
fragment = f", {json_key_val}" |
|
|
else: |
|
|
fragment = json_key_val |
|
|
|
|
|
calls.append( |
|
|
ToolCallItem( |
|
|
tool_index=self.current_tool_id, parameters=fragment |
|
|
) |
|
|
) |
|
|
self.current_tool_param_count += 1 |
|
|
|
|
|
|
|
|
total_len = (name_end + 1) + end_pos + end_token_len |
|
|
self.parsed_pos += total_len |
|
|
continue |
|
|
|
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_slice.startswith(self.function_end_token): |
|
|
if not self.json_started: |
|
|
calls.append( |
|
|
ToolCallItem(tool_index=self.current_tool_id, parameters="{") |
|
|
) |
|
|
self.json_started = True |
|
|
|
|
|
calls.append( |
|
|
ToolCallItem(tool_index=self.current_tool_id, parameters="}") |
|
|
) |
|
|
self.parsed_pos += len(self.function_end_token) |
|
|
self.current_func_name = None |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_slice.startswith(self.tool_call_end_token): |
|
|
self.parsed_pos += len(self.tool_call_end_token) |
|
|
self.is_inside_tool_call = False |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
next_open_angle = current_slice.find("<") |
|
|
|
|
|
if next_open_angle == -1: |
|
|
|
|
|
if not self.is_inside_tool_call: |
|
|
normal_text_chunks.append(current_slice) |
|
|
|
|
|
self.parsed_pos += len(current_slice) |
|
|
continue |
|
|
|
|
|
elif next_open_angle == 0: |
|
|
|
|
|
|
|
|
possible_tags = [ |
|
|
self.tool_call_start_token, |
|
|
self.tool_call_end_token, |
|
|
self.tool_call_prefix, |
|
|
self.function_end_token, |
|
|
self.parameter_prefix, |
|
|
self.parameter_end_token, |
|
|
] |
|
|
|
|
|
is_potential_tag = False |
|
|
for tag in possible_tags: |
|
|
if tag.startswith(current_slice): |
|
|
is_potential_tag = True |
|
|
break |
|
|
|
|
|
if is_potential_tag: |
|
|
break |
|
|
else: |
|
|
|
|
|
if not self.is_inside_tool_call: |
|
|
normal_text_chunks.append("<") |
|
|
self.parsed_pos += 1 |
|
|
continue |
|
|
|
|
|
else: |
|
|
|
|
|
text_segment = current_slice[:next_open_angle] |
|
|
if not self.is_inside_tool_call: |
|
|
normal_text_chunks.append(text_segment) |
|
|
|
|
|
self.parsed_pos += next_open_angle |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if self.parsed_pos > 0: |
|
|
self._buffer = self._buffer[self.parsed_pos :] |
|
|
self.parsed_pos = 0 |
|
|
|
|
|
normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" |
|
|
return StreamingParseResult(calls=calls, normal_text=normal_text) |
|
|
|
|
|
def supports_structural_tag(self) -> bool: |
|
|
return False |
|
|
|
|
|
def structure_info(self) -> _GetInfoFunc: |
|
|
raise NotImplementedError |
|
|
|