Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: Adds support for plain message objects as shorthand for messages #5954

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@ import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js";

test("Test ChatModel accepts array shorthand for messages", async () => {
const model = new FakeChatModel({});
const response = await model.invoke([["human", "Hello there!"]]);
expect(response.content).toEqual("Hello there!");
});

test("Test ChatModel accepts object shorthand for messages", async () => {
const model = new FakeChatModel({});
const response = await model.invoke([
{
type: "human",
content: "Hello there!",
additional_kwargs: {},
example: true,
},
]);
expect(response.content).toEqual("Hello there!");
});

test("Test ChatModel uses callbacks", async () => {
const model = new FakeChatModel({});
let acc = "";
Expand Down
4 changes: 4 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ export abstract class BaseMessageChunk extends BaseMessage {

export type BaseMessageLike =
| BaseMessage
| ({
type: MessageType | "user" | "assistant" | "placeholder";
} & BaseMessageFields &
Record<string, unknown>)
| [
StringWithAutocomplete<
MessageType | "user" | "assistant" | "placeholder"
Expand Down
32 changes: 22 additions & 10 deletions langchain-core/src/messages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
isBaseMessage,
StoredMessage,
StoredMessageV1,
BaseMessageFields,
} from "./base.js";
import {
ChatMessage,
Expand All @@ -20,6 +21,23 @@ import { HumanMessage, HumanMessageChunk } from "./human.js";
import { SystemMessage, SystemMessageChunk } from "./system.js";
import { ToolMessage, ToolMessageFieldsWithToolCallId } from "./tool.js";

function _constructMessageFromParams(
params: BaseMessageFields & { type: string }
) {
const { type, ...rest } = params;
if (type === "human" || type === "user") {
return new HumanMessage(rest);
} else if (type === "ai" || type === "assistant") {
return new AIMessage(rest);
} else if (type === "system") {
return new SystemMessage(rest);
} else {
throw new Error(
`Unable to coerce message from array: only human, AI, or system message coercion is currently supported.`
);
}
}

export function coerceMessageLikeToMessage(
messageLike: BaseMessageLike
): BaseMessage {
Expand All @@ -28,17 +46,11 @@ export function coerceMessageLikeToMessage(
} else if (isBaseMessage(messageLike)) {
return messageLike;
}
const [type, content] = messageLike;
if (type === "human" || type === "user") {
return new HumanMessage({ content });
} else if (type === "ai" || type === "assistant") {
return new AIMessage({ content });
} else if (type === "system") {
return new SystemMessage({ content });
if (Array.isArray(messageLike)) {
const [type, content] = messageLike;
return _constructMessageFromParams({ type, content });
} else {
throw new Error(
`Unable to coerce message from array: only human, AI, or system message coercion is currently supported.`
);
return _constructMessageFromParams(messageLike);
}
}

Expand Down
49 changes: 21 additions & 28 deletions langchain-core/src/prompts/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,50 +134,43 @@ export class MessagesPlaceholder<
return [this.variableName];
}

validateInputOrThrow(
input: Array<unknown> | undefined,
variableName: Extract<keyof RunInput, string>
): input is BaseMessage[] {
async formatMessages(
values: TypedPromptInputValues<RunInput>
): Promise<BaseMessage[]> {
const input = values[this.variableName];
if (this.optional && !input) {
return false;
return [];
} else if (!input) {
const error = new Error(
`Error: Field "${variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined`
`Field "${this.variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined`
);
error.name = "InputFormatError";
throw error;
}

let isInputBaseMessage = false;

if (Array.isArray(input)) {
isInputBaseMessage = input.every((message) =>
isBaseMessage(message as BaseMessage)
);
} else {
isInputBaseMessage = isBaseMessage(input as BaseMessage);
}

if (!isInputBaseMessage) {
let formattedMessages;
try {
if (Array.isArray(input)) {
formattedMessages = input.map(coerceMessageLikeToMessage);
} else {
formattedMessages = [coerceMessageLikeToMessage(input)];
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
const readableInput =
typeof input === "string" ? input : JSON.stringify(input, null, 2);

const error = new Error(
`Error: Field "${variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: ${readableInput}`
[
`Field "${this.variableName}" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages or coerceable values as input.`,
`Received value: ${readableInput}`,
`Additional message: ${e.message}`,
].join("\n\n")
);
error.name = "InputFormatError";
throw error;
}

return true;
}

async formatMessages(
values: TypedPromptInputValues<RunInput>
): Promise<BaseMessage[]> {
this.validateInputOrThrow(values[this.variableName], this.variableName);

return values[this.variableName] ?? [];
return formattedMessages;
}
}

Expand Down
56 changes: 55 additions & 1 deletion langchain-core/src/prompts/tests/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ test("Test MessagesPlaceholder not optional", async () => {
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await expect(prompt.formatMessages({} as any)).rejects.toThrow(
'Error: Field "foo" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined'
'Field "foo" in prompt uses a MessagesPlaceholder, which expects an array of BaseMessages as an input value. Received: undefined'
);
});

Expand All @@ -323,6 +323,60 @@ test("Test MessagesPlaceholder shorthand in a chat prompt template", async () =>
]);
});

test("Test MessagesPlaceholder shorthand in a chat prompt template with object format", async () => {
const prompt = ChatPromptTemplate.fromMessages([["placeholder", "{foo}"]]);
const messages = await prompt.formatMessages({
foo: [
{
type: "system",
content: "some initial content",
},
{
type: "human",
content: [
{
text: "page: 1\ndescription: One Purchase Flow\ntimestamp: '2024-06-04T14:46:46.062Z'\ntype: navigate\nscreenshot_present: true\n",
type: "text",
},
{
text: "page: 3\ndescription: intent_str=buy,mode_str=redirect,screenName_str=order-completed,\ntimestamp: '2024-06-04T14:46:58.846Z'\ntype: Screen View\nscreenshot_present: false\n",
type: "text",
},
],
},
{
type: "assistant",
content: "some captivating response",
},
],
});
expect(messages).toEqual([
new SystemMessage("some initial content"),
new HumanMessage({
content: [
{
text: "page: 1\ndescription: One Purchase Flow\ntimestamp: '2024-06-04T14:46:46.062Z'\ntype: navigate\nscreenshot_present: true\n",
type: "text",
},
{
text: "page: 3\ndescription: intent_str=buy,mode_str=redirect,screenName_str=order-completed,\ntimestamp: '2024-06-04T14:46:58.846Z'\ntype: Screen View\nscreenshot_present: false\n",
type: "text",
},
],
}),
new AIMessage("some captivating response"),
]);
});

test("Test MessagesPlaceholder with invalid shorthand should throw", async () => {
const prompt = ChatPromptTemplate.fromMessages([["placeholder", "{foo}"]]);
await expect(() =>
prompt.formatMessages({
foo: [{ badFormatting: true }],
})
).rejects.toThrow();
});

test("Test using partial", async () => {
const userPrompt = new PromptTemplate({
template: "{foo}{bar}",
Expand Down
Loading