diff --git a/prisma/migrations/20250214221114_add_user_model_preference/migration.sql b/prisma/migrations/20250214221114_add_user_model_preference/migration.sql new file mode 100644 index 00000000..1e70a0d6 --- /dev/null +++ b/prisma/migrations/20250214221114_add_user_model_preference/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "users" ADD COLUMN "modelPreference" TEXT; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 8bb5c489..91dd59c4 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -35,6 +35,7 @@ model User { degenMode Boolean @default(false) referralCode String? @unique referringUserId String? + modelPreference String? wallets Wallet[] conversations Conversation[] diff --git a/src/ai/providers.tsx b/src/ai/providers.tsx index 00b037ee..f3bc2ff1 100644 --- a/src/ai/providers.tsx +++ b/src/ai/providers.tsx @@ -5,6 +5,7 @@ import { createOpenAI } from '@ai-sdk/openai'; import { z } from 'zod'; import { Card } from '@/components/ui/card'; +import { MODEL_PREFERENCE_MAP } from '@/lib/constants'; import { actionTools } from './generic/action'; import { jinaTools } from './generic/jina'; @@ -27,35 +28,47 @@ const usingAnthropic = !!process.env.ANTHROPIC_API_KEY; const anthropic = createAnthropic({ apiKey: process.env.ANTHROPIC_API_KEY }); const claude35Sonnet = anthropic('claude-3-5-sonnet-20241022'); -const openai = createOpenAI({ - baseURL: process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1', - apiKey: process.env.OPENAI_API_KEY, - compatibility: 'strict', - ...(process.env.OPENAI_BASE_URL?.includes('openrouter.ai') && { - fetch: async (url, options) => { - if (!options?.body) return fetch(url, options); - - const body = JSON.parse(options.body as string); - - const modifiedBody = { - ...body, - provider: { - order: ['Anthropic', 'OpenAI'], - allow_fallbacks: false, - }, - }; - - options.body = JSON.stringify(modifiedBody); - - return fetch(url, options); - }, - }), -}); +const openai = (modelName: string) => { + const shouldOverrideFetch = + process.env.OPENAI_BASE_URL?.includes('openrouter.ai') && + modelName.includes('anthropic'); + return createOpenAI({ + baseURL: process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1', + apiKey: process.env.OPENAI_API_KEY, + compatibility: 'strict', + ...(shouldOverrideFetch && { + fetch: async (url, options) => { + if (!options?.body) return fetch(url, options); + + const body = JSON.parse(options.body as string); + + const modifiedBody = { + ...body, + provider: { + order: ['Anthropic'], + allow_fallbacks: false, + }, + }; + + options.body = JSON.stringify(modifiedBody); + + return fetch(url, options); + }, + }), + })(modelName); +}; export const orchestratorModel = openai('gpt-4o-mini'); const openAiModel = openai(process.env.OPENAI_MODEL_NAME || 'gpt-4o'); +export const defaultModel = usingAnthropic ? claude35Sonnet : openAiModel; + +export const getModelFromPreference = (preference?: string) => { + if (!preference || !MODEL_PREFERENCE_MAP[preference]) return defaultModel; + return openai(preference); +}; + export const defaultSystemPrompt = ` Your name is Neur (Agent). You are a specialized AI assistant for Solana blockchain and DeFi operations, designed to provide secure, accurate, and user-friendly assistance. @@ -111,8 +124,6 @@ Realtime knowledge: - { approximateCurrentTime: ${new Date().toISOString()}} `; -export const defaultModel = usingAnthropic ? claude35Sonnet : openAiModel; - export interface ToolConfig { displayName?: string; icon?: ReactNode; diff --git a/src/app/(user)/chat/[id]/chat-interface.tsx b/src/app/(user)/chat/[id]/chat-interface.tsx index ab2b77e5..31e077ca 100644 --- a/src/app/(user)/chat/[id]/chat-interface.tsx +++ b/src/app/(user)/chat/[id]/chat-interface.tsx @@ -411,6 +411,9 @@ function ChatMessage({ (_, alt, src, width, height) => `![${alt}](${src}#size=${width}x${height})`, ); + // Some models have extra messages about tool outputs + const isToolOutputResponse = message.content?.startsWith('```tool_outputs'); + return (
)} - {message.content && ( -
+ {!isToolOutputResponse && + message.content && + message.content.trim() !== '' && (
- ( - - ), - img: ({ node, alt, src, ...props }) => { - if (!src) return null; - - try { - // Handle both relative and absolute URLs safely - const url = new URL(src, 'http://dummy.com'); - const size = url.hash.match(/size=(\d+)x(\d+)/); - - if (size) { - const [, width, height] = size; - // Remove hash from src - url.hash = ''; - return ( - {alt - ); - } - } catch (e) { - // If URL parsing fails, fallback to original src - console.warn('Failed to parse image URL:', e); - } - - const thumbnailPattern = /_thumb\.(png|jpg|jpeg|gif)$/i; - const isThumbnail = thumbnailPattern.test(src); - - const width = isThumbnail ? 40 : 500; - const height = isThumbnail ? 40 : 300; - - // Fallback to Image component with default dimensions - return ( - {alt - ); - }, - }} +
- {processedContent} - + ( + + ), + img: ({ node, alt, src, ...props }) => { + if (!src) return null; + + try { + // Handle both relative and absolute URLs safely + const url = new URL(src, 'http://dummy.com'); + const size = url.hash.match(/size=(\d+)x(\d+)/); + + if (size) { + const [, width, height] = size; + // Remove hash from src + url.hash = ''; + return ( + {alt + ); + } + } catch (e) { + // If URL parsing fails, fallback to original src + console.warn('Failed to parse image URL:', e); + } + + const thumbnailPattern = /_thumb\.(png|jpg|jpeg|gif)$/i; + const isThumbnail = thumbnailPattern.test(src); + + const width = isThumbnail ? 40 : 500; + const height = isThumbnail ? 40 : 300; + + // Fallback to Image component with default dimensions + return ( + {alt + ); + }, + }} + > + {processedContent} + +
-
- )} + )} {message.toolInvocations && ( { scrollToBottom(); - }, []); + }, [chatMessages]); const handleSend = async (value: string, attachments: Attachment[]) => { if (!value.trim() && (!attachments || attachments.length === 0)) { @@ -750,7 +759,7 @@ export default function ChatInterface({ useAnimationEffect(); return ( -
+
diff --git a/src/app/(user)/chat/[id]/page.tsx b/src/app/(user)/chat/[id]/page.tsx index 1eca9fd8..7215aab6 100644 --- a/src/app/(user)/chat/[id]/page.tsx +++ b/src/app/(user)/chat/[id]/page.tsx @@ -3,6 +3,7 @@ import { Suspense } from 'react'; import { Metadata } from 'next'; import { notFound } from 'next/navigation'; +import { ModelSelector } from '@/components/model-selector'; import { verifyUser } from '@/server/actions/user'; import { dbGetConversation, @@ -51,6 +52,7 @@ async function ChatData({ params }: { params: Promise<{ id: string }> }) { // Verify user authentication and access rights const authResponse = await verifyUser(); const userId = authResponse?.data?.data?.id; + const modelPreference = authResponse?.data?.data?.modelPreference; // Check if user has access to private conversation if ( @@ -69,7 +71,14 @@ async function ChatData({ params }: { params: Promise<{ id: string }> }) { return notFound(); } - return ; + return ( + <> +
+ +
+ + + ); } /** diff --git a/src/app/(user)/home/home-content.tsx b/src/app/(user)/home/home-content.tsx index 960d2499..d0ffc61e 100644 --- a/src/app/(user)/home/home-content.tsx +++ b/src/app/(user)/home/home-content.tsx @@ -11,9 +11,11 @@ import { Attachment, JSONValue } from 'ai'; import { useChat } from 'ai/react'; import { CheckCircle2, Loader2 } from 'lucide-react'; import { toast } from 'sonner'; +import { mutate } from 'swr'; import { v4 as uuidv4 } from 'uuid'; import ChatInterface from '@/app/(user)/chat/[id]/chat-interface'; +import { ModelSelector } from '@/components/model-selector'; import { SavedPromptsMenu } from '@/components/saved-prompts-menu'; import { Badge } from '@/components/ui/badge'; import BlurFade from '@/components/ui/blur-fade'; @@ -268,7 +270,6 @@ export function HomeContent() { } }, [pathname, resetChat]); - // 监听浏览器的前进后退 useEffect(() => { const handlePopState = () => { if (location.pathname === '/home') { @@ -390,6 +391,11 @@ export function HomeContent() {
+ mutate(`user-${user?.privyId}`)} + /> +
+ mutate(`user-${user?.privyId}`)} + /> +
)} diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index 73ec0967..d8b8b348 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -17,6 +17,7 @@ import { defaultModel, defaultSystemPrompt, defaultTools, + getModelFromPreference, getToolsFromRequiredTools, } from '@/ai/providers'; import { MAX_TOKEN_MESSAGES } from '@/lib/constants'; @@ -47,6 +48,7 @@ export async function POST(req: Request) { const userId = session?.data?.data?.id; const publicKey = session?.data?.data?.publicKey; const degenMode = session?.data?.data?.degenMode; + const modelPreference = session?.data?.data?.modelPreference; if (!userId) { return new Response('Unauthorized', { status: 401 }); @@ -193,7 +195,7 @@ export async function POST(req: Request) { // Begin streaming text from the model const result = streamText({ - model: defaultModel, + model: getModelFromPreference(modelPreference), system: systemPrompt, tools: tools as Record>, experimental_toolCallStreaming: true, diff --git a/src/components/model-selector.tsx b/src/components/model-selector.tsx new file mode 100644 index 00000000..9da00a12 --- /dev/null +++ b/src/components/model-selector.tsx @@ -0,0 +1,76 @@ +'use client'; + +import { startTransition, useMemo, useOptimistic, useState } from 'react'; + +import { CheckCircledIcon, ChevronDownIcon } from '@radix-ui/react-icons'; + +import { Button } from '@/components/ui/button'; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu'; +import { MODEL_PREFERENCE_MAP } from '@/lib/constants'; +import { cn } from '@/lib/utils'; +import { updateUser } from '@/server/actions/user'; + +export function ModelSelector({ + selectedModelId = 'anthropic/claude-3.5-sonnet', + className, + mutate, +}: { + selectedModelId?: string; + mutate?: () => void; +} & React.ComponentProps) { + const [open, setOpen] = useState(false); + const [optimisticModelId, setOptimisticModelId] = + useOptimistic(selectedModelId); + + const selectedModel = useMemo( + () => MODEL_PREFERENCE_MAP[optimisticModelId], + [optimisticModelId], + ); + + return ( + + + + + + {Object.keys(MODEL_PREFERENCE_MAP).map((modelId) => ( + { + setOpen(false); + + startTransition(() => { + setOptimisticModelId(modelId); + updateUser({ modelPreference: modelId }); + mutate?.(); + }); + }} + className="group/item flex flex-row items-center justify-between gap-4" + data-active={modelId === optimisticModelId} + > +
+ {MODEL_PREFERENCE_MAP[modelId]} +
+
+ +
+
+ ))} +
+
+ ); +} diff --git a/src/lib/constants.ts b/src/lib/constants.ts index 7541798e..51ddaf1f 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -10,3 +10,10 @@ export const RPC_URL = export const MAX_TOKEN_MESSAGES = 10; export const NO_CONFIRMATION_MESSAGE = ' (Does not require confirmation)'; + +export const MODEL_PREFERENCE_MAP: Record = { + 'anthropic/claude-3.5-sonnet': 'Claude-3.5 Sonnet', + //'deepseek/deepseek-chat': 'DeepSeek V3', + 'google/gemini-2.0-flash-001': 'Gemini Flash 2.0', + 'gpt-4o': 'GPT-4o', +}; diff --git a/src/server/actions/action.ts b/src/server/actions/action.ts index b7bdb58a..fc74870e 100644 --- a/src/server/actions/action.ts +++ b/src/server/actions/action.ts @@ -13,6 +13,7 @@ import { defaultModel, defaultSystemPrompt, defaultTools, + getModelFromPreference, getToolsFromRequiredTools, } from '@/ai/providers'; import prisma from '@/lib/prisma'; @@ -121,7 +122,7 @@ export async function processAction(action: ActionWithUser) { // Call the AI model logWithTiming(startTime, `[action:${action.id}] calling generateText`); const { response, usage } = await generateText({ - model: defaultModel, + model: getModelFromPreference(action.user.modelPreference ?? undefined), system: systemPrompt, tools: clonedTools as Record>, experimental_telemetry: { diff --git a/src/server/actions/user.ts b/src/server/actions/user.ts index 860463ca..5a977875 100644 --- a/src/server/actions/user.ts +++ b/src/server/actions/user.ts @@ -158,6 +158,7 @@ export const verifyUser = actionClient.action< degenMode: boolean; publicKey?: string; privyId: string; + modelPreference?: string; }> >(async () => { const token = (await cookies()).get('privy-token')?.value; @@ -176,6 +177,7 @@ export const verifyUser = actionClient.action< id: true, degenMode: true, privyId: true, + modelPreference: true, wallets: { select: { publicKey: true, @@ -201,6 +203,7 @@ export const verifyUser = actionClient.action< privyId: user.privyId, publicKey: user.wallets[0]?.publicKey, degenMode: user.degenMode, + modelPreference: user.modelPreference ?? undefined, }, }; } catch { @@ -314,6 +317,7 @@ export const getPrivyClient = actionClient.action( export type UserUpdateData = { degenMode?: boolean; referralCode?: string; // Add referralCode as an optional field + modelPreference?: string; }; export async function updateUser(data: UserUpdateData) { try { @@ -326,7 +330,7 @@ export async function updateUser(data: UserUpdateData) { } // Extract referralCode from the input data - const { referralCode, degenMode } = data; + const { referralCode, degenMode, modelPreference } = data; // If referralCode is provided, validate and update referringUserId if (referralCode) { @@ -358,6 +362,7 @@ export async function updateUser(data: UserUpdateData) { where: { id: userId }, data: { degenMode, + modelPreference: modelPreference, referringUserId: referringUser.id, }, }); @@ -367,6 +372,7 @@ export async function updateUser(data: UserUpdateData) { where: { id: userId }, data: { degenMode, + modelPreference, }, }); } diff --git a/src/types/db.ts b/src/types/db.ts index 910250c3..07743679 100644 --- a/src/types/db.ts +++ b/src/types/db.ts @@ -53,6 +53,7 @@ export type NeurUser = Pick< | 'subscription' | 'referralCode' | 'referringUserId' + | 'modelPreference' > & { privyUser: PrivyUser; hasEAP: boolean;