Skip to content

Commit

Permalink
Merge pull request #1068 from mendableai/nsc/llm-usage-extract
Browse files Browse the repository at this point in the history
(feat/extract) - LLMs usage analysis + billing
  • Loading branch information
nickscamara authored Jan 20, 2025
2 parents 34ad9ec + 02dea23 commit 406f28c
Show file tree
Hide file tree
Showing 16 changed files with 8,344 additions and 27 deletions.
16 changes: 11 additions & 5 deletions apps/api/src/controllers/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ export async function getACUC(
api_key: string,
cacheOnly = false,
useCache = true,
mode?: RateLimiterMode,
): Promise<AuthCreditUsageChunk | null> {
const cacheKeyACUC = `acuc_${api_key}`;
const cacheKeyACUC = `acuc_${api_key}_${mode}`;

if (useCache) {
const cachedACUC = await getValue(cacheKeyACUC);
Expand All @@ -93,9 +94,13 @@ export async function getACUC(
let retries = 0;
const maxRetries = 5;

let rpcName =
mode === RateLimiterMode.Extract || mode === RateLimiterMode.ExtractStatus
? "auth_credit_usage_chunk_extract"
: "auth_credit_usage_chunk_test_22_credit_pack_n_extract";
while (retries < maxRetries) {
({ data, error } = await supabase_service.rpc(
"auth_credit_usage_chunk_test_21_credit_pack",
rpcName,
{ input_key: api_key },
{ get: true },
));
Expand Down Expand Up @@ -127,8 +132,6 @@ export async function getACUC(
setCachedACUC(api_key, chunk);
}

// console.log(chunk);

return chunk;
} else {
return null;
Expand Down Expand Up @@ -203,7 +206,7 @@ export async function supaAuthenticateUser(
};
}

chunk = await getACUC(normalizedApi);
chunk = await getACUC(normalizedApi, false, true, mode);

if (chunk === null) {
return {
Expand Down Expand Up @@ -258,6 +261,9 @@ export async function supaAuthenticateUser(
subscriptionData.plan,
);
break;
case RateLimiterMode.ExtractStatus:
rateLimiter = getRateLimiter(RateLimiterMode.ExtractStatus, token);
break;
case RateLimiterMode.CrawlStatus:
rateLimiter = getRateLimiter(RateLimiterMode.CrawlStatus, token);
break;
Expand Down
1 change: 1 addition & 0 deletions apps/api/src/controllers/v1/extract-status.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ export async function extractStatusController(
error: extract?.error ?? undefined,
expiresAt: (await getExtractExpiry(req.params.jobId)).toISOString(),
steps: extract.showSteps ? extract.steps : undefined,
llmUsage: extract.showLLMUsage ? extract.llmUsage : undefined,
});
}
1 change: 1 addition & 0 deletions apps/api/src/controllers/v1/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export async function extractController(
createdAt: Date.now(),
status: "processing",
showSteps: req.body.__experimental_streamSteps,
showLLMUsage: req.body.__experimental_llmUsage,
});

if (Sentry.isInitialized()) {
Expand Down
10 changes: 10 additions & 0 deletions apps/api/src/controllers/v1/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ export const extractV1Options = z
origin: z.string().optional().default("api"),
urlTrace: z.boolean().default(false),
__experimental_streamSteps: z.boolean().default(false),
__experimental_llmUsage: z.boolean().default(false),
timeout: z.number().int().positive().finite().safe().default(60000),
})
.strict(strictMessage);
Expand Down Expand Up @@ -881,3 +882,12 @@ export type SearchResponse =
warning?: string;
data: Document[];
};


export type TokenUsage = {
promptTokens: number;
completionTokens: number;
totalTokens: number;
step?: string;
model?: string;
};
2 changes: 2 additions & 0 deletions apps/api/src/lib/extract/extract-redis.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ export type StoredExtract = {
error?: any;
showSteps?: boolean;
steps?: ExtractedStep[];
showLLMUsage?: boolean;
llmUsage?: number;
};

export async function saveExtract(id: string, extract: StoredExtract) {
Expand Down
76 changes: 64 additions & 12 deletions apps/api/src/lib/extract/extraction-service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
Document,
ExtractRequest,
TokenUsage,
toLegacyCrawlerOptions,
URLTrace,
} from "../../controllers/v1/types";
Expand Down Expand Up @@ -31,6 +32,7 @@ import { ExtractStep, updateExtract } from "./extract-redis";
import { deduplicateObjectsArray } from "./helpers/deduplicate-objs-array";
import { mergeNullValObjs } from "./helpers/merge-null-val-objs";
import { CUSTOM_U_TEAMS } from "./config";
import { calculateFinalResultCost, estimateCost, estimateTotalCost } from "./usage/llm-cost";

interface ExtractServiceOptions {
request: ExtractRequest;
Expand All @@ -46,6 +48,9 @@ interface ExtractResult {
warning?: string;
urlTrace?: URLTrace[];
error?: string;
tokenUsageBreakdown?: TokenUsage[];
llmUsage?: number;
totalUrlsScraped?: number;
}

async function analyzeSchemaAndPrompt(
Expand All @@ -57,6 +62,7 @@ async function analyzeSchemaAndPrompt(
multiEntityKeys: string[];
reasoning?: string;
keyIndicators?: string[];
tokenUsage: TokenUsage;
}> {
if (!schema) {
schema = await generateSchemaFromPrompt(prompt);
Expand All @@ -71,8 +77,10 @@ async function analyzeSchemaAndPrompt(
keyIndicators: z.array(z.string()),
});

const model = "gpt-4o";

const result = await openai.beta.chat.completions.parse({
model: "gpt-4o",
model: model,
messages: [
{
role: "system",
Expand Down Expand Up @@ -131,12 +139,20 @@ Schema: ${schemaString}\nPrompt: ${prompt}\nRelevant URLs: ${urls}`,

const { isMultiEntity, multiEntityKeys, reasoning, keyIndicators } =
checkSchema.parse(result.choices[0].message.parsed);
return { isMultiEntity, multiEntityKeys, reasoning, keyIndicators };

const tokenUsage: TokenUsage = {
promptTokens: result.usage?.prompt_tokens ?? 0,
completionTokens: result.usage?.completion_tokens ?? 0,
totalTokens: result.usage?.total_tokens ?? 0,
model: model,
};
return { isMultiEntity, multiEntityKeys, reasoning, keyIndicators, tokenUsage };
}

type completions = {
extract: Record<string, any>;
numTokens: number;
totalUsage: TokenUsage;
warning?: string;
};

Expand All @@ -163,6 +179,11 @@ export async function performExtraction(
let multiEntityCompletions: completions[] = [];
let multiEntityResult: any = {};
let singleAnswerResult: any = {};
let totalUrlsScraped = 0;


// Token tracking
let tokenUsage: TokenUsage[] = [];

await updateExtract(extractId, {
status: "processing",
Expand Down Expand Up @@ -219,6 +240,7 @@ export async function performExtraction(
"No valid URLs found to scrape. Try adjusting your search criteria or including more URLs.",
extractId,
urlTrace: urlTraces,
totalUrlsScraped: 0
};
}

Expand Down Expand Up @@ -249,9 +271,12 @@ export async function performExtraction(
// 1. the first one is a completion that will extract the array of items
// 2. the second one is multiple completions that will extract the items from the array
let startAnalyze = Date.now();
const { isMultiEntity, multiEntityKeys, reasoning, keyIndicators } =
const { isMultiEntity, multiEntityKeys, reasoning, keyIndicators, tokenUsage: schemaAnalysisTokenUsage } =
await analyzeSchemaAndPrompt(links, reqSchema, request.prompt ?? "");

// Track schema analysis tokens
tokenUsage.push(schemaAnalysisTokenUsage);

// console.log("\nIs Multi Entity:", isMultiEntity);
// console.log("\nMulti Entity Keys:", multiEntityKeys);
// console.log("\nReasoning:", reasoning);
Expand Down Expand Up @@ -312,6 +337,8 @@ export async function performExtraction(
(doc): doc is Document => doc !== null,
);

totalUrlsScraped += multyEntityDocs.length;

let endScrape = Date.now();

await updateExtract(extractId, {
Expand Down Expand Up @@ -376,6 +403,8 @@ export async function performExtraction(
true,
);

tokenUsage.push(shouldExtractCheck.totalUsage);

if (!shouldExtractCheck.extract["extract"]) {
console.log(
`Skipping extraction for ${doc.metadata.url} as content is irrelevant`,
Expand Down Expand Up @@ -438,6 +467,11 @@ export async function performExtraction(
timeoutPromise,
])) as Awaited<ReturnType<typeof generateOpenAICompletions>>;

// Track multi-entity extraction tokens
if (multiEntityCompletion) {
tokenUsage.push(multiEntityCompletion.totalUsage);
}

// console.log(multiEntityCompletion.extract)
// if (!multiEntityCompletion.extract?.is_content_relevant) {
// console.log(`Skipping extraction for ${doc.metadata.url} as content is not relevant`);
Expand Down Expand Up @@ -500,6 +534,7 @@ export async function performExtraction(
"An unexpected error occurred. Please contact [email protected] for help.",
extractId,
urlTrace: urlTraces,
totalUrlsScraped
};
}
}
Expand Down Expand Up @@ -551,15 +586,17 @@ export async function performExtraction(
}
}

singleAnswerDocs.push(
...results.filter((doc): doc is Document => doc !== null),
);
const validResults = results.filter((doc): doc is Document => doc !== null);
singleAnswerDocs.push(...validResults);
totalUrlsScraped += validResults.length;

} catch (error) {
return {
success: false,
error: error.message,
extractId,
urlTrace: urlTraces,
totalUrlsScraped
};
}

Expand All @@ -571,6 +608,7 @@ export async function performExtraction(
"All provided URLs are invalid. Please check your input and try again.",
extractId,
urlTrace: request.urlTrace ? urlTraces : undefined,
totalUrlsScraped: 0
};
}

Expand Down Expand Up @@ -603,6 +641,11 @@ export async function performExtraction(
true,
);

// Track single answer extraction tokens
if (singleAnswerCompletions) {
tokenUsage.push(singleAnswerCompletions.totalUsage);
}

singleAnswerResult = singleAnswerCompletions.extract;

// Update token usage in traces
Expand All @@ -629,19 +672,24 @@ export async function performExtraction(
? await mixSchemaObjects(reqSchema, singleAnswerResult, multiEntityResult)
: singleAnswerResult || multiEntityResult;

let linksBilled = links.length * 5;

const totalTokensUsed = tokenUsage.reduce((a, b) => a + b.totalTokens, 0);
const llmUsage = estimateTotalCost(tokenUsage);
let tokensToBill = calculateFinalResultCost(finalResult);


if (CUSTOM_U_TEAMS.includes(teamId)) {
linksBilled = 1;
tokensToBill = 1;
}
// Bill team for usage
billTeam(teamId, subId, linksBilled).catch((error) => {
billTeam(teamId, subId, tokensToBill, logger, true).catch((error) => {
logger.error(
`Failed to bill team ${teamId} for ${linksBilled} credits: ${error}`,
`Failed to bill team ${teamId} for ${tokensToBill} tokens: ${error}`,
);
});

// Log job

// Log job with token usage
logJob({
job_id: extractId,
success: true,
Expand All @@ -654,10 +702,12 @@ export async function performExtraction(
url: request.urls.join(", "),
scrapeOptions: request,
origin: request.origin ?? "api",
num_tokens: 0, // completions?.numTokens ?? 0,
num_tokens: totalTokensUsed,
tokens_billed: tokensToBill,
}).then(() => {
updateExtract(extractId, {
status: "completed",
llmUsage,
}).catch((error) => {
logger.error(
`Failed to update extract ${extractId} status to completed: ${error}`,
Expand All @@ -671,5 +721,7 @@ export async function performExtraction(
extractId,
warning: undefined, // TODO FIX
urlTrace: request.urlTrace ? urlTraces : undefined,
llmUsage,
totalUrlsScraped
};
}
15 changes: 12 additions & 3 deletions apps/api/src/lib/extract/reranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,21 @@ function filterAndProcessLinks(
);
}

export type RerankerResult = {
mapDocument: MapDocument[];
tokensUsed: number;
}

export async function rerankLinksWithLLM(
mappedLinks: MapDocument[],
searchQuery: string,
urlTraces: URLTrace[],
): Promise<MapDocument[]> {
): Promise<RerankerResult> {
const chunkSize = 100;
const chunks: MapDocument[][] = [];
const TIMEOUT_MS = 20000;
const MAX_RETRIES = 2;
let totalTokensUsed = 0;

// Split mappedLinks into chunks of 200
for (let i = 0; i < mappedLinks.length; i += chunkSize) {
Expand Down Expand Up @@ -225,6 +230,7 @@ export async function rerankLinksWithLLM(
return [];
}

totalTokensUsed += completion.numTokens || 0;
// console.log(`Chunk ${chunkIndex + 1}: Found ${completion.extract.relevantLinks.length} relevant links`);
return completion.extract.relevantLinks;

Expand Down Expand Up @@ -252,5 +258,8 @@ export async function rerankLinksWithLLM(
.filter((link): link is MapDocument => link !== undefined);

// console.log(`Returning ${relevantLinks.length} relevant links`);
return relevantLinks;
}
return {
mapDocument: relevantLinks,
tokensUsed: totalTokensUsed,
};
}
8 changes: 6 additions & 2 deletions apps/api/src/lib/extract/url-processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,19 @@ export async function processUrl(
// (link, index) => `${index + 1}. URL: ${link.url}, Title: ${link.title}, Description: ${link.description}`
// );

mappedLinks = await rerankLinksWithLLM(mappedLinks, searchQuery, urlTraces);
const rerankerResult = await rerankLinksWithLLM(mappedLinks, searchQuery, urlTraces);
mappedLinks = rerankerResult.mapDocument;
let tokensUsed = rerankerResult.tokensUsed;

// 2nd Pass, useful for when the first pass returns too many links
if (mappedLinks.length > 100) {
mappedLinks = await rerankLinksWithLLM(
const rerankerResult = await rerankLinksWithLLM(
mappedLinks,
searchQuery,
urlTraces,
);
mappedLinks = rerankerResult.mapDocument;
tokensUsed += rerankerResult.tokensUsed;
}

// dumpToFile(
Expand Down
Loading

0 comments on commit 406f28c

Please sign in to comment.