Skip to content

Commit

Permalink
fix: bad perf of follow-up
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Feb 20, 2025
1 parent 5d0fb6f commit 8b2f073
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 59 deletions.
77 changes: 20 additions & 57 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ function getSchema(allowReflect: boolean, allowRead: boolean, allowAnswer: boole


function getPrompt(
question: string,
context?: string[],
allChatHistory?: { user: string, assistant: string }[],
allQuestions?: string[],
allKeywords?: string[],
allowReflect: boolean = true,
Expand All @@ -103,29 +101,9 @@ function getPrompt(
// Add header section
sections.push(`Current date: ${new Date().toUTCString()}
You are an advanced AI research agent from Jina AI. You are specialized in multistep reasoning. Using your training data and prior lessons learned, answer the following question with absolute certainty:
<question>
${question}
</question>
You are an advanced AI research agent from Jina AI. You are specialized in multistep reasoning. Using your training data and prior lessons learned, answer the user question with absolute certainty.
`);

// add chat history if exist
if (allChatHistory?.length) {
sections.push(`
You have conducted the following chat history with the user:
<chat-history>
${allChatHistory.map(({user, assistant}) => `
<user>
${user}
</user>
<you>
${assistant}
</you>
`).join('\n')}
</chat-history>`);
}

// Add knowledge section if exists
if (knowledge?.length) {
const knowledgeItems = knowledge
Expand Down Expand Up @@ -317,11 +295,11 @@ function removeHTMLtags(text: string) {
}


export async function getResponse(question: string,
export async function getResponse(question?: string,
tokenBudget: number = 1_000_000,
maxBadAttempts: number = 3,
existingContext?: Partial<TrackerContext>,
historyMessages?: Array<CoreAssistantMessage | CoreUserMessage>
messages?: Array<CoreAssistantMessage | CoreUserMessage>
): Promise<{ result: StepAction; context: TrackerContext }> {
const context: TrackerContext = {
tokenTracker: existingContext?.tokenTracker || new TokenTracker(tokenBudget),
Expand All @@ -331,29 +309,16 @@ export async function getResponse(question: string,
let totalStep = 0;
let badAttempts = 0;
let schema: ZodObject<any> = getSchema(true, true, true, true, true)
question = question.trim()
question = question?.trim() as string;
if (messages && messages.length > 0) {
question = (messages[messages.length - 1]?.content as string).trim();
} else {
messages = [{role: 'user', content: question.trim()}]
}
const gaps: string[] = [question]; // All questions to be answered including the orginal question
const allQuestions = [question];
const allKeywords = [];
const allChatHistory: { user: string, assistant: string }[] = [];
const allKnowledge: KnowledgeItem[] = []; // knowledge are intermedidate questions that are answered
// iterate over historyMessages
// if role is user and content is question, add to allQuestions, the next assistant content should be the answer
// put this pair to the allKnowledge
historyMessages?.forEach((message, i) => {
if (message.role === 'user' && message.content && historyMessages[i + 1]?.role === 'assistant') {
allQuestions.push(message.content as string)
const answerContent = (historyMessages[i + 1]?.content || '') as string;
// Remove <think></think> tags and their content using regex
const cleanedAnswer = answerContent.replace(/<think>[\s\S]*?<\/think>/g, '').trim();
if (cleanedAnswer) {
allChatHistory.push({
user: message.content as string,
assistant: cleanedAnswer,
});
}
}
})

const badContext = [];
let diaryContext = [];
Expand All @@ -362,7 +327,7 @@ export async function getResponse(question: string,
let allowRead = true;
let allowReflect = true;
let allowCoding = true;
let prompt = '';
let system = '';
let thisStep: StepAction = {action: 'answer', answer: '', references: [], think: '', isFinal: false};

const allURLs: Record<string, string> = {};
Expand All @@ -377,7 +342,7 @@ export async function getResponse(question: string,
console.log(`Step ${totalStep} / Budget used ${budgetPercentage}%`);
console.log('Gaps:', gaps);
allowReflect = allowReflect && (gaps.length <= 1);
const currentQuestion = gaps.length > 0 ? gaps.shift()! : question
const currentQuestion: string = gaps.length > 0 ? gaps.shift()! : question
if (!evaluationMetrics[currentQuestion]) {
evaluationMetrics[currentQuestion] = await evaluateQuestion(currentQuestion, context.tokenTracker)
}
Expand All @@ -387,10 +352,8 @@ export async function getResponse(question: string,
allowSearch = allowSearch && (Object.keys(allURLs).length < 50); // disable search when too many urls already

// generate prompt for this step
prompt = getPrompt(
currentQuestion,
system = getPrompt(
diaryContext,
allChatHistory,
allQuestions,
allKeywords,
allowReflect,
Expand All @@ -409,7 +372,8 @@ export async function getResponse(question: string,
const result = await generator.generateObject({
model: 'agent',
schema,
prompt,
system,
messages,
});
thisStep = result.object as StepAction;
// print allowed and chose action
Expand Down Expand Up @@ -780,19 +744,17 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
}


await storeContext(prompt, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
await storeContext(system, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
}

await storeContext(prompt, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
await storeContext(system, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
if (!(thisStep as AnswerAction).isFinal) {
console.log('Enter Beast mode!!!')
// any answer is better than no answer, humanity last resort
step++;
totalStep++;
const prompt = getPrompt(
question,
system = getPrompt(
diaryContext,
allChatHistory,
allQuestions,
allKeywords,
false,
Expand All @@ -812,15 +774,16 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
const result = await generator.generateObject({
model: 'agentBeastMode',
schema,
prompt,
system,
messages
});
thisStep = result.object as AnswerAction;
(thisStep as AnswerAction).isFinal = true;
context.actionTracker.trackAction({totalStep, thisStep, gaps, badAttempts});
}
console.log(thisStep)

await storeContext(prompt, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
await storeContext(system, schema, [allContext, allKeywords, allQuestions, allKnowledge], totalStep);
return {result: thisStep, context};

}
Expand Down
8 changes: 7 additions & 1 deletion src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,12 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
return res.status(400).json({error: 'Last message must be from user'});
}

// clean <think> from all assistant messages
body.messages?.filter(message => message.role === 'assistant').forEach(message => {
message.content = (message.content as string).replace(/<think>[\s\S]*?<\/think>/g, '').trim();
});
console.log('messages', body.messages);

const {tokenBudget, maxBadAttempts} = getTokenBudgetAndMaxAttempts(
body.reasoning_effort,
body.max_completion_tokens
Expand Down Expand Up @@ -566,7 +572,7 @@ app.post('/v1/chat/completions', (async (req: Request, res: Response) => {
}

try {
const {result: finalStep} = await getResponse(lastMessage.content as string, tokenBudget, maxBadAttempts, context, body.messages)
const {result: finalStep} = await getResponse(undefined, tokenBudget, maxBadAttempts, context, body.messages)

const usage = context.tokenTracker.getTotalUsageSnakeCase();
if (body.stream) {
Expand Down
8 changes: 7 additions & 1 deletion src/utils/safe-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ interface GenerateObjectResult<T> {
interface GenerateOptions<T> {
model: ToolName;
schema: z.ZodType<T>;
prompt: string;
prompt?: string;
system?:string;
messages?: any;
}

export class ObjectGeneratorSafe {
Expand All @@ -26,6 +28,8 @@ export class ObjectGeneratorSafe {
model,
schema,
prompt,
system,
messages,
} = options;

try {
Expand All @@ -34,6 +38,8 @@ export class ObjectGeneratorSafe {
model: getModel(model),
schema,
prompt,
system,
messages,
maxTokens: getToolConfig(model).maxTokens,
temperature: getToolConfig(model).temperature,
});
Expand Down

0 comments on commit 8b2f073

Please sign in to comment.