diff --git a/README.md b/README.md index 642d4b4e..5d28e840 100644 --- a/README.md +++ b/README.md @@ -347,6 +347,34 @@ $eventDispatcher->addListener(ToolCallsExecuted::class, function (ToolCallsExecu }); ``` +#### Keeping Tool Messages + +Sometimes you might wish to keep the tool messages (`AssistantMessage` containing the `toolCalls` and `ToolCallMessage` containing the response) in the context. +Enable the `keepToolMessages` flag of the toolbox' `ChainProcessor` to ensure those messages will be added to your `MessageBag`. + +```php +use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor; +use PhpLlm\LlmChain\Chain\Toolbox\Toolbox; + +// Platform & LLM instantiation +$messages = new MessageBag( + Message::forSystem(<<call($messages); +// $messages will now include the tool messages +``` + #### Code Examples (with built-in tools) 1. [Brave Tool](examples/toolbox/brave.php) diff --git a/src/Chain/Toolbox/ChainProcessor.php b/src/Chain/Toolbox/ChainProcessor.php index 703ed419..48c7d730 100644 --- a/src/Chain/Toolbox/ChainProcessor.php +++ b/src/Chain/Toolbox/ChainProcessor.php @@ -33,6 +33,7 @@ public function __construct( private readonly ToolboxInterface $toolbox, private readonly ToolResultConverter $resultConverter = new ToolResultConverter(), private readonly ?EventDispatcherInterface $eventDispatcher = null, + private readonly bool $keepToolMessages = false, ) { } @@ -86,7 +87,7 @@ private function isFlatStringArray(array $tools): bool private function handleToolCallsCallback(Output $output): \Closure { return function (ToolCallResponse $response, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResponseInterface { - $messages = clone $output->messages; + $messages = $this->keepToolMessages ? $output->messages : clone $output->messages; if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->content) { $messages->add($streamedAssistantResponse); diff --git a/tests/Chain/Toolbox/ChainProcessorTest.php b/tests/Chain/Toolbox/ChainProcessorTest.php index b85c3579..8b17017c 100644 --- a/tests/Chain/Toolbox/ChainProcessorTest.php +++ b/tests/Chain/Toolbox/ChainProcessorTest.php @@ -4,13 +4,19 @@ namespace PhpLlm\LlmChain\Tests\Chain\Toolbox; +use PhpLlm\LlmChain\Chain\ChainInterface; use PhpLlm\LlmChain\Chain\Exception\MissingModelSupportException; use PhpLlm\LlmChain\Chain\Input; +use PhpLlm\LlmChain\Chain\Output; use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor; use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface; use PhpLlm\LlmChain\Platform\Capability; +use PhpLlm\LlmChain\Platform\Message\AssistantMessage; use PhpLlm\LlmChain\Platform\Message\MessageBag; +use PhpLlm\LlmChain\Platform\Message\ToolCallMessage; use PhpLlm\LlmChain\Platform\Model; +use PhpLlm\LlmChain\Platform\Response\ToolCall; +use PhpLlm\LlmChain\Platform\Response\ToolCallResponse; use PhpLlm\LlmChain\Platform\Tool\ExecutionReference; use PhpLlm\LlmChain\Platform\Tool\Tool; use PHPUnit\Framework\Attributes\CoversClass; @@ -20,7 +26,10 @@ #[CoversClass(ChainProcessor::class)] #[UsesClass(Input::class)] +#[UsesClass(Output::class)] #[UsesClass(Tool::class)] +#[UsesClass(ToolCall::class)] +#[UsesClass(ToolCallResponse::class)] #[UsesClass(ExecutionReference::class)] #[UsesClass(MessageBag::class)] #[UsesClass(MissingModelSupportException::class)] @@ -87,4 +96,54 @@ public function processInputWithUnsupportedToolCallingWillThrowException(): void $chainProcessor->processInput($input); } + + #[Test] + public function processOutputWithToolCallResponseKeepingMessages(): void + { + $toolbox = $this->createMock(ToolboxInterface::class); + $toolbox->expects($this->once())->method('execute')->willReturn('Test response'); + + $model = new Model('gpt-4', [Capability::TOOL_CALLING]); + + $messageBag = new MessageBag(); + + $response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1'])); + + $chain = $this->createStub(ChainInterface::class); + + $chainProcessor = new ChainProcessor($toolbox, keepToolMessages: true); + $chainProcessor->setChain($chain); + + $output = new Output($model, $response, $messageBag, []); + + $chainProcessor->processOutput($output); + + self::assertCount(2, $messageBag); + self::assertInstanceOf(AssistantMessage::class, $messageBag->getMessages()[0]); + self::assertInstanceOf(ToolCallMessage::class, $messageBag->getMessages()[1]); + } + + #[Test] + public function processOutputWithToolCallResponseForgettingMessages(): void + { + $toolbox = $this->createMock(ToolboxInterface::class); + $toolbox->expects($this->once())->method('execute')->willReturn('Test response'); + + $model = new Model('gpt-4', [Capability::TOOL_CALLING]); + + $messageBag = new MessageBag(); + + $response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1'])); + + $chain = $this->createStub(ChainInterface::class); + + $chainProcessor = new ChainProcessor($toolbox, keepToolMessages: false); + $chainProcessor->setChain($chain); + + $output = new Output($model, $response, $messageBag, []); + + $chainProcessor->processOutput($output); + + self::assertCount(0, $messageBag); + } }