diff --git a/package.json b/package.json index 79d32e07..f973c50a 100644 --- a/package.json +++ b/package.json @@ -119,7 +119,7 @@ }, "temperature": { "type": "number", - "description": "Temperature of the OpenAI model.", + "description": "Temperature of the OpenAI model. Should be in range [0, 2], otherwise an error will be produced.", "default": 1 }, "apiKey": { @@ -306,7 +306,7 @@ }, "temperature": { "type": "number", - "description": "Temperature of the LM Studio model.", + "description": "Temperature of the LM Studio model. Should be in range [0, 2], otherwise an error will be produced.", "default": 1 }, "port": { @@ -376,6 +376,93 @@ "markdownDescription": "List of configurations that fetch completions from a locally running LLM inside [LM Studio](https://lmstudio.ai).", "order": 3 }, + "coqpilot.deepSeekModelsParameters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "modelId": { + "type": "string", + "markdownDescription": "Unique identifier of this model to distinguish it from others. Could be any string.", + "default": "deep-seek-v3" + }, + "modelName": { + "type": "string", + "markdownDescription": "Model to use from the DeepSeek public API. List of models known to CoqPilot: \n * deepseek-chat \n * deepseek-reasoner", + "default": "deepseek-chat" + }, + "temperature": { + "type": "number", + "description": "Temperature of the DeepSeek model. Should be in range [0, 2], otherwise an error will be produced.", + "default": 1 + }, + "apiKey": { + "type": "string", + "description": "Api key to communicate with the DeepSeek api. You can get one [here](https://platform.deepseek.com/api_keys).", + "default": "None" + }, + "choices": { + "type": "number", + "description": "Number of attempts to generate proof for one hole with this model. All attempts are made as a single request, so this parameter should not have a significant impact on performance. However, more choices mean more tokens spent on generation.", + "default": 15 + }, + "systemPrompt": { + "type": "string", + "description": "Prompt for the DeepSeek model to begin a chat with. It is sent as a system message, which means it has more impact than other messages.", + "default": "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'." + }, + "maxTokensToGenerate": { + "type": "number", + "description": "Number of tokens that the model is allowed to generate as a response message (i.e. message with proof).", + "default": 2048 + }, + "tokensLimit": { + "type": "number", + "description": "Total length of input and generated tokens, it is determined by the model.", + "default": 4096 + }, + "maxContextTheoremsNumber": { + "type": "number", + "description": "Maximum number of context theorems to include in the prompt sent to the DeepSeek model as examples for proof generation. Lower values reduce token usage but may decrease the likelihood of generating correct proofs.", + "default": 100 + }, + "multiroundProfile": { + "type": "object", + "properties": { + "maxRoundsNumber": { + "type": "number", + "description": "Maximum number of rounds to generate and further fix the proof. Default value is 1, which means each proof will be only generated, but not fixed.", + "default": 1 + }, + "proofFixChoices": { + "type": "number", + "description": "Number of attempts to generate a proof fix for each proof in one round. Warning: increasing `proofFixChoices` can lead to exponential growth in generation requests if `maxRoundsNumber` is relatively large.", + "default": 1 + }, + "proofFixPrompt": { + "type": "string", + "description": "Prompt for the proof-fix request that will be sent as a user chat message in response to an incorrect proof. It may include the `${diagnostic}` substring, which will be replaced by the actual compiler diagnostic.", + "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." + }, + "maxPreviousProofVersionsNumber": { + "type": "number", + "description": "Maximum number of previous proof versions to include in the proof-fix chat, each presented as a dialogue: the user's diagnostic followed by the assistant's corresponding proof attempt. The most recent proof version being fixed is always included and is not affected by this parameter.", + "default": 100 + } + }, + "default": { + "maxRoundsNumber": 1, + "proofFixChoices": 1, + "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof.", + "maxPreviousProofVersionsNumber": 100 + } + } + } + }, + "default": [], + "markdownDescription": "List of configurations for DeepSeek models. Each configuration will be fetched for completions independently in the order they are listed.", + "order": 4 + }, "coqpilot.contextTheoremsRankerType": { "type": "string", "enum": [ @@ -390,7 +477,7 @@ ], "description": "Context of the LLM is limited. Usually not all theorems from the file may be used in the completion request. This parameter defines the way theorems are selected for the completion.", "default": "distance", - "order": 4 + "order": 5 }, "coqpilot.loggingVerbosity": { "type": "string", @@ -404,13 +491,13 @@ ], "description": "The verbosity of the logs.", "default": "info", - "order": 5 + "order": 6 }, "coqpilot.coqLspServerPath": { "type": "string", "description": "Path to the Coq LSP server. If not specified, CoqPilot will try to find the server automatically at the default location: coq-lsp at PATH.", "default": "coq-lsp", - "order": 6 + "order": 7 } } } diff --git a/src/benchmark/framework/experiment/setupDSL/benchmarkingBundleBuilder.ts b/src/benchmark/framework/experiment/setupDSL/benchmarkingBundleBuilder.ts index da9910ac..f7a8a210 100644 --- a/src/benchmark/framework/experiment/setupDSL/benchmarkingBundleBuilder.ts +++ b/src/benchmark/framework/experiment/setupDSL/benchmarkingBundleBuilder.ts @@ -10,7 +10,8 @@ export type LLMServiceStringIdentifier = | "predefined" | "openai" | "grazie" - | "lmstudio"; + | "lmstudio" + | "deepseek"; export type CorrespondingInputParams = T extends "predefined" @@ -21,7 +22,9 @@ export type CorrespondingInputParams = ? InputBenchmarkingModelParams.GrazieParams : T extends "lmstudio" ? InputBenchmarkingModelParams.LMStudioParams - : never; + : T extends "deepseek" + ? InputBenchmarkingModelParams.DeepSeekParams + : never; export class BenchmarkingBundle { constructor() {} @@ -46,6 +49,8 @@ export class BenchmarkingBundle { return LLMServiceIdentifier.GRAZIE; case "lmstudio": return LLMServiceIdentifier.LMSTUDIO; + case "deepseek": + return LLMServiceIdentifier.DEEPSEEK; } } } diff --git a/src/benchmark/framework/structures/common/llmServiceIdentifier.ts b/src/benchmark/framework/structures/common/llmServiceIdentifier.ts index a22c3bc7..b5a4bdae 100644 --- a/src/benchmark/framework/structures/common/llmServiceIdentifier.ts +++ b/src/benchmark/framework/structures/common/llmServiceIdentifier.ts @@ -3,4 +3,5 @@ export enum LLMServiceIdentifier { OPENAI = "Open AI", GRAZIE = "Grazie", LMSTUDIO = "LM Studio", + DEEPSEEK = "DeepSeek", } diff --git a/src/benchmark/framework/structures/inputParameters/inputBenchmarkingModelParams.ts b/src/benchmark/framework/structures/inputParameters/inputBenchmarkingModelParams.ts index aa14ffc8..455c39e8 100644 --- a/src/benchmark/framework/structures/inputParameters/inputBenchmarkingModelParams.ts +++ b/src/benchmark/framework/structures/inputParameters/inputBenchmarkingModelParams.ts @@ -1,4 +1,5 @@ import { + DeepSeekUserModelParams, GrazieUserModelParams, LMStudioUserModelParams, OpenAiUserModelParams, @@ -22,4 +23,6 @@ export namespace InputBenchmarkingModelParams { export interface GrazieParams extends GrazieUserModelParams, Params {} export interface LMStudioParams extends LMStudioUserModelParams, Params {} + + export interface DeepSeekParams extends DeepSeekUserModelParams, Params {} } diff --git a/src/benchmark/framework/utils/commonStructuresUtils/llmServicesUtils.ts b/src/benchmark/framework/utils/commonStructuresUtils/llmServicesUtils.ts index a147b24b..acfdac48 100644 --- a/src/benchmark/framework/utils/commonStructuresUtils/llmServicesUtils.ts +++ b/src/benchmark/framework/utils/commonStructuresUtils/llmServicesUtils.ts @@ -1,4 +1,6 @@ import { ErrorsHandlingMode } from "../../../../llm/llmServices/commonStructures/errorsHandlingMode"; +import { DeepSeekModelParamsResolver } from "../../../../llm/llmServices/deepSeek/deepSeekModelParamsResolver"; +import { DeepSeekService } from "../../../../llm/llmServices/deepSeek/deepSeekService"; import { GrazieModelParamsResolver } from "../../../../llm/llmServices/grazie/grazieModelParamsResolver"; import { GrazieService } from "../../../../llm/llmServices/grazie/grazieService"; import { LLMService } from "../../../../llm/llmServices/llmService"; @@ -29,6 +31,8 @@ export function getShortName(identifier: LLMServiceIdentifier): string { return "Grazie"; case LLMServiceIdentifier.LMSTUDIO: return "LM Studio"; + case LLMServiceIdentifier.DEEPSEEK: + return "DeepSeek"; } } @@ -53,6 +57,9 @@ export function selectLLMServiceBuilder( case LLMServiceIdentifier.LMSTUDIO: return (eventLogger, errorsHandlingMode) => new LMStudioService(eventLogger, errorsHandlingMode); + case LLMServiceIdentifier.DEEPSEEK: + return (eventLogger, errorsHandlingMode) => + new DeepSeekService(eventLogger, errorsHandlingMode); } } @@ -61,6 +68,7 @@ export interface LLMServicesParamsResolvers { openAiModelParamsResolver: OpenAiModelParamsResolver; grazieModelParamsResolver: GrazieModelParamsResolver; lmStudioModelParamsResolver: LMStudioModelParamsResolver; + deepSeekModelParamsResolver: DeepSeekModelParamsResolver; } export function createParamsResolvers(): LLMServicesParamsResolvers { @@ -70,6 +78,7 @@ export function createParamsResolvers(): LLMServicesParamsResolvers { openAiModelParamsResolver: new OpenAiModelParamsResolver(), grazieModelParamsResolver: new GrazieModelParamsResolver(), lmStudioModelParamsResolver: new LMStudioModelParamsResolver(), + deepSeekModelParamsResolver: new DeepSeekModelParamsResolver(), }; } @@ -86,5 +95,7 @@ export function getParamsResolver( return paramsResolvers.grazieModelParamsResolver; case LLMServiceIdentifier.LMSTUDIO: return paramsResolvers.lmStudioModelParamsResolver; + case LLMServiceIdentifier.DEEPSEEK: + return paramsResolvers.deepSeekModelParamsResolver; } } diff --git a/src/extension/pluginContext.ts b/src/extension/pluginContext.ts index 3bb45e88..98a1b689 100644 --- a/src/extension/pluginContext.ts +++ b/src/extension/pluginContext.ts @@ -5,6 +5,7 @@ import { Disposable, WorkspaceConfiguration, window, workspace } from "vscode"; import { LLMServices, disposeServices } from "../llm/llmServices"; import { ErrorsHandlingMode } from "../llm/llmServices/commonStructures/errorsHandlingMode"; +import { DeepSeekService } from "../llm/llmServices/deepSeek/deepSeekService"; import { GrazieService } from "../llm/llmServices/grazie/grazieService"; import { LMStudioService } from "../llm/llmServices/lmStudio/lmStudioService"; import { OpenAiService } from "../llm/llmServices/openai/openAiService"; @@ -67,6 +68,12 @@ export class PluginContext implements Disposable { path.join(this.llmServicesLogsDir, "lmstudio-logs.txt"), this.llmServicesSetup.debugLogs ), + deepSeekService: new DeepSeekService( + this.llmServicesSetup.eventLogger, + this.llmServicesSetup.errorsHandlingMode, + path.join(this.llmServicesLogsDir, "deepseek-logs.txt"), + this.llmServicesSetup.debugLogs + ), }; private parseLoggingVerbosity(config: WorkspaceConfiguration): Severity { diff --git a/src/extension/settings/configReaders.ts b/src/extension/settings/configReaders.ts index 205c1576..1e7b3a8c 100644 --- a/src/extension/settings/configReaders.ts +++ b/src/extension/settings/configReaders.ts @@ -6,11 +6,13 @@ import { LLMService } from "../../llm/llmServices/llmService"; import { ModelParams, ModelsParams } from "../../llm/llmServices/modelParams"; import { SingleParamResolutionResult } from "../../llm/llmServices/utils/paramsResolvers/abstractResolvers"; import { + DeepSeekUserModelParams, GrazieUserModelParams, LMStudioUserModelParams, OpenAiUserModelParams, PredefinedProofsUserModelParams, UserModelParams, + deepSeekUserModelParamsSchema, grazieUserModelParamsSchema, lmStudioUserModelParamsSchema, openAiUserModelParamsSchema, @@ -108,14 +110,27 @@ export function readAndValidateUserModelsParams( jsonSchemaValidator ) ); + const deepSeekUserParams: DeepSeekUserModelParams[] = + config.deepSeekModelsParameters.map((params: any) => + validateAndParseJson( + params, + deepSeekUserModelParamsSchema, + jsonSchemaValidator + ) + ); validateIdsAreUnique([ ...predefinedProofsUserParams, ...openAiUserParams, ...grazieUserParams, ...lmStudioUserParams, + ...deepSeekUserParams, ]); - validateApiKeysAreProvided(openAiUserParams, grazieUserParams); + validateApiKeysAreProvided( + openAiUserParams, + grazieUserParams, + deepSeekUserParams + ); const modelsParams: ModelsParams = { predefinedProofsModelParams: resolveParamsAndShowResolutionLogs( @@ -134,6 +149,10 @@ export function readAndValidateUserModelsParams( llmServices.lmStudioService, lmStudioUserParams ), + deepSeekParams: resolveParamsAndShowResolutionLogs( + llmServices.deepSeekService, + deepSeekUserParams + ), }; validateModelsArePresent([ @@ -198,7 +217,8 @@ function validateIdsAreUnique(allModels: UserModelParams[]) { function validateApiKeysAreProvided( openAiUserParams: OpenAiUserModelParams[], - grazieUserParams: GrazieUserModelParams[] + grazieUserParams: GrazieUserModelParams[], + deepSeekUserParams: DeepSeekUserModelParams[] ) { const buildApiKeyError = ( serviceName: string, @@ -218,6 +238,9 @@ function validateApiKeysAreProvided( if (grazieUserParams.some((params) => params.apiKey === "None")) { throw buildApiKeyError("Grazie", "grazie"); } + if (deepSeekUserParams.some((params) => params.apiKey === "None")) { + throw buildApiKeyError("Deep Seek", "deepSeek"); + } } function validateModelsArePresent(allModels: T[]) { diff --git a/src/extension/settings/settingsValidationError.ts b/src/extension/settings/settingsValidationError.ts index 243297c3..3749a7e3 100644 --- a/src/extension/settings/settingsValidationError.ts +++ b/src/extension/settings/settingsValidationError.ts @@ -34,7 +34,8 @@ export function toSettingName(llmService: LLMService): string { () => "predefinedProofs", () => "openAi", () => "grazie", - () => "lmStudio" + () => "lmStudio", + () => "deepSeek" ); return `${pluginId}.${serviceNameInSettings}ModelsParameters`; } diff --git a/src/llm/llmIterator.ts b/src/llm/llmIterator.ts index 58699069..ffa71122 100644 --- a/src/llm/llmIterator.ts +++ b/src/llm/llmIterator.ts @@ -36,6 +36,7 @@ export class LLMSequentialIterator ); } + // TODO: Implement a smarter way of ordering the services private createHooks( proofGenerationContext: ProofGenerationContext, modelsParams: ModelsParams, @@ -48,6 +49,16 @@ export class LLMSequentialIterator services.predefinedProofsService, "predefined-proofs" ), + // Here DeepSeek service is reordered to the beginning + // of the list, due to it's strong performance and + // low costs. Refer to discussion: + // https://github.com/JetBrains-Research/coqpilot/pull/56#discussion_r1935180516 + ...this.createLLMServiceHooks( + proofGenerationContext, + modelsParams.deepSeekParams, + services.deepSeekService, + "deepseek" + ), ...this.createLLMServiceHooks( proofGenerationContext, modelsParams.openAiParams, diff --git a/src/llm/llmServices.ts b/src/llm/llmServices.ts index 9e974ae3..d54ccdc6 100644 --- a/src/llm/llmServices.ts +++ b/src/llm/llmServices.ts @@ -1,5 +1,6 @@ import { illegalState } from "../utils/throwErrors"; +import { DeepSeekService } from "./llmServices/deepSeek/deepSeekService"; import { GrazieService } from "./llmServices/grazie/grazieService"; import { LLMService } from "./llmServices/llmService"; import { LMStudioService } from "./llmServices/lmStudio/lmStudioService"; @@ -13,6 +14,7 @@ export interface LLMServices { openAiService: OpenAiService; grazieService: GrazieService; lmStudioService: LMStudioService; + deepSeekService: DeepSeekService; } export function disposeServices(llmServices: LLMServices) { @@ -27,6 +29,7 @@ export function asLLMServices( llmServices.openAiService, llmServices.grazieService, llmServices.lmStudioService, + llmServices.deepSeekService, ]; } @@ -35,7 +38,8 @@ export function switchByLLMServiceType( onPredefinedProofsService: () => T, onOpenAiService: () => T, onGrazieService: () => T, - onLMStudioService: () => T + onLMStudioService: () => T, + onDeepSeekService: () => T ): T { if (llmService instanceof PredefinedProofsService) { return onPredefinedProofsService(); @@ -45,6 +49,8 @@ export function switchByLLMServiceType( return onGrazieService(); } else if (llmService instanceof LMStudioService) { return onLMStudioService(); + } else if (llmService instanceof DeepSeekService) { + return onDeepSeekService(); } else { illegalState( `switch by unknown \`LLMService\`: "${llmService.serviceName}"` diff --git a/src/llm/llmServices/deepSeek/deepSeekModelParamsResolver.ts b/src/llm/llmServices/deepSeek/deepSeekModelParamsResolver.ts new file mode 100644 index 00000000..e7efbbf1 --- /dev/null +++ b/src/llm/llmServices/deepSeek/deepSeekModelParamsResolver.ts @@ -0,0 +1,40 @@ +import { DeepSeekUserModelParams } from "../../userModelParams"; +import { DeepSeekModelParams, deepSeekModelParamsSchema } from "../modelParams"; +import { BasicModelParamsResolver } from "../utils/paramsResolvers/basicModelParamsResolvers"; +import { ValidParamsResolverImpl } from "../utils/paramsResolvers/paramsResolverImpl"; + +export class DeepSeekModelParamsResolver + extends BasicModelParamsResolver< + DeepSeekUserModelParams, + DeepSeekModelParams + > + implements + ValidParamsResolverImpl +{ + constructor() { + super(deepSeekModelParamsSchema, "DeepSeekModelParams"); + } + + readonly modelName = this.resolveParam("modelName") + .requiredToBeConfigured() + .validate([ + (value) => + DeepSeekModelParamsResolver.allowedModels.includes(value), + `be one of the allowed models: ${DeepSeekModelParamsResolver.allowedModels.join( + ", " + )}`, + ]); + + readonly temperature = this.resolveParam("temperature") + .requiredToBeConfigured() + .validate([ + (value) => value >= 0 && value <= 2, + "be in range between 0 and 2", + ]); + + readonly apiKey = this.resolveParam("apiKey") + .requiredToBeConfigured() + .validateAtRuntimeOnly(); + + static readonly allowedModels = ["deepseek-chat", "deepseek-reasoner"]; +} diff --git a/src/llm/llmServices/deepSeek/deepSeekService.ts b/src/llm/llmServices/deepSeek/deepSeekService.ts new file mode 100644 index 00000000..efb004a5 --- /dev/null +++ b/src/llm/llmServices/deepSeek/deepSeekService.ts @@ -0,0 +1,162 @@ +import OpenAI from "openai"; + +import { illegalState } from "../../../utils/throwErrors"; +import { ProofGenerationContext } from "../../proofGenerationContext"; +import { DeepSeekUserModelParams } from "../../userModelParams"; +import { AnalyzedChatHistory, ChatHistory } from "../commonStructures/chat"; +import { + GeneratedRawContent, + GeneratedRawContentItem, +} from "../commonStructures/generatedRawContent"; +import { ProofVersion } from "../commonStructures/proofVersion"; +import { GeneratedProofImpl } from "../generatedProof"; +import { LLMServiceImpl } from "../llmService"; +import { LLMServiceInternal } from "../llmServiceInternal"; +import { DeepSeekModelParams } from "../modelParams"; +import { toO1CompatibleChatHistory } from "../utils/o1ClassModels"; + +import { DeepSeekModelParamsResolver } from "./deepSeekModelParamsResolver"; + +export class DeepSeekService extends LLMServiceImpl< + DeepSeekUserModelParams, + DeepSeekModelParams, + DeepSeekService, + DeepSeekGeneratedProof, + DeepSeekServiceInternal +> { + readonly serviceName = "DeepSeekService"; + protected readonly internal = new DeepSeekServiceInternal( + this, + this.eventLogger, + this.generationsLoggerBuilder + ); + protected readonly modelParamsResolver = new DeepSeekModelParamsResolver(); +} + +export class DeepSeekGeneratedProof extends GeneratedProofImpl< + DeepSeekModelParams, + DeepSeekService, + DeepSeekGeneratedProof, + DeepSeekServiceInternal +> { + constructor( + rawProof: GeneratedRawContentItem, + proofGenerationContext: ProofGenerationContext, + modelParams: DeepSeekModelParams, + llmServiceInternal: DeepSeekServiceInternal, + previousProofVersions?: ProofVersion[] + ) { + super( + rawProof, + proofGenerationContext, + modelParams, + llmServiceInternal, + previousProofVersions + ); + } +} + +class DeepSeekServiceInternal extends LLMServiceInternal< + DeepSeekModelParams, + DeepSeekService, + DeepSeekGeneratedProof, + DeepSeekServiceInternal +> { + static readonly baseApiUrl = "https://api.deepseek.com/v1"; + + constructGeneratedProof( + rawProof: GeneratedRawContentItem, + proofGenerationContext: ProofGenerationContext, + modelParams: DeepSeekModelParams, + previousProofVersions?: ProofVersion[] | undefined + ): DeepSeekGeneratedProof { + return new DeepSeekGeneratedProof( + rawProof, + proofGenerationContext, + modelParams, + this, + previousProofVersions + ); + } + + async generateFromChatImpl( + analyzedChat: AnalyzedChatHistory, + params: DeepSeekModelParams, + choices: number + ): Promise { + LLMServiceInternal.validateChoices(choices); + + const openaiCompatibleApi = new OpenAI({ + apiKey: params.apiKey, + baseURL: DeepSeekServiceInternal.baseApiUrl, + }); + const formattedChat = this.formatChatHistory(analyzedChat.chat, params); + this.logDebug.event("Completion requested", { + history: formattedChat, + }); + + try { + const completion = + await openaiCompatibleApi.chat.completions.create({ + messages: formattedChat, + model: params.modelName, + n: choices, + temperature: params.temperature, + // eslint-disable-next-line @typescript-eslint/naming-convention + max_tokens: params.maxTokensToGenerate, + }); + const rawContentItems = completion.choices.map((choice) => { + const content = choice.message.content; + if (content === null) { + illegalState("response message content is null"); + } + return content; + }); + + return this.packContentWithTokensMetrics( + rawContentItems, + completion.usage, + analyzedChat, + params + ); + } catch (e) { + // TODO: Implement error repacking, according to the documentation + // https://api-docs.deepseek.com/quick_start/error_codes + // Postponed due to API unavailability + // throw DeepSeekServiceInternal.repackKnownError(e, params); + throw e; + } + } + + private packContentWithTokensMetrics( + rawContentItems: string[], + tokensUsage: OpenAI.Completions.CompletionUsage | undefined, + analyzedChat: AnalyzedChatHistory, + params: DeepSeekModelParams + ): GeneratedRawContent { + const promptTokens = + tokensUsage?.prompt_tokens ?? + analyzedChat.estimatedTokens.messagesTokens; + return LLMServiceInternal.aggregateToGeneratedRawContent( + rawContentItems, + promptTokens, + params.modelName, + { + promptTokens: promptTokens, + generatedTokens: tokensUsage?.completion_tokens, + tokensSpentInTotal: tokensUsage?.total_tokens, + } + ); + } + + private formatChatHistory( + chat: ChatHistory, + modelParams: DeepSeekModelParams + ): ChatHistory { + return toO1CompatibleChatHistory( + chat, + modelParams.modelName, + "deepseek" + ); + } +} diff --git a/src/llm/llmServices/modelParams.ts b/src/llm/llmServices/modelParams.ts index 3b4c0f55..c95fe336 100644 --- a/src/llm/llmServices/modelParams.ts +++ b/src/llm/llmServices/modelParams.ts @@ -56,11 +56,18 @@ export interface LMStudioModelParams extends ModelParams { port: number; } +export interface DeepSeekModelParams extends ModelParams { + modelName: string; + temperature: number; + apiKey: string; +} + export interface ModelsParams { predefinedProofsModelParams: PredefinedProofsModelParams[]; openAiParams: OpenAiModelParams[]; grazieParams: GrazieModelParams[]; lmStudioParams: LMStudioModelParams[]; + deepSeekParams: DeepSeekModelParams[]; } export const multiroundProfileSchema: JSONSchemaType = { @@ -170,3 +177,21 @@ export const lmStudioModelParamsSchema: JSONSchemaType = { required: ["temperature", "port", ...modelParamsSchema.required], additionalProperties: false, }; + +export const deepSeekModelParamsSchema: JSONSchemaType = { + title: "deepSeekModelsParameters", + type: "object", + properties: { + modelName: { type: "string" }, + temperature: { type: "number" }, + apiKey: { type: "string" }, + ...(modelParamsSchema.properties as PropertiesSchema), + }, + required: [ + "modelName", + "temperature", + "apiKey", + ...modelParamsSchema.required, + ], + additionalProperties: false, +}; diff --git a/src/llm/llmServices/utils/o1ClassModels.ts b/src/llm/llmServices/utils/o1ClassModels.ts index 8a311560..faa1bd48 100644 --- a/src/llm/llmServices/utils/o1ClassModels.ts +++ b/src/llm/llmServices/utils/o1ClassModels.ts @@ -9,6 +9,10 @@ const o1ClassModelsOpenAI = [ const o1ClassModelsGrazie = ["openai-o1", "openai-o1-mini"]; +// TODO: When DeepSeek API becomes stable: +// check whether the r1 model chat history is compatible with o1 model +const o1ClassModelsDeepSeek = ["deepseek-reasoner"]; + /** * As of November 2024, o1 model requires a different format of chat history. * It doesn't support the system prompt, therefore we manually @@ -18,10 +22,15 @@ const o1ClassModelsGrazie = ["openai-o1", "openai-o1-mini"]; export function toO1CompatibleChatHistory( chatHistory: ChatHistory, modelName: string, - service: "openai" | "grazie" + service: "openai" | "grazie" | "deepseek" ): ChatHistory { const o1ClassModels = - service === "openai" ? o1ClassModelsOpenAI : o1ClassModelsGrazie; + service === "openai" + ? o1ClassModelsOpenAI + : service === "grazie" + ? o1ClassModelsGrazie + : o1ClassModelsDeepSeek; + if (o1ClassModels.includes(modelName)) { return chatHistory.map((message: ChatMessage) => { return { diff --git a/src/llm/userModelParams.ts b/src/llm/userModelParams.ts index bb91c048..a4fd067e 100644 --- a/src/llm/userModelParams.ts +++ b/src/llm/userModelParams.ts @@ -78,6 +78,12 @@ export interface LMStudioUserModelParams extends UserModelParams { port: number; } +export interface DeepSeekUserModelParams extends UserModelParams { + modelName: string; + temperature: number; + apiKey: string; +} + export const userMultiroundProfileSchema: JSONSchemaType = { type: "object", @@ -168,3 +174,17 @@ export const lmStudioUserModelParamsSchema: JSONSchemaType = + { + title: "deepSeekModelsParameters", + type: "object", + properties: { + modelName: { type: "string" }, + temperature: { type: "number" }, + apiKey: { type: "string" }, + ...(userModelParamsSchema.properties as PropertiesSchema), + }, + required: ["modelId", "modelName", "temperature", "apiKey"], + additionalProperties: false, + }; diff --git a/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts b/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts index e2da73cf..72afba7e 100644 --- a/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts +++ b/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts @@ -1,4 +1,5 @@ import { LLMServices } from "../../llm/llmServices"; +import { DeepSeekService } from "../../llm/llmServices/deepSeek/deepSeekService"; import { GrazieService } from "../../llm/llmServices/grazie/grazieService"; import { LMStudioService } from "../../llm/llmServices/lmStudio/lmStudioService"; import { @@ -16,11 +17,13 @@ export function createDefaultServices(): LLMServices { const openAiService = new OpenAiService(); const grazieService = new GrazieService(); const lmStudioService = new LMStudioService(); + const deepSeekService = new DeepSeekService(); return { predefinedProofsService, openAiService, grazieService, lmStudioService, + deepSeekService, }; } @@ -32,6 +35,7 @@ export function createTrivialModelsParams( openAiParams: [], grazieParams: [], lmStudioParams: [], + deepSeekParams: [], }; } diff --git a/src/test/legacyBenchmark/benchmarkingFramework.ts b/src/test/legacyBenchmark/benchmarkingFramework.ts index ca193885..751c2d18 100644 --- a/src/test/legacyBenchmark/benchmarkingFramework.ts +++ b/src/test/legacyBenchmark/benchmarkingFramework.ts @@ -3,6 +3,7 @@ import * as fs from "fs"; import { LLMServices } from "../../llm/llmServices"; import { isLLMServiceRequestSucceeded } from "../../llm/llmServices/commonStructures/llmServiceRequest"; +import { DeepSeekService } from "../../llm/llmServices/deepSeek/deepSeekService"; import { GrazieService } from "../../llm/llmServices/grazie/grazieService"; import { LLMServiceImpl } from "../../llm/llmServices/llmService"; import { LMStudioService } from "../../llm/llmServices/lmStudio/lmStudioService"; @@ -461,6 +462,7 @@ async function prepareForBenchmarkCompletions( grazieService: new GrazieService(eventLogger), predefinedProofsService: new PredefinedProofsService(eventLogger), lmStudioService: new LMStudioService(eventLogger), + deepSeekService: new DeepSeekService(eventLogger), }; const processEnvironment: ProcessEnvironment = { coqProofChecker: coqProofChecker, @@ -604,5 +606,8 @@ function resolveInputModelsParametersOrThrow( lmStudioParams: inputModelsParams.lmStudioParams.map((inputParams) => resolveParametersOrThrow(llmServices.lmStudioService, inputParams) ), + deepSeekParams: inputModelsParams.deepSeekParams.map((inputParams) => + resolveParametersOrThrow(llmServices.deepSeekService, inputParams) + ), }; } diff --git a/src/test/legacyBenchmark/inputModelsParams.ts b/src/test/legacyBenchmark/inputModelsParams.ts index ef0c55bb..73c9799a 100644 --- a/src/test/legacyBenchmark/inputModelsParams.ts +++ b/src/test/legacyBenchmark/inputModelsParams.ts @@ -1,4 +1,5 @@ import { + DeepSeekUserModelParams, GrazieUserModelParams, LMStudioUserModelParams, OpenAiUserModelParams, @@ -10,6 +11,7 @@ export interface InputModelsParams { openAiParams: OpenAiUserModelParams[]; grazieParams: GrazieUserModelParams[]; lmStudioParams: LMStudioUserModelParams[]; + deepSeekParams: DeepSeekUserModelParams[]; } export const onlyAutoModelsParams: InputModelsParams = { @@ -22,6 +24,7 @@ export const onlyAutoModelsParams: InputModelsParams = { }, ], lmStudioParams: [], + deepSeekParams: [], }; export const tacticianModelsParams: InputModelsParams = { @@ -34,4 +37,5 @@ export const tacticianModelsParams: InputModelsParams = { }, ], lmStudioParams: [], + deepSeekParams: [], }; diff --git a/src/test/llm/llmServices/deepSeekService.test.ts b/src/test/llm/llmServices/deepSeekService.test.ts new file mode 100644 index 00000000..1a191ac9 --- /dev/null +++ b/src/test/llm/llmServices/deepSeekService.test.ts @@ -0,0 +1,144 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../llm/llmServiceErrors"; +import { DeepSeekService } from "../../../llm/llmServices/deepSeek/deepSeekService"; +import { DeepSeekModelParams } from "../../../llm/llmServices/modelParams"; +import { DeepSeekUserModelParams } from "../../../llm/userModelParams"; + +import { testIf } from "../../commonTestFunctions/conditionalTest"; +import { + withLLMService, + withLLMServiceAndParams, +} from "../../commonTestFunctions/withLLMService"; +import { + deepSeekModelName, + mockProofGenerationContext, + testModelId, +} from "../llmSpecificTestUtils/constants"; +import { testLLMServiceCompletesAdmitFromFile } from "../llmSpecificTestUtils/testAdmitCompletion"; +import { + paramsResolvedWithBasicDefaults, + testResolveValidCompleteParameters, +} from "../llmSpecificTestUtils/testResolveParameters"; + +suite("[LLMService] Test `DeepSeekService`", function () { + const apiKey = process.env.DEEPSEEK_API_KEY; + const choices = 15; + const inputFile = ["small_document.v"]; + + const requiredInputParamsTemplate = { + modelId: testModelId, + modelName: deepSeekModelName, + temperature: 1, + choices: choices, + maxTokensToGenerate: 2000, + tokensLimit: 4000, + }; + + testIf( + apiKey !== undefined, + "`DEEPSEEK_API_KEY` is not specified", + this.title, + `Simple generation: 1 request, ${choices} choices`, + async () => { + const inputParams: DeepSeekUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: apiKey!, + }; + const deepSeekService = new DeepSeekService(); + await testLLMServiceCompletesAdmitFromFile( + deepSeekService, + inputParams, + inputFile, + choices + ); + } + )?.timeout(5000); + + test("Test `resolveParameters` reads & accepts valid params", async () => { + const inputParams: DeepSeekUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMService(new DeepSeekService(), async (deepSeekService) => { + testResolveValidCompleteParameters(deepSeekService, inputParams); + testResolveValidCompleteParameters( + deepSeekService, + { + ...inputParams, + ...paramsResolvedWithBasicDefaults, + maxTokensToGenerate: 2000, + tokensLimit: 4000, + }, + true + ); + }); + }); + + test("Test `generateProof` throws on invalid configurations, ", async () => { + const inputParams: DeepSeekUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMServiceAndParams( + new DeepSeekService(), + inputParams, + async (deepSeekService, resolvedParams: DeepSeekModelParams) => { + // non-positive choices + await expect(async () => { + await deepSeekService.generateProof( + mockProofGenerationContext, + resolvedParams, + -1 + ); + }).toBeRejectedWith(ConfigurationError, "choices"); + } + ); + }); + + testIf( + apiKey !== undefined, + "`DEEPSEEK_API_KEY` is not specified", + this.title, + "Test `generateProof` throws on invalid configurations, ", + async () => { + const inputParams: DeepSeekUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: apiKey!, + }; + await withLLMServiceAndParams( + new DeepSeekService(), + inputParams, + async (deepSeekService, resolvedParams) => { + // unknown model name + await expect(async () => { + await deepSeekService.generateProof( + mockProofGenerationContext, + { + ...resolvedParams, + modelName: "unknown", + } as DeepSeekModelParams, + 1 + ); + }).toBeRejectedWith(ConfigurationError, "model name"); + + // context length exceeded (requested too many tokens for the completion) + await expect(async () => { + await deepSeekService.generateProof( + mockProofGenerationContext, + { + ...resolvedParams, + maxTokensToGenerate: 500_000, + tokensLimit: 1_000_000, + } as DeepSeekModelParams, + 1 + ); + }).toBeRejectedWith( + ConfigurationError, + "`tokensLimit` and `maxTokensToGenerate`" + ); + } + ); + } + ); +}); diff --git a/src/test/llm/llmSpecificTestUtils/constants.ts b/src/test/llm/llmSpecificTestUtils/constants.ts index 73683bb8..05a3f88c 100644 --- a/src/test/llm/llmSpecificTestUtils/constants.ts +++ b/src/test/llm/llmSpecificTestUtils/constants.ts @@ -18,6 +18,8 @@ export const testModelId = "test model"; export const gptModelName = "gpt-4o-mini-2024-07-18"; +export const deepSeekModelName = "deepseek-chat"; + export const mockChat: AnalyzedChatHistory = { chat: [{ role: "system", content: "Generate proofs." }], contextTheorems: [], diff --git a/src/test/llm/parseUserModelParams.test.ts b/src/test/llm/parseUserModelParams.test.ts index fb7a8c2d..e71d9993 100644 --- a/src/test/llm/parseUserModelParams.test.ts +++ b/src/test/llm/parseUserModelParams.test.ts @@ -2,6 +2,7 @@ import { JSONSchemaType } from "ajv"; import { expect } from "earl"; import { + deepSeekUserModelParamsSchema, grazieUserModelParamsSchema, lmStudioUserModelParamsSchema, openAiUserModelParamsSchema, @@ -65,7 +66,7 @@ suite("Parse `UserModelParams` from JSON test", () => { const validOpenAiUserModelParamsComplete = { ...validUserModelParamsCompelete, modelName: "gpt-model", - temperature: 36.6, + temperature: 0.8, apiKey: "api-key", }; const validGrazieUserModelParamsComplete = { @@ -76,9 +77,15 @@ suite("Parse `UserModelParams` from JSON test", () => { }; const validLMStudioUserModelParamsComplete = { ...validUserModelParamsCompelete, - temperature: 36.6, + temperature: 0.8, port: 555, }; + const validDeepSeekUserModelParamsComplete = { + ...validUserModelParamsCompelete, + modelName: "deepseek-chat", + temperature: 0.8, + apiKey: "api-key", + }; test("Validate `UserMultiroundProfile`", () => { isValidJSON( @@ -172,4 +179,11 @@ suite("Parse `UserModelParams` from JSON test", () => { lmStudioUserModelParamsSchema ); }); + + test("Validate `DeepSeekUserModelParams`", () => { + isValidJSON( + validDeepSeekUserModelParamsComplete, + deepSeekUserModelParamsSchema + ); + }); });