From c5f0d2f803526c98de6ec5849f4f517ad71c6ee4 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 23 Jan 2025 16:40:59 +0100 Subject: [PATCH] fix(messages): reuse the first valid message ID for subsequent chunks --- libs/langgraph/src/pregel/messages.ts | 43 +++++++++------ libs/langgraph/src/tests/pregel.test.ts | 70 +++++++++++-------------- libs/langgraph/src/tests/utils.ts | 45 +++++++++++++--- 3 files changed, 96 insertions(+), 62 deletions(-) diff --git a/libs/langgraph/src/pregel/messages.ts b/libs/langgraph/src/pregel/messages.ts index 82a972d35..86c2cc019 100644 --- a/libs/langgraph/src/pregel/messages.ts +++ b/libs/langgraph/src/pregel/messages.ts @@ -1,4 +1,3 @@ -import { v4 } from "uuid"; import { BaseCallbackHandler, HandleLLMNewTokenCallbackFields, @@ -44,6 +43,8 @@ export class StreamMessagesHandler extends BaseCallbackHandler { emittedChatModelRunIds: Record = {}; + stableMessageIdMap: Record = {}; + lc_prefer_streaming = true; constructor(streamFn: (streamChunk: StreamChunk) => void) { @@ -51,7 +52,7 @@ export class StreamMessagesHandler extends BaseCallbackHandler { this.streamFn = streamFn; } - _emit(meta: Meta, message: BaseMessage, dedupe = false) { + _emit(meta: Meta, message: BaseMessage, runId: string, dedupe = false) { if ( dedupe && message.id !== undefined && @@ -59,13 +60,25 @@ export class StreamMessagesHandler extends BaseCallbackHandler { ) { return; } - if (message.id === undefined) { - const id = v4(); + + // For instance in ChatAnthropic, the first chunk has an message ID + // but the subsequent chunks do not. To avoid clients seeing two messages + // we rename the message ID if it's being auto-set to `run-${runId}` + // (see https://github.com/langchain-ai/langchainjs/pull/6646). + let messageId = message.id; + if (messageId == null || messageId === `run-${runId}`) { + messageId = this.stableMessageIdMap[runId] ?? messageId ?? `run-${runId}`; + } + this.stableMessageIdMap[runId] ??= messageId; + + if (messageId !== message.id) { // eslint-disable-next-line no-param-reassign - message.id = id; + message.id = messageId; + // eslint-disable-next-line no-param-reassign - message.lc_kwargs.id = id; + message.lc_kwargs.id = messageId; } + this.seen[message.id!] = message; this.streamFn([meta[0], "messages", [message, meta[1]]]); } @@ -104,13 +117,12 @@ export class StreamMessagesHandler extends BaseCallbackHandler { this.emittedChatModelRunIds[runId] = true; if (this.metadatas[runId] !== undefined) { if (isChatGenerationChunk(chunk)) { - this._emit(this.metadatas[runId], chunk.message); + this._emit(this.metadatas[runId], chunk.message, runId); } else { this._emit( this.metadatas[runId], - new AIMessageChunk({ - content: token, - }) + new AIMessageChunk({ content: token }), + runId ); } } @@ -121,11 +133,12 @@ export class StreamMessagesHandler extends BaseCallbackHandler { if (!this.emittedChatModelRunIds[runId]) { const chatGeneration = output.generations?.[0]?.[0] as ChatGeneration; if (isBaseMessage(chatGeneration?.message)) { - this._emit(this.metadatas[runId], chatGeneration?.message, true); + this._emit(this.metadatas[runId], chatGeneration?.message, runId, true); } delete this.emittedChatModelRunIds[runId]; } delete this.metadatas[runId]; + delete this.stableMessageIdMap[runId]; } // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -160,21 +173,21 @@ export class StreamMessagesHandler extends BaseCallbackHandler { delete this.metadatas[runId]; if (metadata !== undefined) { if (isBaseMessage(outputs)) { - this._emit(metadata, outputs, true); + this._emit(metadata, outputs, runId, true); } else if (Array.isArray(outputs)) { for (const value of outputs) { if (isBaseMessage(value)) { - this._emit(metadata, value, true); + this._emit(metadata, value, runId, true); } } } else if (outputs != null && typeof outputs === "object") { for (const value of Object.values(outputs)) { if (isBaseMessage(value)) { - this._emit(metadata, value, true); + this._emit(metadata, value, runId, true); } else if (Array.isArray(value)) { for (const item of value) { if (isBaseMessage(item)) { - this._emit(metadata, item, true); + this._emit(metadata, item, runId, true); } } } diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 37d5360a7..042911293 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -8564,24 +8564,6 @@ graph TD; tags: ["c_two_chat_model"], }, ], - [ - new _AnyIdAIMessageChunk({ - content: "2", - }), - { - langgraph_step: 2, - langgraph_node: "c_two", - langgraph_triggers: ["c_one"], - langgraph_path: [PULL, "c_two"], - langgraph_checkpoint_ns: expect.stringMatching(/^p_two:.*\|c_two:.*/), - __pregel_resuming: false, - __pregel_task_id: expect.any(String), - checkpoint_ns: expect.stringMatching(/^p_two:/), - name: "c_two", - tags: ["graph:step:2"], - ls_stop: undefined, - }, - ], [ new _AnyIdAIMessageChunk({ content: "x", @@ -8737,27 +8719,6 @@ graph TD; }, ], ], - [ - "messages", - [ - new _AnyIdAIMessageChunk({ - content: "2", - }), - { - langgraph_step: 2, - langgraph_node: "c_two", - langgraph_triggers: ["c_one"], - langgraph_path: [PULL, "c_two"], - langgraph_checkpoint_ns: - expect.stringMatching(/^p_two:.*\|c_two:.*/), - __pregel_resuming: false, - __pregel_task_id: expect.any(String), - checkpoint_ns: expect.stringMatching(/^p_two:/), - tags: ["graph:step:2"], - name: "c_two", - }, - ], - ], [ "messages", [ @@ -9429,6 +9390,37 @@ graph TD; const thirdState = await graph.getState(config); expect(thirdState.tasks).toHaveLength(0); }); + + it.each(["omit", "first-only", "always"] as const)( + "`messages` inherits message ID - %p", + async (streamMessageId) => { + const checkpointer = await createCheckpointer(); + + const graph = new StateGraph(MessagesAnnotation) + .addNode("one", async () => { + const model = new FakeChatModel({ + responses: [new AIMessage({ id: "123", content: "Output" })], + streamMessageId, + }); + + const invoke = await model.invoke([new HumanMessage("Input")]); + return { messages: invoke }; + }) + .addEdge(START, "one") + .compile({ checkpointer }); + + const messages = await gatherIterator( + graph.stream( + { messages: [] }, + { configurable: { thread_id: "1" }, streamMode: "messages" } + ) + ); + + const messageIds = [...new Set(messages.map(([m]) => m.id))]; + expect(messageIds).toHaveLength(1); + if (streamMessageId !== "omit") expect(messageIds[0]).toBe("123"); + } + ); } runPregelTests(() => new MemorySaverAssertImmutable()); diff --git a/libs/langgraph/src/tests/utils.ts b/libs/langgraph/src/tests/utils.ts index f060cfe64..99d2af886 100644 --- a/libs/langgraph/src/tests/utils.ts +++ b/libs/langgraph/src/tests/utils.ts @@ -2,6 +2,7 @@ /* eslint-disable import/no-extraneous-dependencies */ import assert from "node:assert"; import { expect } from "@jest/globals"; +import { v4 as uuidv4 } from "uuid"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { BaseChatModel, @@ -44,9 +45,16 @@ export class FakeChatModel extends BaseChatModel { callCount = 0; - constructor(fields: FakeChatModelArgs) { + streamMessageId: "omit" | "first-only" | "always"; + + constructor( + fields: FakeChatModelArgs & { + streamMessageId?: "omit" | "first-only" | "always"; + } + ) { super(fields); this.responses = fields.responses; + this.streamMessageId = fields.streamMessageId ?? "omit"; } _combineLLMOutput() { @@ -91,14 +99,35 @@ export class FakeChatModel extends BaseChatModel { runManager?: CallbackManagerForLLMRun ) { const response = this.responses[this.callCount % this.responses.length]; - for (const text of (response.content as string).split("")) { - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: text as string, - }), - text, + + let isFirstChunk = true; + const completionId = response.id ?? uuidv4(); + + for (const content of (response.content as string).split("")) { + let id: string | undefined; + if ( + this.streamMessageId === "always" || + (this.streamMessageId === "first-only" && isFirstChunk) + ) { + id = completionId; + } + + const chunk = new ChatGenerationChunk({ + message: new AIMessageChunk({ content, id }), + text: content, }); - await runManager?.handleLLMNewToken(text as string); + + yield chunk; + await runManager?.handleLLMNewToken( + content, + undefined, + undefined, + undefined, + undefined, + { chunk } + ); + + isFirstChunk = false; } this.callCount += 1; }