Skip to content

Commit 48a7606

Browse files
authored
feat (ai): support changing the system prompt in prepareSteps (#6503)
## Background Changing the system prompt for specific steps is necessary for swarm-like multi agent behavior. ## Summary Add system option to prepareStep result.
1 parent 3c3803c commit 48a7606

File tree

8 files changed

+82
-14
lines changed

8 files changed

+82
-14
lines changed

.changeset/proud-dancers-doubt.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ai): support changing the system prompt in prepareSteps
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import { stepCountIs, streamText, tool } from 'ai';
3+
import 'dotenv/config';
4+
import { z } from 'zod';
5+
6+
main().catch(console.error);
7+
8+
async function main() {
9+
const agentA = {
10+
system: 'You are a helpful agent.',
11+
activeTools: ['transferToAgentB'] as 'transferToAgentB'[],
12+
};
13+
14+
const agentB = {
15+
system: 'Only speak in Haikus.',
16+
activeTools: [],
17+
};
18+
19+
let activeAgent = agentA;
20+
21+
const result = streamText({
22+
model: openai('gpt-4o'),
23+
tools: {
24+
transferToAgentB: tool({
25+
description: 'Transfer to agent B.',
26+
parameters: z.object({}),
27+
execute: async () => {
28+
activeAgent = agentB;
29+
return 'Transferred to agent B.';
30+
},
31+
}),
32+
},
33+
stopWhen: stepCountIs(5),
34+
prepareStep: () => activeAgent,
35+
prompt: 'I want to talk to agent B.',
36+
});
37+
38+
for await (const chunk of result.textStream) {
39+
process.stdout.write(chunk);
40+
}
41+
}

packages/ai/core/generate-text/generate-text.test.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,10 @@ describe('options.stopWhen', () => {
828828
]);
829829

830830
expect(prompt).toStrictEqual([
831+
{
832+
role: 'system',
833+
content: 'system-message-0',
834+
},
831835
{
832836
role: 'user',
833837
content: [{ type: 'text', text: 'test-input' }],
@@ -873,6 +877,10 @@ describe('options.stopWhen', () => {
873877
expect(toolChoice).toStrictEqual({ type: 'auto' });
874878

875879
expect(prompt).toStrictEqual([
880+
{
881+
role: 'system',
882+
content: 'system-message-1',
883+
},
876884
{
877885
role: 'user',
878886
content: [{ type: 'text', text: 'test-input' }],
@@ -955,12 +963,17 @@ describe('options.stopWhen', () => {
955963
type: 'tool',
956964
toolName: 'tool1' as const,
957965
},
966+
system: 'system-message-0',
958967
};
959968
}
960969

961970
if (stepNumber === 1) {
962971
expect(steps.length).toStrictEqual(1);
963-
return { model: trueModel, activeTools: [] };
972+
return {
973+
model: trueModel,
974+
activeTools: [],
975+
system: 'system-message-1',
976+
};
964977
}
965978
},
966979
});

packages/ai/core/generate-text/generate-text.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ A function that attempts to repair a tool call that failed to parse.
280280

281281
const promptMessages = await convertToLanguageModelPrompt({
282282
prompt: {
283-
system: initialPrompt.system,
283+
system: prepareStepResult?.system ?? initialPrompt.system,
284284
messages: stepInputMessages,
285285
},
286286
supportedUrls: await model.supportedUrls,

packages/ai/core/generate-text/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export type {
66
GeneratedFile,
77
} from './generated-file';
88
export * as Output from './output';
9-
export type { PrepareStepFunction } from './prepare-step';
9+
export type { PrepareStepFunction, PrepareStepResult } from './prepare-step';
1010
export { smoothStream, type ChunkDetector } from './smooth-stream';
1111
export type { StepResult } from './step-result';
1212
export { hasToolCall, stepCountIs, type StopCondition } from './stop-condition';

packages/ai/core/generate-text/prepare-step.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,15 @@ export type PrepareStepFunction<
1919
steps: Array<StepResult<NoInfer<TOOLS>>>;
2020
stepNumber: number;
2121
model: LanguageModel;
22-
}) =>
23-
| PromiseLike<
24-
| {
25-
model?: LanguageModel;
26-
toolChoice?: ToolChoice<NoInfer<TOOLS>>;
27-
activeTools?: Array<keyof NoInfer<TOOLS>>;
28-
}
29-
| undefined
30-
>
22+
}) => PromiseLike<PrepareStepResult<TOOLS>> | PrepareStepResult<TOOLS>;
23+
24+
export type PrepareStepResult<
25+
TOOLS extends Record<string, Tool> = Record<string, Tool>,
26+
> =
3127
| {
3228
model?: LanguageModel;
3329
toolChoice?: ToolChoice<NoInfer<TOOLS>>;
3430
activeTools?: Array<keyof NoInfer<TOOLS>>;
31+
system?: string;
3532
}
3633
| undefined;

packages/ai/core/generate-text/stream-text.test.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2887,11 +2887,15 @@ describe('streamText', () => {
28872887
type: 'tool',
28882888
toolName: 'tool1' as const,
28892889
},
2890+
system: 'system-message-0',
28902891
};
28912892
}
28922893

28932894
if (stepNumber === 1) {
2894-
return { activeTools: [] };
2895+
return {
2896+
activeTools: [],
2897+
system: 'system-message-1',
2898+
};
28952899
}
28962900
},
28972901
});
@@ -2908,6 +2912,10 @@ describe('streamText', () => {
29082912
"maxOutputTokens": undefined,
29092913
"presencePenalty": undefined,
29102914
"prompt": [
2915+
{
2916+
"content": "system-message-0",
2917+
"role": "system",
2918+
},
29112919
{
29122920
"content": [
29132921
{
@@ -2958,6 +2966,10 @@ describe('streamText', () => {
29582966
"maxOutputTokens": undefined,
29592967
"presencePenalty": undefined,
29602968
"prompt": [
2969+
{
2970+
"content": "system-message-1",
2971+
"role": "system",
2972+
},
29612973
{
29622974
"content": [
29632975
{

packages/ai/core/generate-text/stream-text.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ class DefaultStreamTextResult<TOOLS extends ToolSet, OUTPUT, PARTIAL_OUTPUT>
856856

857857
const promptMessages = await convertToLanguageModelPrompt({
858858
prompt: {
859-
system: initialPrompt.system,
859+
system: prepareStepResult?.system ?? initialPrompt.system,
860860
messages: stepInputMessages,
861861
},
862862
supportedUrls: await model.supportedUrls,

0 commit comments

Comments
 (0)