Skip to content

Commit

Permalink
✨feat: add ollama embedding provider
Browse files Browse the repository at this point in the history
  • Loading branch information
cookieY committed Sep 29, 2024
1 parent 6f06057 commit 54ced07
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ export class LobeBedrockAI implements LobeRuntimeAI {
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<EmbeddingItem[]> {
const input = payload.input as string[];
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
this.invokeTitanModel(
this.invokeEmbeddingModel(
{
dimensions: payload.dimensions,
index: index,
Expand All @@ -84,7 +84,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {
return embeddings;
}

private invokeTitanModel = async (
private invokeEmbeddingModel = async (
payload: BedRockEmbeddingsParams,
options?: EmbeddingsOptions,
): Promise<EmbeddingItem> => {
Expand Down
44 changes: 39 additions & 5 deletions src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import { ChatModelCard } from '@/types/llm';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
ChatCompetitionOptions,
ChatStreamPayload,
EmbeddingItem,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { StreamingResponse } from '../utils/response';
import { OllamaStream } from '../utils/streams';
Expand Down Expand Up @@ -45,10 +51,7 @@ export class LobeOllamaAI implements LobeRuntimeAI {
options: {
frequency_penalty: payload.frequency_penalty,
presence_penalty: payload.presence_penalty,
temperature:
payload.temperature !== undefined
? payload.temperature / 2
: undefined,
temperature: payload.temperature !== undefined ? payload.temperature / 2 : undefined,
top_p: payload.top_p,
},
stream: true,
Expand All @@ -68,13 +71,44 @@ export class LobeOllamaAI implements LobeRuntimeAI {
}
}

async embeddings(payload: EmbeddingsPayload): Promise<EmbeddingItem[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
this.invokeEmbeddingModel(inputText, payload.model, index),
);
const embeddings = await Promise.all(promises);
return embeddings;
}

async models(): Promise<ChatModelCard[]> {
const list = await this.client.list();
return list.models.map((model) => ({
id: model.name,
}));
}

private invokeEmbeddingModel = async (
inputText: string,
model: string,
index: number,
): Promise<EmbeddingItem> => {
try {
const responseBody = await this.client.embeddings({
model: model,
prompt: inputText,
});
return { embedding: responseBody.embedding, index: index, object: 'embedding' };
} catch (error) {
const e = error as { message: string; name: string; status_code: number };

throw AgentRuntimeError.chat({
error: { message: e.message, name: e.name, status_code: e.status_code },
errorType: AgentRuntimeErrorType.OllamaBizError,
provider: ModelProvider.Ollama,
});
}
};

private buildOllamaMessages(messages: OpenAIChatMessage[]) {
return messages.map((message) => this.convertContentToOllamaMessage(message));
}
Expand Down

0 comments on commit 54ced07

Please sign in to comment.