Skip to content

Commit 41d2c82

Browse files
committed
implement model switching for custom prompts
Add support for both permanent and temporary model switching in custom prompts using frontmatter configuration. Users can now specify models using either 'model:' for permanent switches or 'temp-model:' for temporary switches. This change improves the flexibility of custom prompts by allowing users to switch models contextually while maintaining their default preferences. - closes #1146
1 parent faf8b3d commit 41d2c82

File tree

7 files changed

+143
-21
lines changed

7 files changed

+143
-21
lines changed

src/LLMProviders/chatModelManager.ts

+6
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,10 @@ export default class ChatModelManager {
390390
const settings = getSettings();
391391
return settings.activeModels.find((model) => model.name === modelName);
392392
}
393+
394+
getCurrentModel(): CustomModel | null {
395+
const currentModelKey = getModelKey();
396+
if (!currentModelKey) return null;
397+
return this.findModelByName(currentModelKey.split("|")[0]) || null;
398+
}
393399
}

src/commands/index.ts

+8-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) {
122122
(plugin as any).removeCommand(id);
123123
});
124124

125-
const promptProcessor = CustomPromptProcessor.getInstance(plugin.app.vault);
125+
const promptProcessor = CustomPromptProcessor.getInstance(plugin.app);
126126

127127
addEditorCommand(plugin, COMMAND_IDS.FIX_GRAMMAR, (editor) => {
128128
processInlineEditCommand(plugin, editor, COMMAND_IDS.FIX_GRAMMAR);
@@ -249,7 +249,12 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) {
249249
new Notice(`No prompt found with the title "${promptTitle}".`);
250250
return;
251251
}
252-
plugin.processCustomPrompt(COMMAND_IDS.APPLY_CUSTOM_PROMPT, prompt.content);
252+
plugin.processCustomPrompt(
253+
COMMAND_IDS.APPLY_CUSTOM_PROMPT,
254+
prompt.content,
255+
prompt.model,
256+
prompt.isTemporaryModel
257+
);
253258
} catch (err) {
254259
console.error(err);
255260
new Notice("An error occurred.");
@@ -260,6 +265,7 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) {
260265
addCommand(plugin, COMMAND_IDS.APPLY_ADHOC_PROMPT, async () => {
261266
const modal = new AdhocPromptModal(plugin.app, async (adhocPrompt: string) => {
262267
try {
268+
// For ad-hoc prompts, we don't support model switching
263269
plugin.processCustomPrompt(COMMAND_IDS.APPLY_ADHOC_PROMPT, adhocPrompt);
264270
} catch (err) {
265271
console.error(err);

src/components/Chat.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ const Chat: React.FC<ChatProps> = ({
154154
setLoadingMessage(LOADING_MESSAGES.DEFAULT);
155155

156156
// First, process the original user message for custom prompts
157-
const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault);
157+
const customPromptProcessor = CustomPromptProcessor.getInstance(app);
158158
let processedUserMessage = await customPromptProcessor.processCustomPrompt(
159159
inputMessage || "",
160160
"",
@@ -482,7 +482,7 @@ ${chatContent}`;
482482
};
483483
};
484484

485-
const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault);
485+
const customPromptProcessor = CustomPromptProcessor.getInstance(app);
486486
// eslint-disable-next-line react-hooks/exhaustive-deps
487487
useEffect(
488488
createEffect(COMMAND_IDS.APPLY_CUSTOM_PROMPT, async (selectedText, customPrompt) => {

src/components/chat-components/ChatInput.tsx

+22-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import React, {
3939
} from "react";
4040
import { useDropzone } from "react-dropzone";
4141
import ContextControl from "./ContextControl";
42+
import ChatModelManager from "@/LLMProviders/chatModelManager";
4243

4344
interface ChatInputProps {
4445
inputMessage: string;
@@ -199,14 +200,34 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
199200
};
200201

201202
const showCustomPromptModal = async () => {
202-
const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault);
203+
const customPromptProcessor = CustomPromptProcessor.getInstance(app);
203204
const prompts = await customPromptProcessor.getAllPrompts();
204205
const promptTitles = prompts.map((prompt) => prompt.title);
205206

206207
new ListPromptModal(app, promptTitles, async (promptTitle: string) => {
207208
const selectedPrompt = prompts.find((prompt) => prompt.title === promptTitle);
208209
if (selectedPrompt) {
209210
customPromptProcessor.recordPromptUsage(selectedPrompt.title);
211+
212+
// If the prompt specifies a model, try to switch to it
213+
if (selectedPrompt.model) {
214+
try {
215+
const modelInstance = ChatModelManager.getInstance().findModelByName(
216+
selectedPrompt.model
217+
);
218+
if (modelInstance) {
219+
await ChatModelManager.getInstance().setChatModel(modelInstance);
220+
} else {
221+
new Notice(`Model "${selectedPrompt.model}" not found. Using current model.`);
222+
}
223+
} catch (error) {
224+
console.error("Error switching model:", error);
225+
new Notice(
226+
`Failed to switch to model "${selectedPrompt.model}". Using current model.`
227+
);
228+
}
229+
}
230+
210231
setInputMessage(selectedPrompt.content);
211232
}
212233
}).open();

src/customPromptProcessor.test.ts

+14-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import { CustomPrompt, CustomPromptProcessor } from "@/customPromptProcessor";
22
import { extractNoteFiles, getFileContent, getNotesFromPath } from "@/utils";
3-
import { Notice, TFile, Vault } from "obsidian";
3+
import { Notice, TFile, Vault, App } from "obsidian";
44

55
// Mock Obsidian
66
jest.mock("obsidian", () => ({
77
Notice: jest.fn(),
88
TFile: jest.fn(),
99
Vault: jest.fn(),
10+
App: jest.fn(),
1011
}));
1112

1213
// Mock the utility functions
@@ -21,22 +22,29 @@ jest.mock("@/utils", () => ({
2122

2223
describe("CustomPromptProcessor", () => {
2324
let processor: CustomPromptProcessor;
24-
let mockVault: Vault;
25+
let mockApp: App;
2526
let mockActiveNote: TFile;
2627

2728
beforeEach(() => {
2829
// Reset mocks before each test
2930
jest.clearAllMocks();
3031

3132
// Create mock objects
32-
mockVault = {} as Vault;
33+
mockApp = {
34+
vault: {} as Vault,
35+
metadataCache: {
36+
getFileCache: jest.fn().mockReturnValue({
37+
frontmatter: {},
38+
}),
39+
},
40+
} as unknown as App;
3341
mockActiveNote = {
3442
path: "path/to/active/note.md",
3543
basename: "Active Note",
3644
} as TFile;
3745

3846
// Create an instance of CustomPromptProcessor with mocked dependencies
39-
processor = CustomPromptProcessor.getInstance(mockVault);
47+
processor = CustomPromptProcessor.getInstance(mockApp);
4048
});
4149

4250
it("should add 1 context and selectedText", async () => {
@@ -106,7 +114,7 @@ describe("CustomPromptProcessor", () => {
106114

107115
expect(result).toContain("This is the active note: {activenote}");
108116
expect(result).toContain("Content of the active note");
109-
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockVault);
117+
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockApp);
110118
});
111119

112120
it("should handle {activeNote} when no active note is provided", async () => {
@@ -299,7 +307,7 @@ describe("CustomPromptProcessor", () => {
299307

300308
expect(result).toContain("Summarize this: {selectedText}");
301309
expect(result).toContain("selectedText (entire active note):\n\n Content of the active note");
302-
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockVault);
310+
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockApp.vault);
303311
});
304312

305313
it("should not duplicate active note content when both {} and {activeNote} are present", async () => {

src/customPromptProcessor.ts

+45-8
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,60 @@ import {
99
getNotesFromTags,
1010
processVariableNameForNotePath,
1111
} from "@/utils";
12-
import { normalizePath, Notice, TFile, Vault } from "obsidian";
12+
import { normalizePath, Notice, TFile, Vault, App } from "obsidian";
1313

1414
export interface CustomPrompt {
1515
title: string;
1616
content: string;
17+
model?: string;
18+
isTemporaryModel?: boolean;
1719
}
1820

1921
export class CustomPromptProcessor {
20-
private static instance: CustomPromptProcessor;
22+
private static instance: CustomPromptProcessor | null = null;
23+
private vault: Vault;
2124
private usageStrategy: TimestampUsageStrategy;
25+
private app: App;
2226

23-
private constructor(private vault: Vault) {
27+
private constructor(app: App) {
28+
this.app = app;
29+
this.vault = app.vault;
2430
this.usageStrategy = new TimestampUsageStrategy();
2531
}
2632

2733
get customPromptsFolder(): string {
2834
return getSettings().customPromptsFolder;
2935
}
3036

31-
static getInstance(vault: Vault): CustomPromptProcessor {
37+
static getInstance(app: App): CustomPromptProcessor {
3238
if (!CustomPromptProcessor.instance) {
33-
CustomPromptProcessor.instance = new CustomPromptProcessor(vault);
39+
CustomPromptProcessor.instance = new CustomPromptProcessor(app);
3440
}
3541
return CustomPromptProcessor.instance;
3642
}
3743

44+
private parseFrontMatter(
45+
file: TFile,
46+
content: string
47+
): { frontmatter: { model?: string; isTemporaryModel?: boolean }; content: string } {
48+
// Get the cached frontmatter from Obsidian's metadata cache
49+
const cache = this.app.metadataCache.getFileCache(file);
50+
const frontmatter = (cache?.frontmatter as Record<string, unknown>) || {};
51+
52+
// Handle both model and temp-model fields
53+
const tempModel = frontmatter["temp-model"] as string | undefined;
54+
const model = tempModel || (frontmatter.model as string | undefined);
55+
const isTemporaryModel = "temp-model" in frontmatter;
56+
57+
// Get the content without frontmatter
58+
const contentWithoutFrontmatter = content.replace(/^---\n[\s\S]*?\n---\n/, "").trim();
59+
60+
return {
61+
frontmatter: { model, isTemporaryModel },
62+
content: contentWithoutFrontmatter,
63+
};
64+
}
65+
3866
recordPromptUsage(title: string) {
3967
this.usageStrategy.recordUsage(title);
4068
}
@@ -47,10 +75,13 @@ export class CustomPromptProcessor {
4775

4876
const prompts: CustomPrompt[] = [];
4977
for (const file of files) {
50-
const content = await this.vault.read(file);
78+
const rawContent = await this.vault.read(file);
79+
const { frontmatter, content } = this.parseFrontMatter(file, rawContent);
5180
prompts.push({
5281
title: file.basename,
5382
content: content,
83+
model: frontmatter.model,
84+
isTemporaryModel: frontmatter.isTemporaryModel,
5485
});
5586
}
5687

@@ -64,8 +95,14 @@ export class CustomPromptProcessor {
6495
const filePath = `${this.customPromptsFolder}/${title}.md`;
6596
const file = this.vault.getAbstractFileByPath(filePath);
6697
if (file instanceof TFile) {
67-
const content = await this.vault.read(file);
68-
return { title, content };
98+
const rawContent = await this.vault.read(file);
99+
const { frontmatter, content } = this.parseFrontMatter(file, rawContent);
100+
return {
101+
title,
102+
content,
103+
model: frontmatter.model,
104+
isTemporaryModel: frontmatter.isTemporaryModel,
105+
};
69106
}
70107
return null;
71108
}

src/main.ts

+46-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import {
3232
WorkspaceLeaf,
3333
} from "obsidian";
3434
import { IntentAnalyzer } from "./LLMProviders/intentAnalyzer";
35+
import ChatModelManager from "@/LLMProviders/chatModelManager";
3536

3637
export default class CopilotPlugin extends Plugin {
3738
// A chat history that stores the messages sent and received
@@ -43,6 +44,7 @@ export default class CopilotPlugin extends Plugin {
4344
vectorStoreManager: VectorStoreManager;
4445
fileParserManager: FileParserManager;
4546
settingsUnsubscriber?: () => void;
47+
private chatModelManager: ChatModelManager;
4648

4749
async onload(): Promise<void> {
4850
await this.loadSettings();
@@ -75,6 +77,8 @@ export default class CopilotPlugin extends Plugin {
7577

7678
this.initActiveLeafChangeHandler();
7779

80+
this.chatModelManager = ChatModelManager.getInstance();
81+
7882
this.addRibbonIcon("message-square", "Open Copilot Chat", (evt: MouseEvent) => {
7983
this.activateView();
8084
});
@@ -202,9 +206,49 @@ export default class CopilotPlugin extends Plugin {
202206
} as Partial<Editor> as Editor;
203207
}
204208

205-
processCustomPrompt(eventType: string, customPrompt: string) {
209+
async processCustomPrompt(
210+
eventType: string,
211+
customPrompt: string,
212+
model?: string,
213+
isTemporary?: boolean
214+
) {
206215
const editor = this.getCurrentEditorOrDummy();
207-
this.processText(editor, eventType, customPrompt, false);
216+
217+
// If a model is specified, switch to it
218+
let originalModel: CustomModel | null = null;
219+
if (model) {
220+
try {
221+
// Only store the original model if this is a temporary switch
222+
if (isTemporary) {
223+
originalModel = this.chatModelManager.getCurrentModel();
224+
}
225+
const targetModel = this.chatModelManager.findModelByName(model);
226+
if (targetModel) {
227+
await this.chatModelManager.setChatModel(targetModel);
228+
} else {
229+
new Notice(`Model "${model}" not found. Using current model.`);
230+
}
231+
} catch (error) {
232+
console.error("Error switching model:", error);
233+
new Notice(`Failed to switch to model "${model}". Using current model.`);
234+
}
235+
}
236+
237+
try {
238+
await this.processText(editor, eventType, customPrompt, false);
239+
} finally {
240+
// Restore the original model only if this was a temporary switch
241+
if (originalModel && isTemporary) {
242+
try {
243+
await this.chatModelManager.setChatModel(originalModel);
244+
} catch (error) {
245+
console.error("Error restoring original model:", error);
246+
new Notice(
247+
"Failed to restore original model. You may need to manually reset your model in settings."
248+
);
249+
}
250+
}
251+
}
208252
}
209253

210254
toggleView() {

0 commit comments

Comments
 (0)