Skip to content

Commit

Permalink
fix(messages): reuse the first valid message ID for subsequent chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
dqbd committed Jan 23, 2025
1 parent 03d7003 commit 81a1345
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 62 deletions.
43 changes: 28 additions & 15 deletions libs/langgraph/src/pregel/messages.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { v4 } from "uuid";
import {
BaseCallbackHandler,
HandleLLMNewTokenCallbackFields,
Expand Down Expand Up @@ -44,28 +43,42 @@ export class StreamMessagesHandler extends BaseCallbackHandler {

emittedChatModelRunIds: Record<string, boolean> = {};

stableMessageIdMap: Record<string, string> = {};

lc_prefer_streaming = true;

constructor(streamFn: (streamChunk: StreamChunk) => void) {
super();
this.streamFn = streamFn;
}

_emit(meta: Meta, message: BaseMessage, dedupe = false) {
_emit(meta: Meta, message: BaseMessage, runId: string, dedupe = false) {
if (
dedupe &&
message.id !== undefined &&
this.seen[message.id] !== undefined
) {
return;
}
if (message.id === undefined) {
const id = v4();

// For instance in ChatAnthropic, the first chunk has an message ID
// but the subsequent chunks do not. To avoid clients seeing two messages
// we rename the message ID if it's being auto-set to `run-${runId}`
// (see https://github.com/langchain-ai/langchainjs/pull/6646).
let messageId = message.id;
if (messageId == null || messageId === `run-${runId}`) {
messageId = this.stableMessageIdMap[runId] ?? messageId ?? `run-${runId}`;
}
this.stableMessageIdMap[runId] ??= messageId;

if (messageId !== message.id) {
// eslint-disable-next-line no-param-reassign
message.id = id;
message.id = messageId;

// eslint-disable-next-line no-param-reassign
message.lc_kwargs.id = id;
message.lc_kwargs.id = messageId;
}

this.seen[message.id!] = message;
this.streamFn([meta[0], "messages", [message, meta[1]]]);
}
Expand Down Expand Up @@ -104,13 +117,12 @@ export class StreamMessagesHandler extends BaseCallbackHandler {
this.emittedChatModelRunIds[runId] = true;
if (this.metadatas[runId] !== undefined) {
if (isChatGenerationChunk(chunk)) {
this._emit(this.metadatas[runId], chunk.message);
this._emit(this.metadatas[runId], chunk.message, runId);
} else {
this._emit(
this.metadatas[runId],
new AIMessageChunk({
content: token,
})
new AIMessageChunk({ content: token }),
runId
);
}
}
Expand All @@ -121,11 +133,12 @@ export class StreamMessagesHandler extends BaseCallbackHandler {
if (!this.emittedChatModelRunIds[runId]) {
const chatGeneration = output.generations?.[0]?.[0] as ChatGeneration;
if (isBaseMessage(chatGeneration?.message)) {
this._emit(this.metadatas[runId], chatGeneration?.message, true);
this._emit(this.metadatas[runId], chatGeneration?.message, runId, true);
}
delete this.emittedChatModelRunIds[runId];
}
delete this.metadatas[runId];
delete this.stableMessageIdMap[runId];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down Expand Up @@ -160,21 +173,21 @@ export class StreamMessagesHandler extends BaseCallbackHandler {
delete this.metadatas[runId];
if (metadata !== undefined) {
if (isBaseMessage(outputs)) {
this._emit(metadata, outputs, true);
this._emit(metadata, outputs, runId, true);
} else if (Array.isArray(outputs)) {
for (const value of outputs) {
if (isBaseMessage(value)) {
this._emit(metadata, value, true);
this._emit(metadata, value, runId, true);
}
}
} else if (outputs != null && typeof outputs === "object") {
for (const value of Object.values(outputs)) {
if (isBaseMessage(value)) {
this._emit(metadata, value, true);
this._emit(metadata, value, runId, true);
} else if (Array.isArray(value)) {
for (const item of value) {
if (isBaseMessage(item)) {
this._emit(metadata, item, true);
this._emit(metadata, item, runId, true);
}
}
}
Expand Down
70 changes: 31 additions & 39 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8564,24 +8564,6 @@ graph TD;
tags: ["c_two_chat_model"],
},
],
[
new _AnyIdAIMessageChunk({
content: "2",
}),
{
langgraph_step: 2,
langgraph_node: "c_two",
langgraph_triggers: ["c_one"],
langgraph_path: [PULL, "c_two"],
langgraph_checkpoint_ns: expect.stringMatching(/^p_two:.*\|c_two:.*/),
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
name: "c_two",
tags: ["graph:step:2"],
ls_stop: undefined,
},
],
[
new _AnyIdAIMessageChunk({
content: "x",
Expand Down Expand Up @@ -8737,27 +8719,6 @@ graph TD;
},
],
],
[
"messages",
[
new _AnyIdAIMessageChunk({
content: "2",
}),
{
langgraph_step: 2,
langgraph_node: "c_two",
langgraph_triggers: ["c_one"],
langgraph_path: [PULL, "c_two"],
langgraph_checkpoint_ns:
expect.stringMatching(/^p_two:.*\|c_two:.*/),
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
tags: ["graph:step:2"],
name: "c_two",
},
],
],
[
"messages",
[
Expand Down Expand Up @@ -9470,6 +9431,37 @@ graph TD;
expect(oneCount).toEqual(1);
expect(twoCount).toEqual(0);
});

it.each(["omit", "first-only", "always"] as const)(
"`messages` inherits message ID - %p",
async (streamMessageId) => {
const checkpointer = await createCheckpointer();

const graph = new StateGraph(MessagesAnnotation)
.addNode("one", async () => {
const model = new FakeChatModel({
responses: [new AIMessage({ id: "123", content: "Output" })],
streamMessageId,
});

const invoke = await model.invoke([new HumanMessage("Input")]);
return { messages: invoke };
})
.addEdge(START, "one")
.compile({ checkpointer });

const messages = await gatherIterator(
graph.stream(
{ messages: [] },
{ configurable: { thread_id: "1" }, streamMode: "messages" }
)
);

const messageIds = [...new Set(messages.map(([m]) => m.id))];
expect(messageIds).toHaveLength(1);
if (streamMessageId !== "omit") expect(messageIds[0]).toBe("123");
}
);
}

runPregelTests(() => new MemorySaverAssertImmutable());
45 changes: 37 additions & 8 deletions libs/langgraph/src/tests/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/* eslint-disable import/no-extraneous-dependencies */
import assert from "node:assert";
import { expect } from "@jest/globals";
import { v4 as uuidv4 } from "uuid";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
Expand Down Expand Up @@ -44,9 +45,16 @@ export class FakeChatModel extends BaseChatModel {

callCount = 0;

constructor(fields: FakeChatModelArgs) {
streamMessageId: "omit" | "first-only" | "always";

constructor(
fields: FakeChatModelArgs & {
streamMessageId?: "omit" | "first-only" | "always";
}
) {
super(fields);
this.responses = fields.responses;
this.streamMessageId = fields.streamMessageId ?? "omit";
}

_combineLLMOutput() {
Expand Down Expand Up @@ -91,14 +99,35 @@ export class FakeChatModel extends BaseChatModel {
runManager?: CallbackManagerForLLMRun
) {
const response = this.responses[this.callCount % this.responses.length];
for (const text of (response.content as string).split("")) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: text as string,
}),
text,

let isFirstChunk = true;
const completionId = response.id ?? uuidv4();

for (const content of (response.content as string).split("")) {
let id: string | undefined;
if (
this.streamMessageId === "always" ||
(this.streamMessageId === "first-only" && isFirstChunk)
) {
id = completionId;
}

const chunk = new ChatGenerationChunk({
message: new AIMessageChunk({ content, id }),
text: content,
});
await runManager?.handleLLMNewToken(text as string);

yield chunk;
await runManager?.handleLLMNewToken(
content,
undefined,
undefined,
undefined,
undefined,
{ chunk }
);

isFirstChunk = false;
}
this.callCount += 1;
}
Expand Down

0 comments on commit 81a1345

Please sign in to comment.