Skip to content
131 changes: 127 additions & 4 deletions libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""MLX Chat Wrapper."""

import json
import logging
import re
import uuid
from typing import (
Any,
Callable,
Expand All @@ -9,6 +13,7 @@
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
Expand All @@ -24,7 +29,13 @@
AIMessageChunk,
BaseMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
Expand All @@ -38,9 +49,57 @@

from langchain_community.llms.mlx_pipeline import MLXPipeline

logger = logging.getLogger(__name__)

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""


def _parse_react_tool_calls(
text: str,
) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]:
"""Extract ReAct-style tool calls from plain text output.

Args:
text: Raw model generation text.

Returns:
A tuple containing a list of parsed ``ToolCall`` objects if any were
detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects
for unparseable patterns.
"""

tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []

bracket_pattern = r"Action:\s*(?P<name>[\w.-]+)\[(?P<input>[^\]]+)\]"
separate_pattern = (
r"Action:\s*(?P<name>[^\n]+)\nAction Input:\s*(?P<input>[^\n]+)"
)

matches = list(re.finditer(bracket_pattern, text))
if not matches:
matches = list(re.finditer(separate_pattern, text))

for match in matches:
name = match.group("name").strip()
arg_text = match.group("input").strip()
try:
args = json.loads(arg_text)
if not isinstance(args, dict):
args = {"input": args}
except Exception:
args = {"input": arg_text}
tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args))

if not tool_calls and "Action:" in text:
invalid_tool_calls.append(
make_invalid_tool_call(text, "Could not parse ReAct tool call")
)
return None, invalid_tool_calls

return tool_calls or None, invalid_tool_calls


class ChatMLX(BaseChatModel):
"""MLX chat models.

Expand Down Expand Up @@ -69,14 +128,34 @@ def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tokenizer = self.llm.tokenizer

def _parse_tool_args(self, arg_text: str) -> Dict[str, Any]:
"""Parse the arguments for a tool call.

Args:
arg_text: JSON string representation of the tool arguments.

Returns:
Parsed arguments dictionary. If parsing fails, returns a dict with
the original text under the ``input`` key.
"""
try:
args = json.loads(arg_text)
except json.JSONDecodeError:
args = {"input": arg_text}
except Exception as e: # pragma: no cover - defensive
logger.warning("Unexpected error during tool argument parsing: %s", e)
args = {"input": arg_text}
return args

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
tools = kwargs.pop("tools", None)
llm_input = self._to_chat_prompt(messages, tools=tools)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -89,7 +168,8 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
tools = kwargs.pop("tools", None)
llm_input = self._to_chat_prompt(messages, tools=tools)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -100,8 +180,17 @@ def _to_chat_prompt(
messages: List[BaseMessage],
tokenize: bool = False,
return_tensors: Optional[str] = None,
tools: Sequence[dict] | None = None,
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
"""Convert messages to the prompt format expected by the wrapped LLM.

Args:
messages: Chat messages to include in the prompt.
tokenize: Whether to return token IDs instead of text.
return_tensors: Framework for returned tensors when ``tokenize`` is
True.
tools: Optional tool definitions to include in the prompt.
"""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")

Expand All @@ -114,6 +203,7 @@ def _to_chat_prompt(
tokenize=tokenize,
add_generation_prompt=True,
return_tensors=return_tensors,
tools=tools,
)

def _to_chatml_format(self, message: BaseMessage) -> dict:
Expand All @@ -135,8 +225,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []

for g in llm_result.generations[0]:
tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []
additional_kwargs: Dict[str, Any] = {}

if isinstance(g.generation_info, dict):
raw_tool_calls = g.generation_info.get("tool_calls")
else:
raw_tool_calls = None

if raw_tool_calls:
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tc = parse_tool_call(raw_tool_call, return_id=True)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
else:
if tc:
tool_calls.append(tc)
else:
react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text)
if react_tool_calls is not None:
tool_calls.extend(react_tool_calls)
invalid_tool_calls.extend(invalid_reacts)

chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
message=AIMessage(
content=g.text,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
),
generation_info=g.generation_info,
)
chat_generations.append(chat_generation)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Tests ChatMLX tool calling."""

from typing import Dict

import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool

from langchain_community.chat_models.mlx import ChatMLX
from langchain_community.llms.mlx_pipeline import MLXPipeline

# Use a Phi-3 model for more reliable tool-calling behavior
MODEL_ID = "mlx-community/phi-3-mini-128k-instruct"


@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b


@pytest.fixture(scope="module")
def chat() -> ChatMLX:
"""Return ChatMLX bound with the multiply tool or skip if unavailable."""
try:
llm = MLXPipeline.from_model_id(
model_id=MODEL_ID, pipeline_kwargs={"max_new_tokens": 150}
)
except Exception:
pytest.skip("Required MLX model isn't available.", allow_module_level=True)
chat_model = ChatMLX(llm=llm)
return chat_model.bind_tools(tools=[multiply], tool_choice=True) # type: ignore[return-value]


def _call_tool(tool_call: Dict) -> ToolMessage:
result = multiply.invoke(tool_call["args"])
return ToolMessage(content=str(result), tool_call_id=tool_call.get("id", ""))


def test_mlx_tool_calls_soft(chat: ChatMLX) -> None:
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
ai_msg = chat.invoke(messages)
tool_msg = _call_tool(ai_msg.tool_calls[0])
final = chat.invoke(messages + [ai_msg, tool_msg])
assert "6" in final.content


def test_mlx_tool_calls_hard(chat: ChatMLX) -> None:
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
ai_msg = chat.invoke(messages)
assert isinstance(ai_msg, AIMessage)
assert ai_msg.tool_calls
tool_call = ai_msg.tool_calls[0]
assert tool_call["name"] == "multiply"
assert tool_call["args"] == {"a": 2, "b": 3}
71 changes: 71 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,42 @@

from importlib import import_module

import pytest
from langchain_core.messages import HumanMessage

from langchain_community.chat_models.mlx import ChatMLX


class _FakeTokenizer:
def __init__(self) -> None:
self.tools = None

def apply_chat_template(
self,
messages,
tokenize=False,
add_generation_prompt=True,
return_tensors=None,
tools=None,
) -> str:
self.tools = tools
return "prompt"


class _FakeLLM:
def __init__(self) -> None:
self.tokenizer = _FakeTokenizer()

def _generate(self, prompts, stop=None, run_manager=None, **kwargs):
class _Res:
generations = [[type("G", (), {"text": "", "generation_info": {}})]]
llm_output = {}

return _Res()

async def _agenerate(self, prompts, stop=None, run_manager=None, **kwargs):
return self._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)


def test_import_class() -> None:
"""Test that the class can be imported."""
Expand All @@ -10,3 +46,38 @@ def test_import_class() -> None:

module = import_module(module_name)
assert hasattr(module, class_name)


def test_generate_passes_tools_to_tokenizer() -> None:
llm = _FakeLLM()
chat = ChatMLX(llm=llm)
tools = [
{
"type": "function",
"function": {
"name": "foo",
"description": "",
"parameters": {"type": "object", "properties": {}},
},
}
]
chat._generate([HumanMessage(content="hi")], tools=tools)
assert llm.tokenizer.tools == tools


@pytest.mark.asyncio
async def test_agenerate_passes_tools_to_tokenizer() -> None:
llm = _FakeLLM()
chat = ChatMLX(llm=llm)
tools = [
{
"type": "function",
"function": {
"name": "foo",
"description": "",
"parameters": {"type": "object", "properties": {}},
},
}
]
await chat._agenerate([HumanMessage(content="hi")], tools=tools)
assert llm.tokenizer.tools == tools