From 6caa011d1d58feadea2dbd05b68ac8babe8121e0 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 13:26:46 -0700 Subject: [PATCH 01/14] anthropic[minor]: Implement actual anthropic tool call streaming --- libs/langchain-anthropic/src/chat_models.ts | 228 ++++++++++-------- .../src/tests/chat_models-tools.int.test.ts | 40 ++- 2 files changed, 161 insertions(+), 107 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index f5e788bb2e29..d616f3ffb16f 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -33,7 +33,10 @@ import { } from "@langchain/core/language_models/base"; import { StructuredToolInterface } from "@langchain/core/tools"; import { zodToJsonSchema } from "zod-to-json-schema"; -import { BaseLLMOutputParser } from "@langchain/core/output_parsers"; +import { + BaseLLMOutputParser, + parsePartialJson, +} from "@langchain/core/output_parsers"; import { Runnable, RunnablePassthrough, @@ -41,7 +44,7 @@ import { RunnableToolLike, } from "@langchain/core/runnables"; import { isZodSchema } from "@langchain/core/utils/types"; -import { ToolCall } from "@langchain/core/messages/tool"; +import { ToolCall, ToolCallChunk } from "@langchain/core/messages/tool"; import { z } from "zod"; import type { MessageCreateParams, @@ -674,126 +677,139 @@ export class ChatAnthropicMessages< ): AsyncGenerator { const params = this.invocationParams(options); const formattedMessages = _formatMessagesForAnthropic(messages); - if (options.tools !== undefined && options.tools.length > 0) { - const { generations } = await this._generateNonStreaming( - messages, - params, - { - signal: options.signal, + + const stream = await this.createStreamWithRetry({ + ...params, + ...formattedMessages, + stream: true, + }); + let usageData = { input_tokens: 0, output_tokens: 0 }; + + let toolCallChunksMsg: ToolCallChunk | undefined; + let aggregatePartialJsonString = ""; + + for await (const data of stream) { + if (options.signal?.aborted) { + stream.controller.abort(); + throw new Error("AbortError: User aborted the request."); + } + if (data.type === "message_start") { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { content, usage, ...additionalKwargs } = data.message; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const filteredAdditionalKwargs: Record = {}; + for (const [key, value] of Object.entries(additionalKwargs)) { + if (value !== undefined && value !== null) { + filteredAdditionalKwargs[key] = value; + } } - ); - const result = generations[0].message as AIMessage; - const toolCallChunks = result.tool_calls?.map( - (toolCall: ToolCall, index: number) => ({ - name: toolCall.name, - args: JSON.stringify(toolCall.args), - id: toolCall.id, - index, - }) - ); - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: result.content, - additional_kwargs: result.additional_kwargs, - tool_call_chunks: toolCallChunks, - usage_metadata: result.usage_metadata, - response_metadata: result.response_metadata, - }), - text: generations[0].text, - }); - } else { - const stream = await this.createStreamWithRetry({ - ...params, - ...formattedMessages, - stream: true, - }); - let usageData = { input_tokens: 0, output_tokens: 0 }; - for await (const data of stream) { - if (options.signal?.aborted) { - stream.controller.abort(); - throw new Error("AbortError: User aborted the request."); + usageData = usage; + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }; } - if (data.type === "message_start") { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { content, usage, ...additionalKwargs } = data.message; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const filteredAdditionalKwargs: Record = {}; - for (const [key, value] of Object.entries(additionalKwargs)) { - if (value !== undefined && value !== null) { - filteredAdditionalKwargs[key] = value; - } - } - usageData = usage; - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - total_tokens: usage.input_tokens + usage.output_tokens, - }; - } + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + additional_kwargs: filteredAdditionalKwargs, + response_metadata: { ...filteredAdditionalKwargs }, + usage_metadata: usageMetadata, + }), + text: "", + }); + } else if (data.type === "message_delta") { + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: data.usage.output_tokens, + output_tokens: 0, + total_tokens: data.usage.output_tokens, + }; + } + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + additional_kwargs: { ...data.delta }, + response_metadata: { ...data.delta }, + usage_metadata: usageMetadata, + }), + text: "", + }); + if (data?.usage !== undefined) { + usageData.output_tokens += data.usage.output_tokens; + } + } else if ( + data.type === "content_block_delta" && + data.delta.type === "text_delta" + ) { + const content = data.delta?.text; + if (content !== undefined) { yield new ChatGenerationChunk({ message: new AIMessageChunk({ - content: "", - additional_kwargs: filteredAdditionalKwargs, - usage_metadata: usageMetadata, + content, + additional_kwargs: {}, + response_metadata: { ...data.delta }, }), - text: "", + text: content, }); - } else if (data.type === "message_delta") { - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: data.usage.output_tokens, - output_tokens: 0, - total_tokens: data.usage.output_tokens, - }; - } + await runManager?.handleLLMNewToken(content); + } + } else if ( + data.type === "content_block_start" && + data.content_block.type === "tool_use" + ) { + toolCallChunksMsg = { + name: data.content_block.name, + args: JSON.stringify(data.content_block.input), + id: data.content_block.id, + index: data.index, + } + } else if ( + data.type === "content_block_delta" && + data.delta.type === "input_json_delta" && + toolCallChunksMsg + ) { + aggregatePartialJsonString += data.delta.partial_json; + const parsedPartial = parsePartialJson(aggregatePartialJsonString); + if (parsedPartial) { yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: "", + tool_call_chunks: [ + { + ...toolCallChunksMsg, + args: JSON.stringify(parsedPartial, null, 2), + }, + ], additional_kwargs: { ...data.delta }, - usage_metadata: usageMetadata, + response_metadata: { ...data.delta }, }), text: "", }); - if (data?.usage !== undefined) { - usageData.output_tokens += data.usage.output_tokens; - } - } else if ( - data.type === "content_block_delta" && - data.delta.type === "text_delta" - ) { - const content = data.delta?.text; - if (content !== undefined) { - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content, - additional_kwargs: {}, - }), - text: content, - }); - await runManager?.handleLLMNewToken(content); - } } } - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: usageData.input_tokens, - output_tokens: usageData.output_tokens, - total_tokens: usageData.input_tokens + usageData.output_tokens, - }; - } - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: "", - additional_kwargs: { usage: usageData }, - usage_metadata: usageMetadata, - }), - text: "", - }); } + + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: usageData.input_tokens, + output_tokens: usageData.output_tokens, + total_tokens: usageData.input_tokens + usageData.output_tokens, + }; + } + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + additional_kwargs: { usage: usageData }, + usage_metadata: usageMetadata, + }), + text: "", + }); } /** @ignore */ diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index a6cbdcbd956b..953531c8b7da 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -2,7 +2,7 @@ import { expect, test } from "@jest/globals"; import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; -import { StructuredTool } from "@langchain/core/tools"; +import { StructuredTool, tool } from "@langchain/core/tools"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatAnthropic } from "../chat_models.js"; @@ -386,3 +386,41 @@ test("withStructuredOutput will always force tool usage", async () => { expect(castMessage.tool_calls).toHaveLength(1); expect(castMessage.tool_calls?.[0].name).toBe("get_weather"); }); + +test.only("Can stream tool calls", async () => { + const weatherTool = tool( + (_) => { + return "no-op"; + }, + { + name: "get_weather", + description: zodSchema.description, + schema: zodSchema, + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "What is the weather in San Francisco CA?" + ); + + let argsStringArr: string[] = []; + + for await (const chunk of stream) { + const toolCall = chunk.tool_calls?.[0]; + if (!toolCall) continue; + const stringifiedArgs = JSON.stringify(toolCall.args, null, 2); + + // Push each new "chunk" of args to the array. We'll check the array's + // length at the end to verify multiple tool call chunks were streamed. + if (argsStringArr[argsStringArr.length - 1] !== stringifiedArgs) { + argsStringArr.push(stringifiedArgs); + } + } + + expect(argsStringArr.length).toBeGreaterThan(1); + console.log("argsStringArr.length", argsStringArr.length) + const finalToolCall = JSON.parse(argsStringArr[argsStringArr.length - 1]); + console.log("finalToolCall", finalToolCall) + expect(finalToolCall.location).toBeDefined(); +}); From fdc3e3844fcc211d46f23d9b571fdbbcd792d784 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 13:26:57 -0700 Subject: [PATCH 02/14] chore: lint files --- libs/langchain-anthropic/src/chat_models.ts | 2 +- .../src/tests/chat_models-tools.int.test.ts | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index d616f3ffb16f..08ea303f0c34 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -767,7 +767,7 @@ export class ChatAnthropicMessages< args: JSON.stringify(data.content_block.input), id: data.content_block.id, index: data.index, - } + }; } else if ( data.type === "content_block_delta" && data.delta.type === "input_json_delta" && diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index 953531c8b7da..f3108113dec1 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -389,9 +389,7 @@ test("withStructuredOutput will always force tool usage", async () => { test.only("Can stream tool calls", async () => { const weatherTool = tool( - (_) => { - return "no-op"; - }, + (_) => "no-op", { name: "get_weather", description: zodSchema.description, @@ -404,7 +402,7 @@ test.only("Can stream tool calls", async () => { "What is the weather in San Francisco CA?" ); - let argsStringArr: string[] = []; + const argsStringArr: string[] = []; for await (const chunk of stream) { const toolCall = chunk.tool_calls?.[0]; @@ -419,8 +417,8 @@ test.only("Can stream tool calls", async () => { } expect(argsStringArr.length).toBeGreaterThan(1); - console.log("argsStringArr.length", argsStringArr.length) + console.log("argsStringArr.length", argsStringArr.length); const finalToolCall = JSON.parse(argsStringArr[argsStringArr.length - 1]); - console.log("finalToolCall", finalToolCall) + console.log("finalToolCall", finalToolCall); expect(finalToolCall.location).toBeDefined(); }); From 3d6a52c810a49c18b24234560767a7d9966fd5d7 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 13:27:22 -0700 Subject: [PATCH 03/14] chore: lint files --- .../src/tests/chat_models-tools.int.test.ts | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index f3108113dec1..471cfc15c538 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -388,14 +388,11 @@ test("withStructuredOutput will always force tool usage", async () => { }); test.only("Can stream tool calls", async () => { - const weatherTool = tool( - (_) => "no-op", - { - name: "get_weather", - description: zodSchema.description, - schema: zodSchema, - } - ); + const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: zodSchema.description, + schema: zodSchema, + }); const modelWithTools = model.bindTools([weatherTool]); const stream = await modelWithTools.stream( From 48782707e34bb2289293ddcddc5f8ff09b5e23a4 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 13:27:50 -0700 Subject: [PATCH 04/14] chore: lint files --- .../src/tests/chat_models-tools.int.test.ts | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index 471cfc15c538..e2df16ba9a6b 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -387,19 +387,24 @@ test("withStructuredOutput will always force tool usage", async () => { expect(castMessage.tool_calls?.[0].name).toBe("get_weather"); }); -test.only("Can stream tool calls", async () => { - const weatherTool = tool((_) => "no-op", { - name: "get_weather", - description: zodSchema.description, - schema: zodSchema, - }); +test("Can stream tool calls", async () => { + const weatherTool = tool( + (_) => { + return "no-op"; + }, + { + name: "get_weather", + description: zodSchema.description, + schema: zodSchema, + } + ); const modelWithTools = model.bindTools([weatherTool]); const stream = await modelWithTools.stream( "What is the weather in San Francisco CA?" ); - const argsStringArr: string[] = []; + let argsStringArr: string[] = []; for await (const chunk of stream) { const toolCall = chunk.tool_calls?.[0]; From e78c686599a78042bbdd57345b3f42522a6bdc28 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 14:05:20 -0700 Subject: [PATCH 05/14] I broke it --- libs/langchain-anthropic/src/chat_models.ts | 280 +++++++++++------- .../src/tests/chat_models-tools.int.test.ts | 20 +- 2 files changed, 193 insertions(+), 107 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index 08ea303f0c34..d585b6cfde9b 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -88,6 +88,160 @@ export interface ChatAnthropicCallOptions type AnthropicMessageResponse = Anthropic.ContentBlock | AnthropicToolResponse; +function _toolsInParams(params: AnthropicMessageCreateParams): boolean { + return !!( + params.tools && + params.tools.length > 0 + ); +} + +/** + * Convert Anthropic event to AIMessageChunk. + * + * Note that not all events will result in a message chunk. + * In these cases we return null. + * @param {Anthropic.Messages.RawMessageStreamEvent} event The event to convert. + * @param {boolean} streamUsage Whether to include token usage data in streamed chunks. + * @param {boolean} coerceContentToString Whether to coerce content to a string. + */ +function _makeMessageChunkFromAnthropicEvent(event: Anthropic.Messages.RawMessageStreamEvent, extra: { + streamUsage: boolean, + coerceContentToString: boolean, + usageData: { input_tokens: number, output_tokens: number }, + toolCallChunksMsg: ToolCallChunk[], + aggregatePartialJsonString: string +}): { + chatGenerationChunk: ChatGenerationChunk, + aggregatePartialJsonString: string + toolCallChunksMsg: ToolCallChunk[] +} | null { + if (event.type === "message_start") { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { content, usage, ...additionalKwargs } = event.message; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const filteredAdditionalKwargs: Record = {}; + for (const [key, value] of Object.entries(additionalKwargs)) { + if (value !== undefined && value !== null) { + filteredAdditionalKwargs[key] = value; + } + } + extra.usageData = usage; + let usageMetadata: UsageMetadata | undefined; + if (extra.streamUsage) { + usageMetadata = { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }; + } + return { + chatGenerationChunk: new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: extra.coerceContentToString ? "" : [], + additional_kwargs: filteredAdditionalKwargs, + response_metadata: { ...filteredAdditionalKwargs }, + usage_metadata: usageMetadata, + }), + text: "", + }), + aggregatePartialJsonString: extra.aggregatePartialJsonString, + toolCallChunksMsg: extra.toolCallChunksMsg, + }; + } else if (event.type === "message_delta") { + let usageMetadata: UsageMetadata | undefined; + if (extra.streamUsage) { + usageMetadata = { + input_tokens: event.usage.output_tokens, + output_tokens: 0, + total_tokens: event.usage.output_tokens, + }; + } + if (event?.usage !== undefined) { + extra.usageData.output_tokens += event.usage.output_tokens; + } + return { + chatGenerationChunk: new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: extra.coerceContentToString ? "" : [], + additional_kwargs: { ...event.delta }, + response_metadata: { ...event.delta }, + usage_metadata: usageMetadata, + }), + text: "", + }), + aggregatePartialJsonString: extra.aggregatePartialJsonString, + toolCallChunksMsg: extra.toolCallChunksMsg, + }; + } else if ( + event.type === "content_block_delta" && + event.delta.type === "text_delta" + ) { + const content = event.delta?.text; + if (content !== undefined) { + return { + chatGenerationChunk: new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: extra.coerceContentToString ? content : [{ type: "text", text: content }], + additional_kwargs: {}, + response_metadata: { ...event.delta }, + }), + text: content, + }), + aggregatePartialJsonString: extra.aggregatePartialJsonString, + toolCallChunksMsg: extra.toolCallChunksMsg, + }; + } + } else if ( + event.type === "content_block_start" && + event.content_block.type === "tool_use" + ) { + extra.toolCallChunksMsg.push({ + name: event.content_block.name, + args: JSON.stringify(event.content_block.input), + id: event.content_block.id, + index: event.index, + }); + } else if ( + event.type === "content_block_delta" && + event.delta.type === "input_json_delta" && + extra.toolCallChunksMsg.find((chunk) => chunk.index === event.index) + ) { + const toolCallChunk = extra.toolCallChunksMsg.find( + (chunk) => chunk.index === event.index + ); + extra.aggregatePartialJsonString += event.delta.partial_json; + const parsedPartial = parsePartialJson(extra.aggregatePartialJsonString); + + if (parsedPartial) { + return { + chatGenerationChunk: new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: [{ + id: toolCallChunk?.id, + name: toolCallChunk?.name, + input: parsedPartial, + type: "tool_use", + }], + tool_call_chunks: [ + { + ...toolCallChunk, + args: JSON.stringify(parsedPartial, null, 2), + }, + ], + additional_kwargs: { ...event.delta }, + response_metadata: { ...event.delta }, + }), + text: "", + }), + aggregatePartialJsonString: extra.aggregatePartialJsonString, + toolCallChunksMsg: extra.toolCallChunksMsg, + }; + } + } + + return null; +} + function _formatImage(imageUrl: string) { const regex = /^data:(image\/.+);base64,(.+)$/; const match = imageUrl.match(regex); @@ -684,113 +838,39 @@ export class ChatAnthropicMessages< stream: true, }); let usageData = { input_tokens: 0, output_tokens: 0 }; + const coerceContentToString = _toolsInParams({ + ...params, + ...formattedMessages, + stream: false, + }) ? false : true; // Do not coerce content to string if tools are present - let toolCallChunksMsg: ToolCallChunk | undefined; + let toolCallChunksMsg: ToolCallChunk[] = []; let aggregatePartialJsonString = ""; + console.log("!!!!coerceContentToString!!!!!", coerceContentToString); + for await (const data of stream) { if (options.signal?.aborted) { stream.controller.abort(); throw new Error("AbortError: User aborted the request."); } - if (data.type === "message_start") { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { content, usage, ...additionalKwargs } = data.message; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const filteredAdditionalKwargs: Record = {}; - for (const [key, value] of Object.entries(additionalKwargs)) { - if (value !== undefined && value !== null) { - filteredAdditionalKwargs[key] = value; - } - } - usageData = usage; - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - total_tokens: usage.input_tokens + usage.output_tokens, - }; - } - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: "", - additional_kwargs: filteredAdditionalKwargs, - response_metadata: { ...filteredAdditionalKwargs }, - usage_metadata: usageMetadata, - }), - text: "", - }); - } else if (data.type === "message_delta") { - let usageMetadata: UsageMetadata | undefined; - if (this.streamUsage || options.streamUsage) { - usageMetadata = { - input_tokens: data.usage.output_tokens, - output_tokens: 0, - total_tokens: data.usage.output_tokens, - }; - } - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: "", - additional_kwargs: { ...data.delta }, - response_metadata: { ...data.delta }, - usage_metadata: usageMetadata, - }), - text: "", - }); - if (data?.usage !== undefined) { - usageData.output_tokens += data.usage.output_tokens; - } - } else if ( - data.type === "content_block_delta" && - data.delta.type === "text_delta" - ) { - const content = data.delta?.text; - if (content !== undefined) { - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content, - additional_kwargs: {}, - response_metadata: { ...data.delta }, - }), - text: content, - }); - await runManager?.handleLLMNewToken(content); - } - } else if ( - data.type === "content_block_start" && - data.content_block.type === "tool_use" - ) { - toolCallChunksMsg = { - name: data.content_block.name, - args: JSON.stringify(data.content_block.input), - id: data.content_block.id, - index: data.index, - }; - } else if ( - data.type === "content_block_delta" && - data.delta.type === "input_json_delta" && - toolCallChunksMsg - ) { - aggregatePartialJsonString += data.delta.partial_json; - const parsedPartial = parsePartialJson(aggregatePartialJsonString); - if (parsedPartial) { - yield new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: "", - tool_call_chunks: [ - { - ...toolCallChunksMsg, - args: JSON.stringify(parsedPartial, null, 2), - }, - ], - additional_kwargs: { ...data.delta }, - response_metadata: { ...data.delta }, - }), - text: "", - }); - } + console.log("before", aggregatePartialJsonString); + const result = _makeMessageChunkFromAnthropicEvent(data, { + streamUsage: !!(this.streamUsage || options.streamUsage), + coerceContentToString, + usageData, + toolCallChunksMsg, + aggregatePartialJsonString + }); + if (!result) continue; + + const { chatGenerationChunk, aggregatePartialJsonString: updatedAggregatePartialJsonString, toolCallChunksMsg: updatedToolCallChunksMsg } = result; + aggregatePartialJsonString = updatedAggregatePartialJsonString; + toolCallChunksMsg = updatedToolCallChunksMsg; + + yield chatGenerationChunk; + if (chatGenerationChunk.text !== "") { + await runManager?.handleLLMNewToken(chatGenerationChunk.text); } } @@ -804,7 +884,7 @@ export class ChatAnthropicMessages< } yield new ChatGenerationChunk({ message: new AIMessageChunk({ - content: "", + content: coerceContentToString ? "" : [], additional_kwargs: { usage: usageData }, usage_metadata: usageMetadata, }), diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index e2df16ba9a6b..ea50a355f9a5 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -1,12 +1,13 @@ /* eslint-disable no-process-env */ import { expect, test } from "@jest/globals"; -import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages"; +import { AIMessage, AIMessageChunk, HumanMessage, ToolMessage } from "@langchain/core/messages"; import { StructuredTool, tool } from "@langchain/core/tools"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatAnthropic } from "../chat_models.js"; import { AnthropicToolResponse } from "../types.js"; +import { concat } from "@langchain/core/utils/stream"; const zodSchema = z .object({ @@ -167,7 +168,7 @@ test("Can bind & invoke AnthropicTools", async () => { expect(input.location).toBeTruthy(); }); -test("Can bind & stream AnthropicTools", async () => { +test.only("Can bind & stream AnthropicTools", async () => { const modelWithTools = model.bind({ tools: [anthropicTool], }); @@ -176,11 +177,18 @@ test("Can bind & stream AnthropicTools", async () => { "What is the weather in London today?" ); let finalMessage; + let finalMessageConcated: AIMessageChunk | undefined; for await (const item of result) { - console.log("item", JSON.stringify(item, null, 2)); + if (!finalMessageConcated) { + finalMessageConcated = item; + } else { + finalMessageConcated = concat(finalMessageConcated, item); + } finalMessage = item; } + console.log("finalMessageConcated", finalMessageConcated) + if (!finalMessage) { throw new Error("No final message returned"); } @@ -389,9 +397,7 @@ test("withStructuredOutput will always force tool usage", async () => { test("Can stream tool calls", async () => { const weatherTool = tool( - (_) => { - return "no-op"; - }, + (_) => "no-op", { name: "get_weather", description: zodSchema.description, @@ -404,7 +410,7 @@ test("Can stream tool calls", async () => { "What is the weather in San Francisco CA?" ); - let argsStringArr: string[] = []; + const argsStringArr: string[] = []; for await (const chunk of stream) { const toolCall = chunk.tool_calls?.[0]; From 6d5f184ec813269a3695349bee5b7540770236e6 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 14:47:03 -0700 Subject: [PATCH 06/14] core[minor]: Allow concatenation of messages with multi part content --- langchain-core/src/messages/base.ts | 26 +++- .../src/messages/tests/base_message.test.ts | 143 +++++++++++++++++- 2 files changed, 165 insertions(+), 4 deletions(-) diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index 41a7cb2c2f70..8e1e29950ac5 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -127,7 +127,13 @@ export function mergeContent( } // If both are arrays } else if (Array.isArray(secondContent)) { - return [...firstContent, ...secondContent]; + // find and merge all objects with the same "index" and "type" fields. + return ( + _mergeLists(firstContent, secondContent) ?? [ + ...firstContent, + ...secondContent, + ] + ); // If the first content is a list and second is a string } else { // Otherwise, add the second content as a new element of the list @@ -258,8 +264,16 @@ export function _mergeDicts( throw new Error( `field[${key}] already exists in the message chunk, but with a different type.` ); - } else if (typeof merged[key] === "string") { - merged[key] = (merged[key] as string) + value; + } else if ( + typeof merged[key] === "string" && + !(key === "type" && merged[key] === value) + ) { + if (key === "type") { + // Do not merge 'type' fields + continue; + } else { + merged[key] = (merged[key] as string) + value; + } } else if (!Array.isArray(merged[key]) && typeof merged[key] === "object") { merged[key] = _mergeDicts(merged[key], value); } else if (Array.isArray(merged[key])) { @@ -297,6 +311,12 @@ export function _mergeLists(left?: any[], right?: any[]) { } else { merged.push(item); } + } else if ( + typeof item === "object" && + "text" in item && + item.text === "" + ) { + // Skip empty text blocks } else { merged.push(item); } diff --git a/langchain-core/src/messages/tests/base_message.test.ts b/langchain-core/src/messages/tests/base_message.test.ts index fb6f3797f31f..2a46f2fe014e 100644 --- a/langchain-core/src/messages/tests/base_message.test.ts +++ b/langchain-core/src/messages/tests/base_message.test.ts @@ -1,10 +1,11 @@ -import { test } from "@jest/globals"; +import { test, describe, it, expect } from "@jest/globals"; import { ChatPromptTemplate } from "../../prompts/chat.js"; import { HumanMessage, AIMessage, ToolMessage, ToolMessageChunk, + AIMessageChunk, } from "../index.js"; import { load } from "../../load/index.js"; @@ -193,3 +194,143 @@ test("Can concat raw_output (object) of ToolMessageChunk", () => { bar: "baz", }); }); + +describe("Complex AIMessageChunk concat", () => { + it("concatenates content arrays of strings", () => { + expect( + new AIMessageChunk({ + content: [{ type: "text", text: "I am" }], + id: "ai4", + }).concat( + new AIMessageChunk({ content: [{ type: "text", text: " indeed." }] }) + ) + ).toEqual( + new AIMessageChunk({ + id: "ai4", + content: [ + { type: "text", text: "I am" }, + { type: "text", text: " indeed." }, + ], + }) + ); + }); + + it("concatenates mixed content arrays", () => { + expect( + new AIMessageChunk({ + content: [{ index: 0, type: "text", text: "I am" }], + }).concat( + new AIMessageChunk({ content: [{ type: "text", text: " indeed." }] }) + ) + ).toEqual( + new AIMessageChunk({ + content: [ + { index: 0, type: "text", text: "I am" }, + { type: "text", text: " indeed." }, + ], + }) + ); + }); + + it("merges content arrays with same index", () => { + expect( + new AIMessageChunk({ content: [{ index: 0, text: "I am" }] }).concat( + new AIMessageChunk({ content: [{ index: 0, text: " indeed." }] }) + ) + ).toEqual( + new AIMessageChunk({ content: [{ index: 0, text: "I am indeed." }] }) + ); + }); + + it("does not merge when one chunk is missing an index", () => { + expect( + new AIMessageChunk({ content: [{ index: 0, text: "I am" }] }).concat( + new AIMessageChunk({ content: [{ text: " indeed." }] }) + ) + ).toEqual( + new AIMessageChunk({ + content: [{ index: 0, text: "I am" }, { text: " indeed." }], + }) + ); + }); + + it("does not create a holey array when there's a gap between indexes", () => { + expect( + new AIMessageChunk({ content: [{ index: 0, text: "I am" }] }).concat( + new AIMessageChunk({ content: [{ index: 2, text: " indeed." }] }) + ) + ).toEqual( + new AIMessageChunk({ + content: [ + { index: 0, text: "I am" }, + { index: 2, text: " indeed." }, + ], + }) + ); + }); + + it("does not merge content arrays with separate indexes", () => { + expect( + new AIMessageChunk({ content: [{ index: 0, text: "I am" }] }).concat( + new AIMessageChunk({ content: [{ index: 1, text: " indeed." }] }) + ) + ).toEqual( + new AIMessageChunk({ + content: [ + { index: 0, text: "I am" }, + { index: 1, text: " indeed." }, + ], + }) + ); + }); + + it("merges content arrays with same index and type", () => { + expect( + new AIMessageChunk({ + content: [{ index: 0, text: "I am", type: "text_block" }], + }).concat( + new AIMessageChunk({ + content: [{ index: 0, text: " indeed.", type: "text_block" }], + }) + ) + ).toEqual( + new AIMessageChunk({ + content: [{ index: 0, text: "I am indeed.", type: "text_block" }], + }) + ); + }); + + it("merges content arrays with same index and different types without updating type", () => { + expect( + new AIMessageChunk({ + content: [{ index: 0, text: "I am", type: "text_block" }], + }).concat( + new AIMessageChunk({ + content: [{ index: 0, text: " indeed.", type: "text_block_delta" }], + }) + ) + ).toEqual( + new AIMessageChunk({ + content: [{ index: 0, text: "I am indeed.", type: "text_block" }], + }) + ); + }); + + it("concatenates empty string content and merges other fields", () => { + expect( + new AIMessageChunk({ + content: [{ index: 0, type: "text", text: "I am" }], + }).concat( + new AIMessageChunk({ + content: [{ type: "text", text: "" }], + response_metadata: { extra: "value" }, + }) + ) + ).toEqual( + new AIMessageChunk({ + content: [{ index: 0, type: "text", text: "I am" }], + response_metadata: { extra: "value" }, + }) + ); + }); +}); From afb5b2fef91975e2708fcc378581738468a10cd3 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 14:48:54 -0700 Subject: [PATCH 07/14] cr --- langchain-core/src/messages/base.ts | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index 8e1e29950ac5..7e286420b950 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -264,17 +264,13 @@ export function _mergeDicts( throw new Error( `field[${key}] already exists in the message chunk, but with a different type.` ); - } else if ( - typeof merged[key] === "string" && - !(key === "type" && merged[key] === value) - ) { + } else if (typeof merged[key] === "string") { if (key === "type") { // Do not merge 'type' fields continue; - } else { - merged[key] = (merged[key] as string) + value; } - } else if (!Array.isArray(merged[key]) && typeof merged[key] === "object") { + merged[key] += value; + } else if (typeof merged[key] === "object" && !Array.isArray(merged[key])) { merged[key] = _mergeDicts(merged[key], value); } else if (Array.isArray(merged[key])) { merged[key] = _mergeLists(merged[key], value); From a996ad0e5a75f765a8a5135bbd2db4c7874d4b6a Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 16:02:20 -0700 Subject: [PATCH 08/14] cr --- libs/langchain-anthropic/src/chat_models.ts | 356 +++++++++--------- .../src/tests/chat_models-tools.int.test.ts | 24 +- 2 files changed, 192 insertions(+), 188 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index d585b6cfde9b..004ca12f3d54 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -55,6 +55,7 @@ import { extractToolCalls, } from "./output_parsers.js"; import { AnthropicToolResponse } from "./types.js"; +import { concat } from "@langchain/core/utils/stream"; type AnthropicMessage = Anthropic.MessageParam; type AnthropicMessageCreateParams = Anthropic.MessageCreateParamsNonStreaming; @@ -89,157 +90,7 @@ export interface ChatAnthropicCallOptions type AnthropicMessageResponse = Anthropic.ContentBlock | AnthropicToolResponse; function _toolsInParams(params: AnthropicMessageCreateParams): boolean { - return !!( - params.tools && - params.tools.length > 0 - ); -} - -/** - * Convert Anthropic event to AIMessageChunk. - * - * Note that not all events will result in a message chunk. - * In these cases we return null. - * @param {Anthropic.Messages.RawMessageStreamEvent} event The event to convert. - * @param {boolean} streamUsage Whether to include token usage data in streamed chunks. - * @param {boolean} coerceContentToString Whether to coerce content to a string. - */ -function _makeMessageChunkFromAnthropicEvent(event: Anthropic.Messages.RawMessageStreamEvent, extra: { - streamUsage: boolean, - coerceContentToString: boolean, - usageData: { input_tokens: number, output_tokens: number }, - toolCallChunksMsg: ToolCallChunk[], - aggregatePartialJsonString: string -}): { - chatGenerationChunk: ChatGenerationChunk, - aggregatePartialJsonString: string - toolCallChunksMsg: ToolCallChunk[] -} | null { - if (event.type === "message_start") { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { content, usage, ...additionalKwargs } = event.message; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const filteredAdditionalKwargs: Record = {}; - for (const [key, value] of Object.entries(additionalKwargs)) { - if (value !== undefined && value !== null) { - filteredAdditionalKwargs[key] = value; - } - } - extra.usageData = usage; - let usageMetadata: UsageMetadata | undefined; - if (extra.streamUsage) { - usageMetadata = { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - total_tokens: usage.input_tokens + usage.output_tokens, - }; - } - return { - chatGenerationChunk: new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: extra.coerceContentToString ? "" : [], - additional_kwargs: filteredAdditionalKwargs, - response_metadata: { ...filteredAdditionalKwargs }, - usage_metadata: usageMetadata, - }), - text: "", - }), - aggregatePartialJsonString: extra.aggregatePartialJsonString, - toolCallChunksMsg: extra.toolCallChunksMsg, - }; - } else if (event.type === "message_delta") { - let usageMetadata: UsageMetadata | undefined; - if (extra.streamUsage) { - usageMetadata = { - input_tokens: event.usage.output_tokens, - output_tokens: 0, - total_tokens: event.usage.output_tokens, - }; - } - if (event?.usage !== undefined) { - extra.usageData.output_tokens += event.usage.output_tokens; - } - return { - chatGenerationChunk: new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: extra.coerceContentToString ? "" : [], - additional_kwargs: { ...event.delta }, - response_metadata: { ...event.delta }, - usage_metadata: usageMetadata, - }), - text: "", - }), - aggregatePartialJsonString: extra.aggregatePartialJsonString, - toolCallChunksMsg: extra.toolCallChunksMsg, - }; - } else if ( - event.type === "content_block_delta" && - event.delta.type === "text_delta" - ) { - const content = event.delta?.text; - if (content !== undefined) { - return { - chatGenerationChunk: new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: extra.coerceContentToString ? content : [{ type: "text", text: content }], - additional_kwargs: {}, - response_metadata: { ...event.delta }, - }), - text: content, - }), - aggregatePartialJsonString: extra.aggregatePartialJsonString, - toolCallChunksMsg: extra.toolCallChunksMsg, - }; - } - } else if ( - event.type === "content_block_start" && - event.content_block.type === "tool_use" - ) { - extra.toolCallChunksMsg.push({ - name: event.content_block.name, - args: JSON.stringify(event.content_block.input), - id: event.content_block.id, - index: event.index, - }); - } else if ( - event.type === "content_block_delta" && - event.delta.type === "input_json_delta" && - extra.toolCallChunksMsg.find((chunk) => chunk.index === event.index) - ) { - const toolCallChunk = extra.toolCallChunksMsg.find( - (chunk) => chunk.index === event.index - ); - extra.aggregatePartialJsonString += event.delta.partial_json; - const parsedPartial = parsePartialJson(extra.aggregatePartialJsonString); - - if (parsedPartial) { - return { - chatGenerationChunk: new ChatGenerationChunk({ - message: new AIMessageChunk({ - content: [{ - id: toolCallChunk?.id, - name: toolCallChunk?.name, - input: parsedPartial, - type: "tool_use", - }], - tool_call_chunks: [ - { - ...toolCallChunk, - args: JSON.stringify(parsedPartial, null, 2), - }, - ], - additional_kwargs: { ...event.delta }, - response_metadata: { ...event.delta }, - }), - text: "", - }), - aggregatePartialJsonString: extra.aggregatePartialJsonString, - toolCallChunksMsg: extra.toolCallChunksMsg, - }; - } - } - - return null; + return !!(params.tools && params.tools.length > 0); } function _formatImage(imageUrl: string) { @@ -311,6 +162,112 @@ function isAnthropicTool(tool: any): tool is AnthropicTool { return "input_schema" in tool; } +function _makeMessageChunkFromAnthropicEvent( + data: Anthropic.Messages.RawMessageStreamEvent, + fields: { + streamUsage: boolean; + coerceContentToString: boolean; + usageData: { input_tokens: number; output_tokens: number }; + } +): AIMessageChunk | null { + if (data.type === "message_start") { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { content, usage, ...additionalKwargs } = data.message; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const filteredAdditionalKwargs: Record = {}; + for (const [key, value] of Object.entries(additionalKwargs)) { + if (value !== undefined && value !== null) { + filteredAdditionalKwargs[key] = value; + } + } + fields.usageData = usage; + let usageMetadata: UsageMetadata | undefined; + if (fields.streamUsage) { + usageMetadata = { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }; + } + return new AIMessageChunk({ + content: fields.coerceContentToString ? "" : [], + additional_kwargs: filteredAdditionalKwargs, + usage_metadata: usageMetadata, + }); + } else if (data.type === "message_delta") { + let usageMetadata: UsageMetadata | undefined; + if (fields.streamUsage) { + usageMetadata = { + input_tokens: data.usage.output_tokens, + output_tokens: 0, + total_tokens: data.usage.output_tokens, + }; + } + if (data?.usage !== undefined) { + fields.usageData.output_tokens += data.usage.output_tokens; + } + + return new AIMessageChunk({ + content: fields.coerceContentToString ? "" : [], + additional_kwargs: { ...data.delta }, + usage_metadata: usageMetadata, + }); + } else if ( + data.type === "content_block_start" && + data.content_block.type === "tool_use" + ) { + return new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + index: data.index, + ...data.content_block, + input: "", + }, + ], + additional_kwargs: {}, + }); + } else if ( + data.type === "content_block_delta" && + data.delta.type === "text_delta" + ) { + const content = data.delta?.text; + if (content !== undefined) { + return new AIMessageChunk({ + content: fields.coerceContentToString + ? content + : [ + { + index: data.index, + ...data.delta, + }, + ], + additional_kwargs: {}, + }); + } + } else if ( + data.type === "content_block_delta" && + data.delta.type === "input_json_delta" + ) { + // partial JSON incoming! + return new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + index: data.index, + input: data.delta.partial_json, + type: data.delta.type, + }, + ], + additional_kwargs: {}, + }); + } + + return null; +} + /** * Input to AnthropicChat class. */ @@ -831,49 +788,94 @@ export class ChatAnthropicMessages< ): AsyncGenerator { const params = this.invocationParams(options); const formattedMessages = _formatMessagesForAnthropic(messages); + const coerceContentToString = !_toolsInParams({ + ...params, + ...formattedMessages, + stream: false, + }); const stream = await this.createStreamWithRetry({ ...params, ...formattedMessages, stream: true, }); - let usageData = { input_tokens: 0, output_tokens: 0 }; - const coerceContentToString = _toolsInParams({ - ...params, - ...formattedMessages, - stream: false, - }) ? false : true; // Do not coerce content to string if tools are present - - let toolCallChunksMsg: ToolCallChunk[] = []; - let aggregatePartialJsonString = ""; - - console.log("!!!!coerceContentToString!!!!!", coerceContentToString); - + const usageData = { input_tokens: 0, output_tokens: 0 }; + let finalChunk: AIMessageChunk | undefined; + let aggregateToolCalls: { + index: number; + type: "tool_use"; + id: string; + name: string; + input: string; + }[] = []; + console.log(aggregateToolCalls.length) for await (const data of stream) { if (options.signal?.aborted) { stream.controller.abort(); throw new Error("AbortError: User aborted the request."); } - console.log("before", aggregatePartialJsonString); - const result = _makeMessageChunkFromAnthropicEvent(data, { + const chunk = _makeMessageChunkFromAnthropicEvent(data, { streamUsage: !!(this.streamUsage || options.streamUsage), coerceContentToString, usageData, - toolCallChunksMsg, - aggregatePartialJsonString }); - if (!result) continue; - - const { chatGenerationChunk, aggregatePartialJsonString: updatedAggregatePartialJsonString, toolCallChunksMsg: updatedToolCallChunksMsg } = result; - aggregatePartialJsonString = updatedAggregatePartialJsonString; - toolCallChunksMsg = updatedToolCallChunksMsg; - - yield chatGenerationChunk; - if (chatGenerationChunk.text !== "") { - await runManager?.handleLLMNewToken(chatGenerationChunk.text); + if (!chunk) { + continue; + } + + if (!finalChunk) { + finalChunk = chunk; + } else { + finalChunk = concat(finalChunk, chunk); + } + + let toolCallChunks: ToolCallChunk[] = []; + const toolUseContentBlocks = Array.isArray(finalChunk.content) ? finalChunk.content.filter((item) => item.type === "tool_use") : []; + if (toolUseContentBlocks.length) { + toolCallChunks = toolUseContentBlocks.flatMap((item) => { + if (!("index" in item && "id" in item && "name" in item && "input" in item)) { + return []; + } + + const parsedPartialJson = parsePartialJson(item.input); + if (!parsedPartialJson) { + return []; + } + + return { + index: item.index, + id: item.id, + name: item.name, + args: JSON.stringify(item.input, null, 2), + }; + }); + console.log("toolCallChunks.length", toolCallChunks.length) } - } + const newChunk = new AIMessageChunk({ + content: coerceContentToString ? "" : [], + tool_call_chunks: toolCallChunks, + }); + finalChunk = concat(finalChunk, newChunk); + console.log("finalChunk.tool_calls", finalChunk.tool_calls) + + const token: string | undefined = + typeof chunk.content === "string" && chunk.content !== "" + ? chunk.content + : undefined; + token ? await runManager?.handleLLMNewToken(token) : null; + + yield new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: chunk.content, + additional_kwargs: chunk.additional_kwargs, + tool_call_chunks: toolCallChunks, + usage_metadata: chunk.usage_metadata, + response_metadata: chunk.response_metadata, + }), + text: token ?? "", + }); + } let usageMetadata: UsageMetadata | undefined; if (this.streamUsage || options.streamUsage) { usageMetadata = { diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index ea50a355f9a5..b539c3586e99 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -1,13 +1,18 @@ /* eslint-disable no-process-env */ import { expect, test } from "@jest/globals"; -import { AIMessage, AIMessageChunk, HumanMessage, ToolMessage } from "@langchain/core/messages"; +import { + AIMessage, + AIMessageChunk, + HumanMessage, + ToolMessage, +} from "@langchain/core/messages"; import { StructuredTool, tool } from "@langchain/core/tools"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; +import { concat } from "@langchain/core/utils/stream"; import { ChatAnthropic } from "../chat_models.js"; import { AnthropicToolResponse } from "../types.js"; -import { concat } from "@langchain/core/utils/stream"; const zodSchema = z .object({ @@ -187,7 +192,7 @@ test.only("Can bind & stream AnthropicTools", async () => { finalMessage = item; } - console.log("finalMessageConcated", finalMessageConcated) + console.log("finalMessageConcated", finalMessageConcated); if (!finalMessage) { throw new Error("No final message returned"); @@ -396,14 +401,11 @@ test("withStructuredOutput will always force tool usage", async () => { }); test("Can stream tool calls", async () => { - const weatherTool = tool( - (_) => "no-op", - { - name: "get_weather", - description: zodSchema.description, - schema: zodSchema, - } - ); + const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: zodSchema.description, + schema: zodSchema, + }); const modelWithTools = model.bindTools([weatherTool]); const stream = await modelWithTools.stream( From 792a04269633fb022dc8c96b2188023b02a9955e Mon Sep 17 00:00:00 2001 From: bracesproul Date: Fri, 12 Jul 2024 16:44:12 -0700 Subject: [PATCH 09/14] proper streaming --- libs/langchain-anthropic/src/chat_models.ts | 213 ++++++++++-------- .../src/tests/chat_models-tools.int.test.ts | 64 +++--- 2 files changed, 147 insertions(+), 130 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index 004ca12f3d54..e3bd162ab5fa 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -33,10 +33,7 @@ import { } from "@langchain/core/language_models/base"; import { StructuredToolInterface } from "@langchain/core/tools"; import { zodToJsonSchema } from "zod-to-json-schema"; -import { - BaseLLMOutputParser, - parsePartialJson, -} from "@langchain/core/output_parsers"; +import { BaseLLMOutputParser } from "@langchain/core/output_parsers"; import { Runnable, RunnablePassthrough, @@ -55,7 +52,6 @@ import { extractToolCalls, } from "./output_parsers.js"; import { AnthropicToolResponse } from "./types.js"; -import { concat } from "@langchain/core/utils/stream"; type AnthropicMessage = Anthropic.MessageParam; type AnthropicMessageCreateParams = Anthropic.MessageCreateParamsNonStreaming; @@ -169,7 +165,12 @@ function _makeMessageChunkFromAnthropicEvent( coerceContentToString: boolean; usageData: { input_tokens: number; output_tokens: number }; } -): AIMessageChunk | null { +): { + chunk: AIMessageChunk; + usageData: { input_tokens: number; output_tokens: number }; +} | null { + let usageDataCopy = { ...fields.usageData }; + if (data.type === "message_start") { // eslint-disable-next-line @typescript-eslint/no-unused-vars const { content, usage, ...additionalKwargs } = data.message; @@ -180,7 +181,7 @@ function _makeMessageChunkFromAnthropicEvent( filteredAdditionalKwargs[key] = value; } } - fields.usageData = usage; + usageDataCopy = usage; let usageMetadata: UsageMetadata | undefined; if (fields.streamUsage) { usageMetadata = { @@ -189,11 +190,14 @@ function _makeMessageChunkFromAnthropicEvent( total_tokens: usage.input_tokens + usage.output_tokens, }; } - return new AIMessageChunk({ - content: fields.coerceContentToString ? "" : [], - additional_kwargs: filteredAdditionalKwargs, - usage_metadata: usageMetadata, - }); + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString ? "" : [], + additional_kwargs: filteredAdditionalKwargs, + usage_metadata: usageMetadata, + }), + usageData: usageDataCopy, + }; } else if (data.type === "message_delta") { let usageMetadata: UsageMetadata | undefined; if (fields.streamUsage) { @@ -204,65 +208,75 @@ function _makeMessageChunkFromAnthropicEvent( }; } if (data?.usage !== undefined) { - fields.usageData.output_tokens += data.usage.output_tokens; + usageDataCopy.output_tokens += data.usage.output_tokens; } - - return new AIMessageChunk({ - content: fields.coerceContentToString ? "" : [], - additional_kwargs: { ...data.delta }, - usage_metadata: usageMetadata, - }); + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString ? "" : [], + additional_kwargs: { ...data.delta }, + usage_metadata: usageMetadata, + }), + usageData: usageDataCopy, + }; } else if ( data.type === "content_block_start" && data.content_block.type === "tool_use" ) { - return new AIMessageChunk({ - content: fields.coerceContentToString - ? "" - : [ - { - index: data.index, - ...data.content_block, - input: "", - }, - ], - additional_kwargs: {}, - }); + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? "" + : [ + { + index: data.index, + ...data.content_block, + input: "", + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; } else if ( data.type === "content_block_delta" && data.delta.type === "text_delta" ) { const content = data.delta?.text; if (content !== undefined) { - return new AIMessageChunk({ + return { + chunk: new AIMessageChunk({ + content: fields.coerceContentToString + ? content + : [ + { + index: data.index, + ...data.delta, + }, + ], + additional_kwargs: {}, + }), + usageData: usageDataCopy, + }; + } + } else if ( + data.type === "content_block_delta" && + data.delta.type === "input_json_delta" + ) { + return { + chunk: new AIMessageChunk({ content: fields.coerceContentToString - ? content + ? "" : [ { index: data.index, - ...data.delta, + input: data.delta.partial_json, + type: data.delta.type, }, ], additional_kwargs: {}, - }); - } - } else if ( - data.type === "content_block_delta" && - data.delta.type === "input_json_delta" - ) { - // partial JSON incoming! - return new AIMessageChunk({ - content: fields.coerceContentToString - ? "" - : [ - { - index: data.index, - input: data.delta.partial_json, - type: data.delta.type, - }, - ], - additional_kwargs: {}, - }); + }), + usageData: usageDataCopy, + }; } return null; @@ -799,82 +813,85 @@ export class ChatAnthropicMessages< ...formattedMessages, stream: true, }); - const usageData = { input_tokens: 0, output_tokens: 0 }; - let finalChunk: AIMessageChunk | undefined; - let aggregateToolCalls: { - index: number; - type: "tool_use"; - id: string; - name: string; - input: string; - }[] = []; - console.log(aggregateToolCalls.length) + let usageData = { input_tokens: 0, output_tokens: 0 }; + for await (const data of stream) { if (options.signal?.aborted) { stream.controller.abort(); throw new Error("AbortError: User aborted the request."); } - const chunk = _makeMessageChunkFromAnthropicEvent(data, { + const result = _makeMessageChunkFromAnthropicEvent(data, { streamUsage: !!(this.streamUsage || options.streamUsage), coerceContentToString, usageData, }); - if (!chunk) { + if (!result) { continue; } + const { chunk, usageData: updatedUsageData } = result; + usageData = updatedUsageData; - if (!finalChunk) { - finalChunk = chunk; - } else { - finalChunk = concat(finalChunk, chunk); - } + let newToolCallChunk: ToolCallChunk | undefined; - let toolCallChunks: ToolCallChunk[] = []; - const toolUseContentBlocks = Array.isArray(finalChunk.content) ? finalChunk.content.filter((item) => item.type === "tool_use") : []; - if (toolUseContentBlocks.length) { - toolCallChunks = toolUseContentBlocks.flatMap((item) => { - if (!("index" in item && "id" in item && "name" in item && "input" in item)) { - return []; - } + // Initial chunk for tool calls from anthropic contains identifying information like ID and name. + // This chunk does not contain any input JSON. + const toolUseChunks = Array.isArray(chunk.content) + ? chunk.content.find((c) => c.type === "tool_use") + : undefined; + if ( + toolUseChunks && + "index" in toolUseChunks && + "name" in toolUseChunks && + "id" in toolUseChunks + ) { + newToolCallChunk = { + args: "", + id: toolUseChunks.id, + name: toolUseChunks.name, + index: toolUseChunks.index, + }; + } - const parsedPartialJson = parsePartialJson(item.input); - if (!parsedPartialJson) { - return []; - } - - return { - index: item.index, - id: item.id, - name: item.name, - args: JSON.stringify(item.input, null, 2), + // Chunks after the initial chunk only contain the index and partial JSON. + const inputJsonDeltaChunks = Array.isArray(chunk.content) + ? chunk.content.find((c) => c.type === "input_json_delta") + : undefined; + if ( + inputJsonDeltaChunks && + "index" in inputJsonDeltaChunks && + "input" in inputJsonDeltaChunks + ) { + if (typeof inputJsonDeltaChunks.input === "string") { + newToolCallChunk = { + args: inputJsonDeltaChunks.input, + index: inputJsonDeltaChunks.index, + }; + } else { + newToolCallChunk = { + args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), + index: inputJsonDeltaChunks.index, }; - }); - console.log("toolCallChunks.length", toolCallChunks.length) + } } - const newChunk = new AIMessageChunk({ - content: coerceContentToString ? "" : [], - tool_call_chunks: toolCallChunks, - }); - finalChunk = concat(finalChunk, newChunk); - console.log("finalChunk.tool_calls", finalChunk.tool_calls) - const token: string | undefined = typeof chunk.content === "string" && chunk.content !== "" ? chunk.content : undefined; - token ? await runManager?.handleLLMNewToken(token) : null; yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: chunk.content, additional_kwargs: chunk.additional_kwargs, - tool_call_chunks: toolCallChunks, + tool_call_chunks: newToolCallChunk ? [newToolCallChunk] : undefined, usage_metadata: chunk.usage_metadata, response_metadata: chunk.response_metadata, }), text: token ?? "", }); + if (token) { + await runManager?.handleLLMNewToken(token); + } } let usageMetadata: UsageMetadata | undefined; if (this.streamUsage || options.streamUsage) { diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index b539c3586e99..3726164bc7ee 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -173,7 +173,7 @@ test("Can bind & invoke AnthropicTools", async () => { expect(input.location).toBeTruthy(); }); -test.only("Can bind & stream AnthropicTools", async () => { +test("Can bind & stream AnthropicTools", async () => { const modelWithTools = model.bind({ tools: [anthropicTool], }); @@ -181,34 +181,26 @@ test.only("Can bind & stream AnthropicTools", async () => { const result = await modelWithTools.stream( "What is the weather in London today?" ); - let finalMessage; - let finalMessageConcated: AIMessageChunk | undefined; + let finalMessage: AIMessageChunk | undefined; for await (const item of result) { - if (!finalMessageConcated) { - finalMessageConcated = item; + if (!finalMessage) { + finalMessage = item; } else { - finalMessageConcated = concat(finalMessageConcated, item); + finalMessage = concat(finalMessage, item); } - finalMessage = item; } - console.log("finalMessageConcated", finalMessageConcated); - + expect(finalMessage).toBeDefined(); if (!finalMessage) { throw new Error("No final message returned"); } - console.log( - { - tool_calls: JSON.stringify(finalMessage.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); expect(Array.isArray(finalMessage.content)).toBeTruthy(); if (!Array.isArray(finalMessage.content)) { throw new Error("Content is not an array"); } - let toolCall: AnthropicToolResponse | undefined; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let toolCall: Record | undefined; finalMessage.content.forEach((item) => { if (item.type === "tool_use") { toolCall = item as AnthropicToolResponse; @@ -221,7 +213,7 @@ test.only("Can bind & stream AnthropicTools", async () => { const { name, input } = toolCall; expect(name).toBe("get_weather"); expect(input).toBeTruthy(); - expect(input.location).toBeTruthy(); + expect(JSON.parse(input).location).toBeTruthy(); }); test("withStructuredOutput with zod schema", async () => { @@ -412,23 +404,31 @@ test("Can stream tool calls", async () => { "What is the weather in San Francisco CA?" ); - const argsStringArr: string[] = []; - + let realToolCallChunkStreams = 0; + let prevToolCallChunkArgs = ""; + let finalChunk: AIMessageChunk | undefined; for await (const chunk of stream) { - const toolCall = chunk.tool_calls?.[0]; - if (!toolCall) continue; - const stringifiedArgs = JSON.stringify(toolCall.args, null, 2); - - // Push each new "chunk" of args to the array. We'll check the array's - // length at the end to verify multiple tool call chunks were streamed. - if (argsStringArr[argsStringArr.length - 1] !== stringifiedArgs) { - argsStringArr.push(stringifiedArgs); + if (!finalChunk) { + finalChunk = chunk; + } else { + finalChunk = concat(finalChunk, chunk); + } + if (chunk.tool_call_chunks?.[0]?.args) { + if ( + !prevToolCallChunkArgs || + prevToolCallChunkArgs !== chunk.tool_call_chunks[0].args + ) { + realToolCallChunkStreams += 1; + } + prevToolCallChunkArgs = chunk.tool_call_chunks[0].args; } } - expect(argsStringArr.length).toBeGreaterThan(1); - console.log("argsStringArr.length", argsStringArr.length); - const finalToolCall = JSON.parse(argsStringArr[argsStringArr.length - 1]); - console.log("finalToolCall", finalToolCall); - expect(finalToolCall.location).toBeDefined(); + expect(finalChunk?.tool_calls?.[0]).toBeDefined(); + if (!finalChunk?.tool_calls?.[0]) { + return; + } + expect(finalChunk?.tool_calls?.[0].name).toBe("get_weather"); + expect(finalChunk?.tool_calls?.[0].args.location).toBeDefined(); + expect(realToolCallChunkStreams).toBeGreaterThan(1); }); From 2528af7baa98da76a719d616c9d6756e901607c8 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 15 Jul 2024 13:30:45 -0700 Subject: [PATCH 10/14] fixed --- libs/langchain-anthropic/src/chat_models.ts | 57 ++++++++++++++++++- .../src/tests/chat_models-tools.int.test.ts | 51 +++-------------- 2 files changed, 65 insertions(+), 43 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index e3bd162ab5fa..33dd00980a10 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -47,6 +47,7 @@ import type { MessageCreateParams, Tool as AnthropicTool, } from "@anthropic-ai/sdk/resources/index.mjs"; +import { concat } from "@langchain/core/utils/stream"; import { AnthropicToolsOutputParser, extractToolCalls, @@ -815,6 +816,8 @@ export class ChatAnthropicMessages< }); let usageData = { input_tokens: 0, output_tokens: 0 }; + let concatenatedChunks: AIMessageChunk | undefined; + for await (const data of stream) { if (options.signal?.aborted) { stream.controller.abort(); @@ -879,9 +882,61 @@ export class ChatAnthropicMessages< ? chunk.content : undefined; + // Remove `tool_use` content types until the last chunk. + let toolUseContent: + | { + id: string; + type: "tool_use"; + name: string; + input: Record; + } + | undefined; + if (!concatenatedChunks) { + concatenatedChunks = chunk; + } else { + concatenatedChunks = concat(concatenatedChunks, chunk); + } + if ( + Array.isArray(concatenatedChunks.content) && + concatenatedChunks.content.find((c) => c.type === "tool_use") + ) { + try { + const toolUseMsg = concatenatedChunks.content.find( + (c) => c.type === "tool_use" + ); + if ( + !toolUseMsg || + !( + "input" in toolUseMsg || + "name" in toolUseMsg || + "id" in toolUseMsg + ) + ) + return; + const parsedArgs = JSON.parse(toolUseMsg.input); + if (parsedArgs) { + toolUseContent = { + type: "tool_use", + id: toolUseMsg.id, + name: toolUseMsg.name, + input: parsedArgs, + }; + } + } catch (_) { + // no-op + } + } + + const chunkContentWithoutToolUse = Array.isArray(chunk.content) + ? chunk.content.filter((c) => c.type !== "tool_use") + : chunk.content; + if (Array.isArray(chunkContentWithoutToolUse) && toolUseContent) { + chunkContentWithoutToolUse.push(toolUseContent); + } + yield new ChatGenerationChunk({ message: new AIMessageChunk({ - content: chunk.content, + content: chunkContentWithoutToolUse, additional_kwargs: chunk.additional_kwargs, tool_call_chunks: newToolCallChunk ? [newToolCallChunk] : undefined, usage_metadata: chunk.usage_metadata, diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index 3726164bc7ee..a638c9eaa5a8 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -33,7 +33,6 @@ class WeatherTool extends StructuredTool { name = "get_weather"; async _call(input: z.infer) { - console.log(`WeatherTool called with input: ${input}`); return `The weather in ${input.location} is 25°C`; } } @@ -84,7 +83,6 @@ test("Few shotting with tool calls", async () => { ), new HumanMessage("What did you say the weather was?"), ]); - console.log(res); expect(res.content).toContain("24"); }); @@ -96,12 +94,7 @@ test("Can bind & invoke StructuredTools", async () => { const result = await modelWithTools.invoke( "What is the weather in SF today?" ); - console.log( - { - tool_calls: JSON.stringify(result.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); + expect(Array.isArray(result.content)).toBeTruthy(); if (!Array.isArray(result.content)) { throw new Error("Content is not an array"); @@ -117,7 +110,7 @@ test("Can bind & invoke StructuredTools", async () => { } expect(toolCall).toBeTruthy(); const { name, input } = toolCall; - expect(toolCall.input).toEqual(result.tool_calls?.[0].args); + expect(input).toEqual(result.tool_calls?.[0].args); expect(name).toBe("get_weather"); expect(input).toBeTruthy(); expect(input.location).toBeTruthy(); @@ -134,7 +127,6 @@ test("Can bind & invoke StructuredTools", async () => { ), new HumanMessage("What did you say the weather was?"), ]); - console.log(result2); // This should work, but Anthorpic is too skeptical expect(result2.content).toContain("59"); }); @@ -147,12 +139,7 @@ test("Can bind & invoke AnthropicTools", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - tool_calls: JSON.stringify(result.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); + expect(Array.isArray(result.content)).toBeTruthy(); if (!Array.isArray(result.content)) { throw new Error("Content is not an array"); @@ -201,6 +188,7 @@ test("Can bind & stream AnthropicTools", async () => { } // eslint-disable-next-line @typescript-eslint/no-explicit-any let toolCall: Record | undefined; + finalMessage.content.forEach((item) => { if (item.type === "tool_use") { toolCall = item as AnthropicToolResponse; @@ -213,7 +201,7 @@ test("Can bind & stream AnthropicTools", async () => { const { name, input } = toolCall; expect(name).toBe("get_weather"); expect(input).toBeTruthy(); - expect(JSON.parse(input).location).toBeTruthy(); + expect(input.location).toBeTruthy(); }); test("withStructuredOutput with zod schema", async () => { @@ -227,12 +215,6 @@ test("withStructuredOutput with zod schema", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - result, - }, - "withStructuredOutput with zod schema" - ); expect(typeof result.location).toBe("string"); }); @@ -247,12 +229,7 @@ test("withStructuredOutput with AnthropicTool", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - result, - }, - "withStructuredOutput with AnthropicTool" - ); + expect(typeof result.location).toBe("string"); }); @@ -268,12 +245,7 @@ test("withStructuredOutput JSON Schema only", async () => { const result = await modelWithTools.invoke( "What is the weather in London today?" ); - console.log( - { - result, - }, - "withStructuredOutput JSON Schema only" - ); + expect(typeof result.location).toBe("string"); }); @@ -319,12 +291,7 @@ test("Can pass tool_choice", async () => { const result = await modelWithTools.invoke( "What is the sum of 272818 and 281818?" ); - console.log( - { - tool_calls: JSON.stringify(result.content, null, 2), - }, - "Can bind & invoke StructuredTools" - ); + expect(Array.isArray(result.content)).toBeTruthy(); if (!Array.isArray(result.content)) { throw new Error("Content is not an array"); @@ -340,7 +307,7 @@ test("Can pass tool_choice", async () => { } expect(toolCall).toBeTruthy(); const { name, input } = toolCall; - expect(toolCall.input).toEqual(result.tool_calls?.[0].args); + expect(input).toEqual(result.tool_calls?.[0].args); expect(name).toBe("get_weather"); expect(input).toBeTruthy(); expect(input.location).toBeTruthy(); From de903a6300f2e33b21b7edf56ce589ea321a258c Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 15 Jul 2024 13:49:56 -0700 Subject: [PATCH 11/14] cleanup --- libs/langchain-anthropic/src/chat_models.ts | 228 +++++++++++--------- 1 file changed, 130 insertions(+), 98 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index 33dd00980a10..4a65f0df1137 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -549,6 +549,113 @@ function _formatMessagesForAnthropic(messages: BaseMessage[]): { }; } +function extractToolCallChunk( + chunk: AIMessageChunk +): ToolCallChunk | undefined { + let newToolCallChunk: ToolCallChunk | undefined; + + // Initial chunk for tool calls from anthropic contains identifying information like ID and name. + // This chunk does not contain any input JSON. + const toolUseChunks = Array.isArray(chunk.content) + ? chunk.content.find((c) => c.type === "tool_use") + : undefined; + if ( + toolUseChunks && + "index" in toolUseChunks && + "name" in toolUseChunks && + "id" in toolUseChunks + ) { + newToolCallChunk = { + args: "", + id: toolUseChunks.id, + name: toolUseChunks.name, + index: toolUseChunks.index, + }; + } + + // Chunks after the initial chunk only contain the index and partial JSON. + const inputJsonDeltaChunks = Array.isArray(chunk.content) + ? chunk.content.find((c) => c.type === "input_json_delta") + : undefined; + if ( + inputJsonDeltaChunks && + "index" in inputJsonDeltaChunks && + "input" in inputJsonDeltaChunks + ) { + if (typeof inputJsonDeltaChunks.input === "string") { + newToolCallChunk = { + args: inputJsonDeltaChunks.input, + index: inputJsonDeltaChunks.index, + }; + } else { + newToolCallChunk = { + args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), + index: inputJsonDeltaChunks.index, + }; + } + } + + return newToolCallChunk; +} + +function extractToken(chunk: AIMessageChunk): string | undefined { + return typeof chunk.content === "string" && chunk.content !== "" + ? chunk.content + : undefined; +} + +function extractToolUseContent( + chunk: AIMessageChunk, + concatenatedChunks: AIMessageChunk | undefined +) { + let newConcatenatedChunks = concatenatedChunks; + // Remove `tool_use` content types until the last chunk. + let toolUseContent: + | { + id: string; + type: "tool_use"; + name: string; + input: Record; + } + | undefined; + if (!newConcatenatedChunks) { + newConcatenatedChunks = chunk; + } else { + newConcatenatedChunks = concat(newConcatenatedChunks, chunk); + } + if ( + Array.isArray(newConcatenatedChunks.content) && + newConcatenatedChunks.content.find((c) => c.type === "tool_use") + ) { + try { + const toolUseMsg = newConcatenatedChunks.content.find( + (c) => c.type === "tool_use" + ); + if ( + !toolUseMsg || + !("input" in toolUseMsg || "name" in toolUseMsg || "id" in toolUseMsg) + ) + return; + const parsedArgs = JSON.parse(toolUseMsg.input); + if (parsedArgs) { + toolUseContent = { + type: "tool_use", + id: toolUseMsg.id, + name: toolUseMsg.name, + input: parsedArgs, + }; + } + } catch (_) { + // no-op + } + } + + return { + toolUseContent, + concatenatedChunks: newConcatenatedChunks, + }; +} + /** * Wrapper around Anthropic large language models. * @@ -823,120 +930,43 @@ export class ChatAnthropicMessages< stream.controller.abort(); throw new Error("AbortError: User aborted the request."); } + const result = _makeMessageChunkFromAnthropicEvent(data, { streamUsage: !!(this.streamUsage || options.streamUsage), coerceContentToString, usageData, }); - if (!result) { - continue; - } + if (!result) continue; + const { chunk, usageData: updatedUsageData } = result; usageData = updatedUsageData; - let newToolCallChunk: ToolCallChunk | undefined; - - // Initial chunk for tool calls from anthropic contains identifying information like ID and name. - // This chunk does not contain any input JSON. - const toolUseChunks = Array.isArray(chunk.content) - ? chunk.content.find((c) => c.type === "tool_use") - : undefined; - if ( - toolUseChunks && - "index" in toolUseChunks && - "name" in toolUseChunks && - "id" in toolUseChunks - ) { - newToolCallChunk = { - args: "", - id: toolUseChunks.id, - name: toolUseChunks.name, - index: toolUseChunks.index, - }; + const newToolCallChunk = extractToolCallChunk(chunk); + // Maintain concatenatedChunks for accessing the complete `tool_use` content block. + concatenatedChunks = concatenatedChunks + ? concat(concatenatedChunks, chunk) + : chunk; + + let toolUseContent; + const extractedContent = extractToolUseContent(chunk, concatenatedChunks); + if (extractedContent) { + toolUseContent = extractedContent.toolUseContent; + concatenatedChunks = extractedContent.concatenatedChunks; } - // Chunks after the initial chunk only contain the index and partial JSON. - const inputJsonDeltaChunks = Array.isArray(chunk.content) - ? chunk.content.find((c) => c.type === "input_json_delta") - : undefined; - if ( - inputJsonDeltaChunks && - "index" in inputJsonDeltaChunks && - "input" in inputJsonDeltaChunks - ) { - if (typeof inputJsonDeltaChunks.input === "string") { - newToolCallChunk = { - args: inputJsonDeltaChunks.input, - index: inputJsonDeltaChunks.index, - }; - } else { - newToolCallChunk = { - args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), - index: inputJsonDeltaChunks.index, - }; - } - } - - const token: string | undefined = - typeof chunk.content === "string" && chunk.content !== "" - ? chunk.content - : undefined; - - // Remove `tool_use` content types until the last chunk. - let toolUseContent: - | { - id: string; - type: "tool_use"; - name: string; - input: Record; - } - | undefined; - if (!concatenatedChunks) { - concatenatedChunks = chunk; - } else { - concatenatedChunks = concat(concatenatedChunks, chunk); - } - if ( - Array.isArray(concatenatedChunks.content) && - concatenatedChunks.content.find((c) => c.type === "tool_use") - ) { - try { - const toolUseMsg = concatenatedChunks.content.find( - (c) => c.type === "tool_use" - ); - if ( - !toolUseMsg || - !( - "input" in toolUseMsg || - "name" in toolUseMsg || - "id" in toolUseMsg - ) - ) - return; - const parsedArgs = JSON.parse(toolUseMsg.input); - if (parsedArgs) { - toolUseContent = { - type: "tool_use", - id: toolUseMsg.id, - name: toolUseMsg.name, - input: parsedArgs, - }; - } - } catch (_) { - // no-op - } - } - - const chunkContentWithoutToolUse = Array.isArray(chunk.content) + // Filter partial `tool_use` content, and only add `tool_use` chunks if complete JSON available. + const chunkContent = Array.isArray(chunk.content) ? chunk.content.filter((c) => c.type !== "tool_use") : chunk.content; - if (Array.isArray(chunkContentWithoutToolUse) && toolUseContent) { - chunkContentWithoutToolUse.push(toolUseContent); + if (Array.isArray(chunkContent) && toolUseContent) { + chunkContent.push(toolUseContent); } + // Extract the text content token for text field and runManager. + const token = extractToken(chunk); yield new ChatGenerationChunk({ message: new AIMessageChunk({ - content: chunkContentWithoutToolUse, + content: chunkContent, additional_kwargs: chunk.additional_kwargs, tool_call_chunks: newToolCallChunk ? [newToolCallChunk] : undefined, usage_metadata: chunk.usage_metadata, @@ -944,10 +974,12 @@ export class ChatAnthropicMessages< }), text: token ?? "", }); + if (token) { await runManager?.handleLLMNewToken(token); } } + let usageMetadata: UsageMetadata | undefined; if (this.streamUsage || options.streamUsage) { usageMetadata = { From 8bd0ab9aef39d09bb0585ce9ee1297e73ef571c0 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 15 Jul 2024 13:57:59 -0700 Subject: [PATCH 12/14] add comment --- .../src/tests/chat_models-tools.int.test.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index a638c9eaa5a8..3b30be821a15 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -381,6 +381,9 @@ test("Can stream tool calls", async () => { finalChunk = concat(finalChunk, chunk); } if (chunk.tool_call_chunks?.[0]?.args) { + // Check if the args have changed since the last chunk. + // This helps count the number of unique arg updates in the stream, + // ensuring we're receiving multiple chunks with different arg content. if ( !prevToolCallChunkArgs || prevToolCallChunkArgs !== chunk.tool_call_chunks[0].args From 854c7ffb7242ca328e68f28e576be6261d206bb6 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Mon, 15 Jul 2024 16:28:33 -0700 Subject: [PATCH 13/14] add tool call type field --- libs/langchain-anthropic/src/chat_models.ts | 3 +++ libs/langchain-anthropic/src/output_parsers.ts | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index 4a65f0df1137..8aa8295f2676 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -570,6 +570,7 @@ function extractToolCallChunk( id: toolUseChunks.id, name: toolUseChunks.name, index: toolUseChunks.index, + type: "tool_call_chunk", }; } @@ -586,11 +587,13 @@ function extractToolCallChunk( newToolCallChunk = { args: inputJsonDeltaChunks.input, index: inputJsonDeltaChunks.index, + type: "tool_call_chunk", }; } else { newToolCallChunk = { args: JSON.stringify(inputJsonDeltaChunks.input, null, 2), index: inputJsonDeltaChunks.index, + type: "tool_call_chunk", }; } } diff --git a/libs/langchain-anthropic/src/output_parsers.ts b/libs/langchain-anthropic/src/output_parsers.ts index 1168b8d54d14..c5608900b4b9 100644 --- a/libs/langchain-anthropic/src/output_parsers.ts +++ b/libs/langchain-anthropic/src/output_parsers.ts @@ -82,7 +82,12 @@ export function extractToolCalls(content: Record[]) { const toolCalls: ToolCall[] = []; for (const block of content) { if (block.type === "tool_use") { - toolCalls.push({ name: block.name, args: block.input, id: block.id }); + toolCalls.push({ + name: block.name, + args: block.input, + id: block.id, + type: "tool_call", + }); } } return toolCalls; From f66eb09f2e4ea553af42b490ea485cc4a0b8e3f0 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Mon, 15 Jul 2024 18:01:03 -0700 Subject: [PATCH 14/14] Fix test --- .../src/tests/chat_models-tools.int.test.ts | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts index 3b30be821a15..0bd1fe766875 100644 --- a/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models-tools.int.test.ts @@ -8,9 +8,9 @@ import { ToolMessage, } from "@langchain/core/messages"; import { StructuredTool, tool } from "@langchain/core/tools"; +import { concat } from "@langchain/core/utils/stream"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; -import { concat } from "@langchain/core/utils/stream"; import { ChatAnthropic } from "../chat_models.js"; import { AnthropicToolResponse } from "../types.js"; @@ -163,6 +163,10 @@ test("Can bind & invoke AnthropicTools", async () => { test("Can bind & stream AnthropicTools", async () => { const modelWithTools = model.bind({ tools: [anthropicTool], + tool_choice: { + type: "tool", + name: "get_weather", + }, }); const result = await modelWithTools.stream( @@ -187,21 +191,15 @@ test("Can bind & stream AnthropicTools", async () => { throw new Error("Content is not an array"); } // eslint-disable-next-line @typescript-eslint/no-explicit-any - let toolCall: Record | undefined; - - finalMessage.content.forEach((item) => { - if (item.type === "tool_use") { - toolCall = item as AnthropicToolResponse; - } - }); - if (!toolCall) { + const toolCall = finalMessage.tool_calls?.[0]; + if (toolCall === undefined) { throw new Error("No tool call found"); } expect(toolCall).toBeTruthy(); - const { name, input } = toolCall; + const { name, args } = toolCall; expect(name).toBe("get_weather"); - expect(input).toBeTruthy(); - expect(input.location).toBeTruthy(); + expect(args).toBeTruthy(); + expect(args.location).toBeTruthy(); }); test("withStructuredOutput with zod schema", async () => { @@ -366,7 +364,12 @@ test("Can stream tool calls", async () => { schema: zodSchema, }); - const modelWithTools = model.bindTools([weatherTool]); + const modelWithTools = model.bindTools([weatherTool], { + tool_choice: { + type: "tool", + name: "get_weather", + }, + }); const stream = await modelWithTools.stream( "What is the weather in San Francisco CA?" ); @@ -395,9 +398,6 @@ test("Can stream tool calls", async () => { } expect(finalChunk?.tool_calls?.[0]).toBeDefined(); - if (!finalChunk?.tool_calls?.[0]) { - return; - } expect(finalChunk?.tool_calls?.[0].name).toBe("get_weather"); expect(finalChunk?.tool_calls?.[0].args.location).toBeDefined(); expect(realToolCallChunkStreams).toBeGreaterThan(1);