diff --git a/mlx_lm/tokenizer_utils.py b/mlx_lm/tokenizer_utils.py index c7e50fbe7..b2314d5ef 100644 --- a/mlx_lm/tokenizer_utils.py +++ b/mlx_lm/tokenizer_utils.py @@ -572,6 +572,8 @@ def _infer_tool_parser(chat_template): return "mistral" elif "" in chat_template and "tool_call.name" in chat_template: return "json_tools" + elif " + Beijing + 2024-06-27 + + +Multi-line or special-character values may be wrapped in CDATA: + + + +Ported from SGLang's MiniCPM5Detector (PR #25600). +""" + +import ast +import json +from typing import Any, Optional + +import regex as re + +tool_call_start: str = " opening tag (when the segment includes outer tags, +# as it does for unit-test inputs). +_func_name_full_regex = re.compile( + r"]*>", re.DOTALL +) + +# Bare leading `name="..."` (when the state machine has stripped the outer +# tag and the segment starts with the attribute body). +_func_name_bare_regex = re.compile( + r"^\s*name=[\"']([^\"']+)[\"'][^>]*>", re.DOTALL +) + +_param_regex = re.compile( + r"(.*?)", re.DOTALL +) + +# A tag with no name= attribute invalidates the whole call. +_param_missing_name_regex = re.compile(r"]*\bname=)[^>]*>", re.DOTALL) + +_cdata_regex = re.compile(r"^$", re.DOTALL) + + +def _coerce_value(value: str, want_type: Optional[str]) -> Any: + """Coerce a raw param string into the type declared by the tool schema. + + Strings pass through. Other types try strict JSON, then Python-literal, + then fall back to the raw string. + """ + if want_type == "string": + return value + try: + return json.loads(value) + except (json.JSONDecodeError, ValueError): + pass + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError): + return value + + +def _schema_for(tools: Optional[list], func_name: str): + """Return (param_types, allowed_props, required_props) for func_name.""" + if not tools: + return {}, set(), set() + for tool in tools: + func = tool.get("function") if isinstance(tool, dict) else None + if not func or func.get("name") != func_name: + continue + params = func.get("parameters") or {} + if not isinstance(params, dict): + return {}, set(), set() + props = params.get("properties") or {} + if not isinstance(props, dict): + return {}, set(), set() + types = { + k: (v.get("type") if isinstance(v, dict) else None) + for k, v in props.items() + } + required = set(params.get("required") or []) + return types, set(props.keys()), required + return {}, set(), set() + + +def parse_tool_call(text: str, tools: Optional[list] = None): + """Parse one MiniCPM5 XML tool call. + + The mlx-lm state machine emits one segment per `...` + pair, so this function returns a single call dict rather than a list. + + Raises ValueError on malformed XML or schema-violating calls; the server + layer (`ToolCallFormatter`) converts that into a logged warning and drops + the call. + """ + m = _func_name_full_regex.search(text) or _func_name_bare_regex.match(text) + if not m: + raise ValueError("No tool call found") + func_name = m.group(1) + + if _param_missing_name_regex.search(text): + raise ValueError(f"Tool call '{func_name}' has without name= attribute") + + param_types, allowed_props, required_props = _schema_for(tools, func_name) + + arguments: dict = {} + for pm in _param_regex.finditer(text): + key = pm.group(1) + if allowed_props and key not in allowed_props: + raise ValueError(f"Tool call '{func_name}' uses unknown param '{key}'") + if key in arguments: + raise ValueError(f"Tool call '{func_name}' has duplicate param '{key}'") + raw = pm.group(2).strip() + cdata = _cdata_regex.match(raw) + value = cdata.group(1) if cdata else raw + arguments[key] = _coerce_value(value, param_types.get(key)) + + missing = required_props - arguments.keys() + if missing: + raise ValueError( + f"Tool call '{func_name}' missing required params: {sorted(missing)}" + ) + + return {"name": func_name, "arguments": arguments} diff --git a/tests/test_tool_parsing.py b/tests/test_tool_parsing.py index 52892b7ff..43ca470c1 100644 --- a/tests/test_tool_parsing.py +++ b/tests/test_tool_parsing.py @@ -8,6 +8,7 @@ json_tools, kimi_k2, longcat, + minicpm5, minimax_m2, mistral, pythonic, @@ -37,6 +38,10 @@ def test_parsers(self): '\n12234585\n48838483920\n', minimax_m2, ), + ( + '1223458548838483920', + minicpm5, + ), ( "\n\n12234585\n\n\n48838483920\n\n", qwen3_coder, @@ -107,6 +112,10 @@ def test_parsers(self): '\nLondon\n', minimax_m2, ), + ( + 'London', + minicpm5, + ), ( "\n\nLondon\n\n", qwen3_coder, @@ -313,6 +322,109 @@ def test_kimi_k2(self): ] self.assertEqual(tool_calls, expected) + def test_minicpm5(self): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "required": ["city"], + "properties": { + "city": {"type": "string"}, + "date": {"type": "string"}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "sum_values", + "parameters": { + "type": "object", + "required": ["nums"], + "properties": { + "nums": {"type": "array"}, + "exact": {"type": "boolean"}, + }, + }, + }, + }, + ] + + # CDATA-wrapped multi-line param + test_case = ( + '' + '' + '2024-06-27' + "" + ) + tool_call = minicpm5.parse_tool_call(test_case, tools) + self.assertEqual(tool_call["name"], "get_weather") + self.assertEqual(tool_call["arguments"]["city"], "Bei\njing") + self.assertEqual(tool_call["arguments"]["date"], "2024-06-27") + + # Non-string typed params (array, boolean) + test_case = ( + '' + '[1, 2, 3]' + 'true' + "" + ) + tool_call = minicpm5.parse_tool_call(test_case, tools) + self.assertEqual(tool_call["arguments"]["nums"], [1, 2, 3]) + self.assertEqual(tool_call["arguments"]["exact"], True) + + # Body-only form (state machine stripped outer tags) + test_case = ' name="get_weather">Tokyo' + tool_call = minicpm5.parse_tool_call(test_case, tools) + self.assertEqual(tool_call["name"], "get_weather") + self.assertEqual(tool_call["arguments"]["city"], "Tokyo") + + # Missing required param → ValueError + test_case = ( + '2024-06-27' + ) + with self.assertRaises(ValueError): + minicpm5.parse_tool_call(test_case, tools) + + # Unknown param → ValueError + test_case = ( + 'x' + ) + with self.assertRaises(ValueError): + minicpm5.parse_tool_call(test_case, tools) + + # Duplicate param → ValueError + test_case = ( + '' + 'A' + 'B' + "" + ) + with self.assertRaises(ValueError): + minicpm5.parse_tool_call(test_case, tools) + + # missing name attr → ValueError + test_case = ( + '' + "nope" + "" + ) + with self.assertRaises(ValueError): + minicpm5.parse_tool_call(test_case, tools) + + # Single-quoted attributes (both function and param) + test_case = ( + "" + "Osaka" + "" + ) + tool_call = minicpm5.parse_tool_call(test_case, tools) + self.assertEqual(tool_call["arguments"]["city"], "Osaka") + def test_minimax_m2(self): test_case = ( '\n'