Skip to content

Commit eed9357

Browse files
authored
✨ feat: support to customize Embedding model with env (lobehub#5177)
* feat: 添加嵌入模型配置支持,更新相关文档和测试 * feat: 重构文件配置,更新默认设置和相关测试 * ♻ Code Refactoring - Update the file configuration and standardize the model naming to camel case.
1 parent 0ac5802 commit eed9357

File tree

16 files changed

+275
-31
lines changed

16 files changed

+275
-31
lines changed

.env.example

+4-1
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,7 @@ OPENAI_API_KEY=sk-xxxxxxxxx
208208

209209
# use `openssl rand -base64 32` to generate a key for the encryption of the database
210210
# we use this key to encrypt the user api key
211-
# KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx=
211+
#KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx=
212+
213+
# Specify the Embedding model and Reranker model(unImplemented)
214+
# DEFAULT_FILES_CONFIG="embedding_model=openai/embedding-text-3-small,reranker_model=cohere/rerank-english-v3.0,query_mode=full_text"

docs/self-hosting/advanced/knowledge-base.mdx

+9
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,12 @@ Unstructured.io is a powerful document processing tool.
6262
- **Note**: Evaluate processing needs based on document complexity
6363

6464
By correctly configuring and integrating these core components, you can build a powerful and efficient knowledge base system for LobeChat. Each component plays a crucial role in the overall architecture, supporting advanced document management and intelligent retrieval functions.
65+
66+
### 5. Custom Embedding
67+
68+
- **Purpose**: Use different Embedding generate vector representations for semantic search
69+
- **Options**: support model provider list: zhipu/github/openai/bedrock/ollama
70+
- **Deployment Tip**: Used to configure the default Embedding model
71+
```
72+
environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small
73+
```

docs/self-hosting/advanced/knowledge-base.zh-CN.mdx

+9
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,12 @@ Unstructured.io 是一个强大的文档处理工具。
6060
- **注意事项**:评估处理需求,根据文档复杂度决定是否部署
6161

6262
通过正确配置和集成这些核心组件,您可以为 LobeChat 构建一个强大、高效的知识库系统。每个组件都在整体架构中扮演着关键角色,共同支持高级的文档管理和智能检索功能。
63+
64+
### 5. 自定义 Embedding(可选)
65+
66+
- **用途**: 使用不同的嵌入模型(Embedding)生成文本的向量表示,用于语义搜索
67+
- **选项**: 支持的模型提供商:zhipu/github/openai/bedrock/ollama
68+
- **部署建议**: 使用环境变量配置默认嵌入模型
69+
```
70+
environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small
71+
```

src/config/knowledge.ts

+2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ import { z } from 'zod';
44
export const getKnowledgeConfig = () => {
55
return createEnv({
66
runtimeEnv: {
7+
DEFAULT_FILES_CONFIG: process.env.DEFAULT_FILES_CONFIG,
78
UNSTRUCTURED_API_KEY: process.env.UNSTRUCTURED_API_KEY,
89
UNSTRUCTURED_SERVER_URL: process.env.UNSTRUCTURED_SERVER_URL,
910
},
1011
server: {
12+
DEFAULT_FILES_CONFIG: z.string().optional(),
1113
UNSTRUCTURED_API_KEY: z.string().optional(),
1214
UNSTRUCTURED_SERVER_URL: z.string().optional(),
1315
},

src/const/settings/knowledge.ts

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { FilesConfig, FilesConfigItem } from '@/types/user/settings/filesConfig';
2+
3+
import {
4+
DEFAULT_EMBEDDING_MODEL,
5+
DEFAULT_PROVIDER,
6+
DEFAULT_RERANK_MODEL,
7+
DEFAULT_RERANK_PROVIDER,
8+
DEFAULT_RERANK_QUERY_MODE,
9+
} from './llm';
10+
11+
export const DEFAULT_FILE_EMBEDDING_MODEL_ITEM: FilesConfigItem = {
12+
model: DEFAULT_EMBEDDING_MODEL,
13+
provider: DEFAULT_PROVIDER,
14+
};
15+
16+
export const DEFAULT_FILE_RERANK_MODEL_ITEM: FilesConfigItem = {
17+
model: DEFAULT_RERANK_MODEL,
18+
provider: DEFAULT_RERANK_PROVIDER,
19+
};
20+
21+
export const DEFAULT_FILES_CONFIG: FilesConfig = {
22+
embeddingModel: DEFAULT_FILE_EMBEDDING_MODEL_ITEM,
23+
queryModel: DEFAULT_RERANK_QUERY_MODE,
24+
rerankerModel: DEFAULT_FILE_RERANK_MODEL_ITEM,
25+
};

src/const/settings/llm.ts

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ export const DEFAULT_LLM_CONFIG = genUserLLMConfig({
1515
});
1616

1717
export const DEFAULT_MODEL = 'gpt-4o-mini';
18+
1819
export const DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small';
20+
export const DEFAULT_EMBEDDING_PROVIDER = ModelProvider.OpenAI;
21+
22+
export const DEFAULT_RERANK_MODEL = 'rerank-english-v3.0';
23+
export const DEFAULT_RERANK_PROVIDER = 'cohere';
24+
export const DEFAULT_RERANK_QUERY_MODE = 'full_text';
1925

2026
export const DEFAULT_PROVIDER = ModelProvider.OpenAI;

src/database/schemas/ragEvals.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* eslint-disable sort-keys-fix/sort-keys-fix */
22
import { integer, jsonb, pgTable, text, uuid } from 'drizzle-orm/pg-core';
33

4-
import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
4+
import { DEFAULT_MODEL } from '@/const/settings';
55
import { EvalEvaluationStatus } from '@/types/eval';
66

77
import { timestamps } from './_helpers';
@@ -60,7 +60,7 @@ export const evalEvaluation = pgTable('rag_eval_evaluations', {
6060
onDelete: 'cascade',
6161
}),
6262
languageModel: text('language_model').$defaultFn(() => DEFAULT_MODEL),
63-
embeddingModel: text('embedding_model').$defaultFn(() => DEFAULT_EMBEDDING_MODEL),
63+
embeddingModel: text('embedding_model'),
6464

6565
userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }),
6666
...timestamps,

src/libs/agent-runtime/bedrock/index.ts

+64-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import {
22
BedrockRuntimeClient,
3+
InvokeModelCommand,
34
InvokeModelWithResponseStreamCommand,
45
} from '@aws-sdk/client-bedrock-runtime';
56
import { experimental_buildLlama2Prompt } from 'ai/prompts';
67

78
import { LobeRuntimeAI } from '../BaseAI';
89
import { AgentRuntimeErrorType } from '../error';
9-
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
10+
import {
11+
ChatCompetitionOptions,
12+
ChatStreamPayload,
13+
Embeddings,
14+
EmbeddingsOptions,
15+
EmbeddingsPayload,
16+
ModelProvider,
17+
} from '../types';
1018
import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers';
1119
import { AgentRuntimeError } from '../utils/createError';
1220
import { debugStream } from '../utils/debugStream';
@@ -32,9 +40,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {
3240
constructor({ region, accessKeyId, accessKeySecret, sessionToken }: LobeBedrockAIParams = {}) {
3341
if (!(accessKeyId && accessKeySecret))
3442
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials);
35-
3643
this.region = region ?? 'us-east-1';
37-
3844
this.client = new BedrockRuntimeClient({
3945
credentials: {
4046
accessKeyId: accessKeyId,
@@ -50,6 +56,61 @@ export class LobeBedrockAI implements LobeRuntimeAI {
5056

5157
return this.invokeClaudeModel(payload, options);
5258
}
59+
/**
60+
* Supports the Amazon Titan Text models series.
61+
* Cohere Embed models are not supported
62+
* because the current text size per request
63+
* exceeds the maximum 2048 characters limit
64+
* for a single request for this series of models.
65+
* [bedrock embed guide] https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
66+
*/
67+
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> {
68+
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
69+
const promises = input.map((inputText: string) =>
70+
this.invokeEmbeddingModel(
71+
{
72+
dimensions: payload.dimensions,
73+
input: inputText,
74+
model: payload.model,
75+
},
76+
options,
77+
),
78+
);
79+
return Promise.all(promises);
80+
}
81+
82+
private invokeEmbeddingModel = async (
83+
payload: EmbeddingsPayload,
84+
options?: EmbeddingsOptions,
85+
): Promise<Embeddings> => {
86+
const command = new InvokeModelCommand({
87+
accept: 'application/json',
88+
body: JSON.stringify({
89+
dimensions: payload.dimensions,
90+
inputText: payload.input,
91+
normalize: true,
92+
}),
93+
contentType: 'application/json',
94+
modelId: payload.model,
95+
});
96+
try {
97+
const res = await this.client.send(command, { abortSignal: options?.signal });
98+
const responseBody = JSON.parse(new TextDecoder().decode(res.body));
99+
return responseBody.embedding;
100+
} catch (e) {
101+
const err = e as Error & { $metadata: any };
102+
throw AgentRuntimeError.chat({
103+
error: {
104+
body: err.$metadata,
105+
message: err.message,
106+
type: err.name,
107+
},
108+
errorType: AgentRuntimeErrorType.ProviderBizError,
109+
provider: ModelProvider.Bedrock,
110+
region: this.region,
111+
});
112+
}
113+
};
53114

54115
private invokeClaudeModel = async (
55116
payload: ChatStreamPayload,

src/libs/agent-runtime/ollama/index.ts

+37-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ import { ChatModelCard } from '@/types/llm';
66

77
import { LobeRuntimeAI } from '../BaseAI';
88
import { AgentRuntimeErrorType } from '../error';
9-
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
9+
import {
10+
ChatCompetitionOptions,
11+
ChatStreamPayload,
12+
Embeddings,
13+
EmbeddingsPayload,
14+
ModelProvider,
15+
} from '../types';
1016
import { AgentRuntimeError } from '../utils/createError';
1117
import { debugStream } from '../utils/debugStream';
1218
import { StreamingResponse } from '../utils/response';
@@ -84,13 +90,43 @@ export class LobeOllamaAI implements LobeRuntimeAI {
8490
}
8591
}
8692

93+
async embeddings(payload: EmbeddingsPayload): Promise<Embeddings[]> {
94+
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
95+
const promises = input.map((inputText: string) =>
96+
this.invokeEmbeddingModel({
97+
dimensions: payload.dimensions,
98+
input: inputText,
99+
model: payload.model,
100+
}),
101+
);
102+
return await Promise.all(promises);
103+
}
104+
87105
async models(): Promise<ChatModelCard[]> {
88106
const list = await this.client.list();
89107
return list.models.map((model) => ({
90108
id: model.name,
91109
}));
92110
}
93111

112+
private invokeEmbeddingModel = async (payload: EmbeddingsPayload): Promise<Embeddings> => {
113+
try {
114+
const responseBody = await this.client.embeddings({
115+
model: payload.model,
116+
prompt: payload.input as string,
117+
});
118+
return responseBody.embedding;
119+
} catch (error) {
120+
const e = error as { message: string; name: string; status_code: number };
121+
122+
throw AgentRuntimeError.chat({
123+
error: { message: e.message, name: e.name, status_code: e.status_code },
124+
errorType: AgentRuntimeErrorType.OllamaBizError,
125+
provider: ModelProvider.Ollama,
126+
});
127+
}
128+
};
129+
94130
private buildOllamaMessages(messages: OpenAIChatMessage[]) {
95131
return messages.map((message) => this.convertContentToOllamaMessage(message));
96132
}

src/server/globalConfig/index.ts

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { appEnv, getAppConfig } from '@/config/app';
22
import { authEnv } from '@/config/auth';
33
import { fileEnv } from '@/config/file';
4+
import { knowledgeEnv } from '@/config/knowledge';
45
import { langfuseEnv } from '@/config/langfuse';
56
import { enableNextAuth } from '@/const/auth';
67
import { parseSystemAgent } from '@/server/globalConfig/parseSystemAgent';
@@ -9,6 +10,7 @@ import { GlobalServerConfig } from '@/types/serverConfig';
910
import { genServerLLMConfig } from './_deprecated';
1011
import { genServerAiProvidersConfig } from './genServerAiProviderConfig';
1112
import { parseAgentConfig } from './parseDefaultAgent';
13+
import { parseFilesConfig } from './parseFilesConfig';
1214

1315
export const getServerGlobalConfig = () => {
1416
const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig();
@@ -73,3 +75,7 @@ export const getServerDefaultAgentConfig = () => {
7375

7476
return parseAgentConfig(DEFAULT_AGENT_CONFIG) || {};
7577
};
78+
79+
export const getServerDefaultFilesConfig = () => {
80+
return parseFilesConfig(knowledgeEnv.DEFAULT_FILES_CONFIG);
81+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { describe, expect, it } from 'vitest';
2+
3+
import { parseFilesConfig } from './parseFilesConfig';
4+
5+
describe('parseFilesConfig', () => {
6+
// 测试embeddings配置是否被正确解析
7+
it('parses embeddings configuration correctly', () => {
8+
const envStr =
9+
'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text';
10+
const expected = {
11+
embeddingModel: { provider: 'openai', model: 'embedding-text-3-large' },
12+
rerankerModel: { provider: 'cohere', model: 'rerank-english-v3.0' },
13+
queryModel: 'full_text',
14+
};
15+
expect(parseFilesConfig(envStr)).toEqual(expected);
16+
});
17+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { DEFAULT_FILES_CONFIG } from '@/const/settings/knowledge';
2+
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
3+
import { FilesConfig } from '@/types/user/settings/filesConfig';
4+
5+
const protectedKeys = Object.keys({
6+
embedding_model: null,
7+
query_model: null,
8+
reranker_model: null,
9+
});
10+
11+
export const parseFilesConfig = (envString: string = ''): SystemEmbeddingConfig => {
12+
if (!envString) return DEFAULT_FILES_CONFIG;
13+
const config: FilesConfig = {} as any;
14+
15+
// 处理全角逗号和多余空格
16+
let envValue = envString.replaceAll(',', ',').trim();
17+
18+
const pairs = envValue.split(',');
19+
20+
for (const pair of pairs) {
21+
const [key, value] = pair.split('=').map((s) => s.trim());
22+
23+
if (key && value) {
24+
const [provider, ...modelParts] = value.split('/');
25+
const model = modelParts.join('/');
26+
27+
if ((!provider || !model) && key !== 'query_model') {
28+
throw new Error('Missing model or provider value');
29+
}
30+
31+
if (key === 'query_model' && value === '') {
32+
throw new Error('Missing query mode value');
33+
}
34+
35+
if (protectedKeys.includes(key)) {
36+
switch (key) {
37+
case 'embedding_model': {
38+
config.embeddingModel = { model: model.trim(), provider: provider.trim() };
39+
break;
40+
}
41+
case 'reranker_model': {
42+
config.rerankerModel = { model: model.trim(), provider: provider.trim() };
43+
break;
44+
}
45+
case 'query_model': {
46+
config.queryModel = value;
47+
break;
48+
}
49+
}
50+
}
51+
} else {
52+
throw new Error('Invalid environment variable format');
53+
}
54+
}
55+
56+
return config;
57+
};

0 commit comments

Comments
 (0)