Skip to content
Merged
4 changes: 2 additions & 2 deletions strands-ts/src/agent/__tests__/agent.stateful-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ describe('Agent with stateful model', () => {
})

describe('invocation', () => {
it('passes agent.modelState to the model via streamOptions.modelState', async () => {
it('passes modelState snapshot to the model and writes back changes', async () => {
const model = new StatefulMockModel(['resp_first']).addTurn({ type: 'textBlock', text: 'Hi' })
const agent = new Agent({ model, printer: false })
await agent.invoke('Hello')
expect(model.receivedOptions[0]?.modelState).toBe(agent.modelState)
expect(model.receivedOptions[0]?.modelState).not.toBe(agent.modelState)
expect(agent.modelState.getAll()).toEqual({ responseId: 'resp_first' })
})

Expand Down
38 changes: 26 additions & 12 deletions strands-ts/src/agent/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@ import {
type ToolResultBlockData,
ToolUseBlock,
} from '../types/messages.js'
import { deepCopy } from '../types/json.js'
import type { JSONValue } from '../types/json.js'
import { McpClient } from '../mcp.js'
import { isValidToolName, type Tool, type ToolContext } from '../tools/tool.js'
import type { ToolChoice, ToolSpec } from '../tools/types.js'
import { systemPromptFromData } from '../types/messages.js'
import { cloneSystemPrompt, systemPromptFromData } from '../types/messages.js'
import { normalizeError, ConcurrentInvocationError, StructuredOutputError } from '../errors.js'
import { Model } from '../models/model.js'
import type { BaseModelConfig, StreamAggregatedResult, StreamOptions } from '../models/model.js'
import { ModelPlugin } from '../plugins/model-plugin.js'
import { isModelStreamEvent } from '../models/streaming.js'
import { ToolRegistry } from '../registry/tool-registry.js'
import { StateStore } from '../state-store.js'
import { serializeStateSerializable, loadStateSerializable } from '../types/serializable.js'
import { AgentPrinter, getDefaultAppender, type Printer } from './printer.js'
import type { Plugin } from '../plugins/plugin.js'
import type { InterventionHandler } from '../interventions/handler.js'
Expand Down Expand Up @@ -1919,14 +1921,19 @@ export class Agent implements LocalAgent, InvokableAgent {
): AsyncGenerator<AgentStreamEvent, StreamAggregatedResult, undefined> {
const context: InvokeModelContext = {
agent: this,
messages: this.messages,
...(this.systemPrompt !== undefined && { systemPrompt: this.systemPrompt }),
toolSpecs: this._toolRegistry.list().map((tool) => tool.toolSpec),
...(toolChoice !== undefined && { toolChoice }),
modelState: this.modelState,
messages: this.messages.map((msg) => msg.clone()),
Comment thread
zastrowm marked this conversation as resolved.
Comment thread
zastrowm marked this conversation as resolved.
...(this.systemPrompt !== undefined && { systemPrompt: cloneSystemPrompt(this.systemPrompt) }),
toolSpecs: deepCopy(this._toolRegistry.list().map((tool) => tool.toolSpec)) as unknown as ToolSpec[],
...(toolChoice !== undefined && { toolChoice: deepCopy(toolChoice) as unknown as ToolChoice }),
Comment thread
zastrowm marked this conversation as resolved.
invocationState,
}

// Snapshot model state before middleware runs so concurrent mutations don't leak in.
// The writeback happens after the entire middleware chain completes, so middleware
// cannot affect modelState at any point (before or after next()).
Comment thread
zastrowm marked this conversation as resolved.
const modelStateSnapshot = this.modelState.getAll()
let tempModelState: StateStore | undefined

// async function* doesn't bind lexical `this`; capture for the terminal callback.
// eslint-disable-next-line @typescript-eslint/no-this-alias
const self = this
Expand All @@ -1942,9 +1949,12 @@ export class Agent implements LocalAgent, InvokableAgent {
})

try {
// Wrap the snapshot into a StateStore for the model provider, which expects
// get/set methods.
tempModelState = new StateStore(modelStateSnapshot)
const streamOptions: StreamOptions = {
toolSpecs: ctx.toolSpecs as ToolSpec[],
modelState: ctx.modelState,
modelState: tempModelState,
...(ctx.systemPrompt !== undefined && { systemPrompt: ctx.systemPrompt }),
...(ctx.toolChoice && { toolChoice: ctx.toolChoice }),
}
Expand All @@ -1971,6 +1981,14 @@ export class Agent implements LocalAgent, InvokableAgent {
}
}
)

// Sync model state after the entire middleware chain has completed, so no
// middleware mutation to agent.modelState (before or after next()) takes effect.
Comment thread
zastrowm marked this conversation as resolved.
// Intentionally skipped on error — partial provider writes should not persist.
if (tempModelState) {
Comment thread
zastrowm marked this conversation as resolved.
loadStateSerializable(this.modelState, serializeStateSerializable(tempModelState))
}

return middlewareResult.result
}

Expand Down Expand Up @@ -2455,11 +2473,7 @@ export class Agent implements LocalAgent, InvokableAgent {
const context: ExecuteToolContext = {
agent: this,
tool,
toolUse: {
name: toolUse.name,
toolUseId: toolUse.toolUseId,
input: toolUse.input,
},
toolUse: deepCopy(toolUse) as unknown as ToolUseData,
invocationState,
interrupt: createMiddlewareInterrupt(this._interruptState, `middleware:executeTool:${toolUse.toolUseId}`),
}
Expand Down
4 changes: 1 addition & 3 deletions strands-ts/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ export { SandboxTimeoutError, SandboxAbortError, SandboxPathNotFoundError } from
export type { StreamType, StreamChunk, FileInfo, OutputFile, ExecutionResult } from './sandbox/types.js'

// Middleware system
export { InvokeModelStage, ExecuteToolStage, AgentStreamStage } from './middleware/index.js'
export { InvokeModelStage, ExecuteToolStage } from './middleware/index.js'
export type {
MiddlewareStage,
MiddlewareHandler,
Expand All @@ -323,8 +323,6 @@ export type {
InvokeModelResult,
ExecuteToolContext,
ExecuteToolResult,
AgentStreamContext,
AgentStreamResult,
MiddlewareInterruptResult,
MiddlewareInterruptible,
} from './middleware/index.js'
Expand Down
14 changes: 14 additions & 0 deletions strands-ts/src/middleware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ Middleware is intended to supersede hooks. The Input/Output phases already cover

`createStage` is not exported from the public API. Only the three built-in stages are public.

## Middleware cannot access or modify model state

`modelState` is intentionally excluded from `InvokeModelContext`. The agent snapshots model state before the middleware chain runs and writes back the model provider's changes after the entire chain completes. Any mutations middleware makes to `agent.modelState` — whether before or after `next()` — are overwritten by this writeback.

We may revisit this later.

## `AgentStreamStage` is internal

`AgentStreamStage` is not exported from the public API. The context for this stage passes `args` and `options` by reference, but the correct contract (copy vs. reference, readonly enforcement) hasn't been finalized. Rather than ship an inconsistent surface, we're keeping it internal until we decide whether `args` should be deep-copied (consistent with `InvokeModelStage.messages`) or remain a reference (simpler, but surprising given the other stages' guarantees).

## `invocationState` is shared by reference

`invocationState` is not copied. Tools and hooks write to it, and those mutations must appear on `AgentResult.invocationState`. The SDK should never write to it directly — the key space belongs to the caller.

## Middleware cannot resume interrupts

`AgentStreamStage` middleware cannot currently resume tool-level interrupts. Interrupt resolution (`_interruptState.resume()`) runs in `stream()`'s outer loop, outside the middleware chain. When a tool-level interrupt fires, `_stream` catches the `InterruptError` internally and returns a normal `AgentResult` with `stopReason: 'interrupt'` — the middleware sees the result but cannot re-enter the stream with interrupt responses.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ describe('Agent middleware integration — InvokeModelStage', () => {
systemPrompt: 'Be helpful',
messages: expect.arrayContaining([expect.any(Message)]),
toolSpecs: expect.arrayContaining([expect.objectContaining({ name: 'testTool' })]),
modelState: expect.anything(),
invocationState: expect.anything(),
})
})
Expand Down
Loading
Loading