Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

baidu-qianfan[patch]: Fix streaming mode of Qianfan #6661

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 76 additions & 143 deletions libs/langchain-baidu-qianfan/src/chat_models.ts
stanoswald marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import {
} from "@langchain/core/outputs";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { convertEventStreamToIterableReadableDataStream } from "@langchain/core/utils/event_source_parse";
import { ChatCompletion } from "@baiducloud/qianfan";

/**
Expand Down Expand Up @@ -60,8 +59,12 @@ interface ChatCompletionResponse {
id: string;
object: string;
created: number;
sentence_id: number;
is_end: boolean;
is_truncated: boolean;
result: string;
need_clear_history: boolean;
finish_reason: string;
usage: TokenUsage;
}

Expand Down Expand Up @@ -305,167 +308,96 @@ export class ChatBaiduQianfan
private _ensureMessages(messages: BaseMessage[]): Qianfan[] {
return messages.map((message) => ({
role: messageToQianfanRole(message),
content: message.text,
content: message.content.toString(),
}));
}

/** @ignore */
async _generate(
messages: BaseMessage[],
_options?: this["ParsedCallOptions"],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
const tokenUsage: TokenUsage = {};

const params = this.invocationParams();
if (this.streaming) {
let finalChunk: ChatGenerationChunk | undefined;
const stream = this._streamResponseChunks(messages, options, runManager);
for await (const chunk of stream) {
if (finalChunk === undefined) {
finalChunk = chunk;
} else {
finalChunk = finalChunk.concat(chunk);
}
}

// Qianfan requires the system message to be put in the params, not messages array
const systemMessage = messages.find(
(message) => message._getType() === "system"
);
if (systemMessage) {
// eslint-disable-next-line no-param-reassign
messages = messages.filter((message) => message !== systemMessage);
params.system = systemMessage.text;
}
const messagesMapped = this._ensureMessages(messages);
if (finalChunk === undefined) {
throw new Error("No chunks returned from BaiduQianFan API.");
}

const data = params.stream
? await new Promise<ChatCompletionResponse>((resolve, reject) => {
let rejected = false;
this.completionWithRetry(
{
...params,
messages: messagesMapped,
},
true,
(event) => {
resolve(event.data);
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(event.data ?? "");
}
).catch((error) => {
if (!rejected) {
rejected = true;
reject(error);
}
});
})
: await this.completionWithRetry(
return {
generations: [
{
...params,
messages: messagesMapped,
text: finalChunk.text,
message: finalChunk.message,
},
false
).then((data) => {
if (data?.error_code) {
throw new Error(data?.error_msg);
}
return data;
});

const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = data.usage ?? {};

if (completionTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
}

if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens;
}

if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
],
llmOutput: finalChunk.generationInfo?.usage ?? {},
};
} else {
const params = this.invocationParams();

const systemMessage = messages.find(
(message) => message._getType() === "system"
);
if (systemMessage) {
// eslint-disable-next-line no-param-reassign
messages = messages.filter((message) => message !== systemMessage);
params.system = systemMessage.content.toString();
}
const messagesMapped = this._ensureMessages(messages);

const data = (await this.completionWithRetry(
{
...params,
messages: messagesMapped,
},
false
)) as ChatCompletionResponse;

const tokenUsage = data.usage || {};

const generations: ChatGeneration[] = [
{
text: data.result || "",
message: new AIMessage(data.result || ""),
},
];

return {
generations,
llmOutput: { tokenUsage },
};
}

const generations: ChatGeneration[] = [];
const text = data.result ?? "";
generations.push({
text,
message: new AIMessage(text),
});
return {
generations,
llmOutput: { tokenUsage },
};
}

/** @ignore */
async completionWithRetry(
request: ChatCompletionRequest,
stream: boolean,
onmessage?: (event: MessageEvent) => void
) {
stream: boolean
): Promise<
ChatCompletionResponse | AsyncIterableIterator<ChatCompletionResponse>
> {
const makeCompletionRequest = async () => {
console.log(request);
const response = await this.client.chat(request, this.model);

if (!stream) {
return response;
} else {
let streamResponse = { result: "" } as {
id: string;
object: string;
created: number;
sentence_id?: number;
result: string;
need_clear_history: boolean;
usage: TokenUsage;
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
for await (const message of response as AsyncIterableIterator<any>) {
// 返回结果
if (!streamResponse) {
streamResponse = {
id: message.id,
object: message.object,
created: message.created,
result: message.result,
need_clear_history: message.need_clear_history,
usage: message.usage,
};
} else {
streamResponse.result += message.result;
streamResponse.created = message.created;
streamResponse.need_clear_history = message.need_clear_history;
streamResponse.usage = message.usage;
}
}
const event = new MessageEvent("message", {
data: streamResponse,
});
onmessage?.(event);
return response as AsyncIterableIterator<ChatCompletionResponse>;
}
};

return this.caller.call(makeCompletionRequest);
}

private async createStream(request: ChatCompletionRequest) {
const response = await this.client.chat(
{
...request,
stream: true,
},
this.model
);

return convertEventStreamToIterableReadableDataStream(response);
}

private _deserialize(json: string) {
try {
return JSON.parse(json);
} catch (e) {
console.warn(`Received a non-JSON parseable chunk: ${json}`);
}
}

async *_streamResponseChunks(
messages: BaseMessage[],
_options?: this["ParsedCallOptions"],
Expand All @@ -476,27 +408,28 @@ export class ChatBaiduQianfan
stream: true,
};

// Qianfan requires the system message to be put in the params, not messages array
const systemMessage = messages.find(
(message) => message._getType() === "system"
);
if (systemMessage) {
// eslint-disable-next-line no-param-reassign
messages = messages.filter((message) => message !== systemMessage);
parameters.system = systemMessage.text;
parameters.system = systemMessage.content.toString();
}
const messagesMapped = this._ensureMessages(messages);

const stream = await this.caller.call(async () =>
this.createStream({
...parameters,
messages: messagesMapped,
})
);
const stream = (await this.caller.call(async () =>
this.completionWithRetry(
{
...parameters,
messages: messagesMapped,
},
true
)
)) as AsyncIterableIterator<ChatCompletionResponse>;

for await (const chunk of stream) {
const deserializedChunk = this._deserialize(chunk);
const { result, is_end, id } = deserializedChunk;
const { result, is_end, id } = chunk;
yield new ChatGenerationChunk({
text: result,
message: new AIMessageChunk({ content: result }),
Expand Down
Loading