Skip to content

Commit 41a01a3

Browse files
feat: add option to keep tool messages (#321)
Fixes #321
1 parent be5d914 commit 41a01a3

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,34 @@ $eventDispatcher->addListener(ToolCallsExecuted::class, function (ToolCallsExecu
341341
});
342342
```
343343

344+
#### Keeping Tool Messages
345+
346+
Sometimes you might wish to keep the tool messages (`AssistantMessage` containing the `toolCalls` and `ToolCallMessage` containing the response) in the context.
347+
Enable the `keepToolMessages` flag to ensure those messages will be added to your `MessageBag`.
348+
349+
```php
350+
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
351+
use PhpLlm\LlmChain\Chain\Toolbox\Toolbox;
352+
353+
// Platform & LLM instantiation
354+
$messages = new MessageBag(
355+
Message::forSystem(<<<PROMPT
356+
Please answer all user questions only using the similary_search tool. Do not add information and if you cannot
357+
find an answer, say so.
358+
PROMPT),
359+
Message::ofUser('...') // The user's question.
360+
);
361+
362+
$yourTool = new YourTool();
363+
364+
$toolbox = Toolbox::create($yourTool);
365+
$toolProcessor = new ChainProcessor($toolbox, keepToolMessages: true);
366+
367+
$chain = new Chain($platform, $llm, inputProcessor: [$toolProcessor], outputProcessor: [$toolProcessor]);
368+
$response = $chain->call($messages);
369+
// $messages will now include the tool messages
370+
```
371+
344372
#### Code Examples (with built-in tools)
345373
346374
1. [Brave Tool](examples/toolbox/brave.php)

src/Chain/Toolbox/ChainProcessor.php

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public function __construct(
2828
private readonly ToolboxInterface $toolbox,
2929
private readonly ToolResultConverter $resultConverter = new ToolResultConverter(),
3030
private readonly ?EventDispatcherInterface $eventDispatcher = null,
31+
private readonly bool $keepToolMessages = false,
3132
) {
3233
}
3334

@@ -81,7 +82,7 @@ private function isFlatStringArray(array $tools): bool
8182
private function handleToolCallsCallback(Output $output): \Closure
8283
{
8384
return function (ToolCallResponse $response, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResponseInterface {
84-
$messages = clone $output->messages;
85+
$messages = $this->keepToolMessages ? $output->messages : clone $output->messages;
8586

8687
if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->content) {
8788
$messages->add($streamedAssistantResponse);

tests/Chain/Toolbox/ChainProcessorTest.php

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,26 @@
44

55
namespace PhpLlm\LlmChain\Tests\Chain\Toolbox;
66

7+
use PhpLlm\LlmChain\Chain;
78
use PhpLlm\LlmChain\Chain\Input;
9+
use PhpLlm\LlmChain\Chain\Output;
810
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
911
use PhpLlm\LlmChain\Chain\Toolbox\ExecutionReference;
1012
use PhpLlm\LlmChain\Chain\Toolbox\Metadata;
1113
use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface;
1214
use PhpLlm\LlmChain\Exception\MissingModelSupport;
1315
use PhpLlm\LlmChain\Model\LanguageModel;
16+
use PhpLlm\LlmChain\Model\Message\AssistantMessage;
1417
use PhpLlm\LlmChain\Model\Message\MessageBag;
18+
use PhpLlm\LlmChain\Model\Message\ToolCallMessage;
19+
use PhpLlm\LlmChain\Model\Response\ToolCall;
20+
use PhpLlm\LlmChain\Model\Response\ToolCallResponse;
21+
use PhpLlm\LlmChain\PlatformInterface;
1522
use PHPUnit\Framework\Attributes\CoversClass;
1623
use PHPUnit\Framework\Attributes\Test;
1724
use PHPUnit\Framework\Attributes\UsesClass;
1825
use PHPUnit\Framework\TestCase;
26+
use Symfony\Contracts\EventDispatcher\EventDispatcherInterface;
1927

2028
#[CoversClass(ChainProcessor::class)]
2129
#[UsesClass(Input::class)]
@@ -93,4 +101,60 @@ public function processInputWithUnsupportedToolCallingWillThrowException(): void
93101

94102
$chainProcessor->processInput($input);
95103
}
104+
105+
#[Test]
106+
public function processOutputWithToolCallResponseKeepingMessages(): void
107+
{
108+
$toolbox = $this->createMock(ToolboxInterface::class);
109+
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
110+
111+
$llm = $this->createStub(LanguageModel::class);
112+
113+
$eventDispatcher = $this->createMock(EventDispatcherInterface::class);
114+
$eventDispatcher->expects($this->once())->method('dispatch');
115+
116+
$messageBag = new MessageBag();
117+
118+
$response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
119+
120+
$chain = new Chain($this->createStub(PlatformInterface::class), $llm);
121+
122+
$chainProcessor = new ChainProcessor($toolbox, eventDispatcher: $eventDispatcher, keepToolMessages: true);
123+
$chainProcessor->setChain($chain);
124+
125+
$output = new Output($llm, $response, $messageBag, []);
126+
127+
$chainProcessor->processOutput($output);
128+
129+
self::assertCount(2, $messageBag);
130+
self::assertInstanceOf(AssistantMessage::class, $messageBag->getMessages()[0]);
131+
self::assertInstanceOf(ToolCallMessage::class, $messageBag->getMessages()[1]);
132+
}
133+
134+
#[Test]
135+
public function processOutputWithToolCallResponseForgettingMessages(): void
136+
{
137+
$toolbox = $this->createMock(ToolboxInterface::class);
138+
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
139+
140+
$llm = $this->createStub(LanguageModel::class);
141+
142+
$eventDispatcher = $this->createMock(EventDispatcherInterface::class);
143+
$eventDispatcher->expects($this->once())->method('dispatch');
144+
145+
$messageBag = new MessageBag();
146+
147+
$response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
148+
149+
$chain = new Chain($this->createStub(PlatformInterface::class), $llm);
150+
151+
$chainProcessor = new ChainProcessor($toolbox, eventDispatcher: $eventDispatcher, keepToolMessages: false);
152+
$chainProcessor->setChain($chain);
153+
154+
$output = new Output($llm, $response, $messageBag, []);
155+
156+
$chainProcessor->processOutput($output);
157+
158+
self::assertCount(0, $messageBag);
159+
}
96160
}

0 commit comments

Comments
 (0)