Skip to content

Commit

Permalink
feat: 支持流式输出
Browse files Browse the repository at this point in the history
  • Loading branch information
crazyurus committed May 12, 2024
1 parent 19f6288 commit f2e00ea
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 18 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "aigc-detector",
"version": "1.0.3",
"version": "1.0.4",
"description": "Detect if content is generated by AI",
"keywords": [
"aigc",
Expand Down
27 changes: 18 additions & 9 deletions src/cli/commands/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class ChatCommand extends BaseCommand {

static flags = {};

private lastMessage = 'How can I help you today?';

private messages = new ChatMessageHistory();

async run(): Promise<void> {
Expand All @@ -54,16 +52,25 @@ class ChatCommand extends BaseCommand {
apiKey: config.apiKey,
platform: config.platform as unknown as Platform
});
const userDisplay = this.getDisplayContent(PromptRole.USER);
const aiDisplay = this.getDisplayContent(PromptRole.AI);
let lastMessage = 'How can I help you today?';

process.stdout.write(aiDisplay + lastMessage + '\n');

// eslint-disable-next-line no-constant-condition
while (true) {
const aiMessage = await this.addMessage(PromptRole.AI, this.lastMessage);
const userMessage = await this.getUserMessage(aiMessage + `\n${userDisplay}`);
const answer = await detector.chat(userMessage, await this.messages.getMessages());
const userMessage = await this.getUserMessage();
const stream = detector.chat(userMessage, await this.messages.getMessages());

process.stdout.write(aiDisplay);
stream.pipe(process.stdout);

lastMessage = await stream.getData();

process.stdout.write('\n');

await this.addMessage(PromptRole.USER, userMessage);
this.lastMessage = answer;
await this.addMessage(PromptRole.AI, lastMessage);
}
} else {
this.showHelp();
Expand All @@ -84,9 +91,11 @@ class ChatCommand extends BaseCommand {
return chalk[roleDisplay.color](`[${roleDisplay.name}] `);
}

private getUserMessage(aiMessage: string): Promise<string> {
private getUserMessage(): Promise<string> {
const userDisplay = this.getDisplayContent(PromptRole.USER);

return new Promise<string>((resolve) => {
reader.question(aiMessage, resolve);
reader.question(userDisplay, resolve);
});
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/core/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { BaseMessage } from '@langchain/core/messages';

import type Stream from './stream';

import { PROMPT } from '../const';
import { getPlatform, type Platform } from '../platform';
import { getEnvConfig } from './env';
Expand All @@ -22,15 +24,14 @@ export class AIGC {
this.platform = (env.platform as unknown as Platform) || options.platform;
}

public async chat(content: string, messages: BaseMessage[]) {
public chat(content: string, messages: BaseMessage[]): Stream {
const platform = getPlatform(this.platform);
const result = await platform.invoke(

return platform.stream(
'You are a helpful assistant. Answer all questions to the best of your ability.',
{ content, messages },
this.apiKey
);

return result;
}

public async detect(content: string): Promise<ReturnType<typeof getDetectResult>> {
Expand Down
23 changes: 23 additions & 0 deletions src/core/stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { Transform, type TransformCallback } from 'node:stream';

class Stream extends Transform {
private data = '';

getData(): Promise<string> {
return new Promise((resolve) => {
this.on('close', () => {
resolve(this.data);
});
});
}

_transform(chunk: Buffer, encoding: BufferEncoding, callback: TransformCallback): void {
const data = chunk.toString();

this.data += data;

callback(null, data);
}
}

export default Stream;
30 changes: 29 additions & 1 deletion src/platform/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ import { ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemp
import { ChatOpenAI } from '@langchain/openai';
import { LLMChain } from 'langchain/chains';

import Stream from '../core/stream';

type InvokeParameter = Parameters<InstanceType<typeof LLMChain>['invoke']>[0];

abstract class Platform {
protected temperature = 0.7;

protected getChatModel(apiKey?: string): BaseLanguageModel {
protected getChatModel(apiKey?: string, streaming = false): BaseLanguageModel {
return new ChatOpenAI({
apiKey,
configuration: {
baseURL: `https://${this.server}/v1`
},
frequencyPenalty: 1,
model: this.model,
streaming,
temperature: this.temperature
});
}
Expand All @@ -39,6 +42,31 @@ abstract class Platform {
return result.text;
}

public stream(prompt: string, params: InvokeParameter, apiKey?: string): Stream {
const promptTemplate = this.getPrompt(prompt);
const chain = new LLMChain({
llm: this.getChatModel(apiKey, true),
prompt: promptTemplate
});
const stream = new Stream();

chain
.invoke(params, {
callbacks: [
{
handleLLMNewToken(token: string) {
stream.write(token);
}
}
]
})
.then(() => {
stream.destroy();
});

return stream;
}

protected abstract model: string;

public abstract name: string;
Expand Down
3 changes: 2 additions & 1 deletion src/platform/minimax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ class MiniMax extends Platform {

protected server = 'api.minimax.chat';

protected getChatModel(apiKey?: string): BaseLanguageModel {
protected getChatModel(apiKey?: string, streaming = false): BaseLanguageModel {
return new ChatMinimax({
minimaxApiKey: apiKey,
minimaxGroupId: '1782658868262748274',
model: this.model,
streaming,
temperature: this.temperature
});
}
Expand Down
3 changes: 2 additions & 1 deletion src/platform/tongyi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ class TongYi extends Platform {

protected server = 'open.bigmodel.cn';

protected getChatModel(apiKey?: string): BaseLanguageModel {
protected getChatModel(apiKey?: string, streaming = false): BaseLanguageModel {
return new ChatAlibabaTongyi({
alibabaApiKey: apiKey,
model: this.model,
streaming,
temperature: this.temperature
});
}
Expand Down
3 changes: 2 additions & 1 deletion src/platform/zhipu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ class ZhiPu extends Platform {

protected server = 'open.bigmodel.cn';

protected getChatModel(apiKey?: string): BaseLanguageModel {
protected getChatModel(apiKey?: string, streaming?: boolean): BaseLanguageModel {
return new ChatZhipuAI({
model: this.model,
streaming,
temperature: this.temperature,
zhipuAIApiKey: apiKey
});
Expand Down

0 comments on commit f2e00ea

Please sign in to comment.