Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "users" ADD COLUMN "modelPreference" TEXT;
1 change: 1 addition & 0 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ model User {
degenMode Boolean @default(false)
referralCode String? @unique
referringUserId String?
modelPreference String?

wallets Wallet[]
conversations Conversation[]
Expand Down
63 changes: 37 additions & 26 deletions src/ai/providers.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { createOpenAI } from '@ai-sdk/openai';
import { z } from 'zod';

import { Card } from '@/components/ui/card';
import { MODEL_PREFERENCE_MAP } from '@/lib/constants';

import { actionTools } from './generic/action';
import { jinaTools } from './generic/jina';
Expand All @@ -27,35 +28,47 @@ const usingAnthropic = !!process.env.ANTHROPIC_API_KEY;
const anthropic = createAnthropic({ apiKey: process.env.ANTHROPIC_API_KEY });
const claude35Sonnet = anthropic('claude-3-5-sonnet-20241022');

const openai = createOpenAI({
baseURL: process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1',
apiKey: process.env.OPENAI_API_KEY,
compatibility: 'strict',
...(process.env.OPENAI_BASE_URL?.includes('openrouter.ai') && {
fetch: async (url, options) => {
if (!options?.body) return fetch(url, options);

const body = JSON.parse(options.body as string);

const modifiedBody = {
...body,
provider: {
order: ['Anthropic', 'OpenAI'],
allow_fallbacks: false,
},
};

options.body = JSON.stringify(modifiedBody);

return fetch(url, options);
},
}),
});
const openai = (modelName: string) => {
const shouldOverrideFetch =
process.env.OPENAI_BASE_URL?.includes('openrouter.ai') &&
modelName.includes('anthropic');
return createOpenAI({
baseURL: process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1',
apiKey: process.env.OPENAI_API_KEY,
compatibility: 'strict',
...(shouldOverrideFetch && {
fetch: async (url, options) => {
if (!options?.body) return fetch(url, options);

const body = JSON.parse(options.body as string);

const modifiedBody = {
...body,
provider: {
order: ['Anthropic'],
allow_fallbacks: false,
},
};

options.body = JSON.stringify(modifiedBody);

return fetch(url, options);
},
}),
})(modelName);
};

export const orchestratorModel = openai('gpt-4o-mini');

const openAiModel = openai(process.env.OPENAI_MODEL_NAME || 'gpt-4o');

export const defaultModel = usingAnthropic ? claude35Sonnet : openAiModel;

export const getModelFromPreference = (preference?: string) => {
if (!preference || !MODEL_PREFERENCE_MAP[preference]) return defaultModel;
return openai(preference);
};

export const defaultSystemPrompt = `
Your name is Neur (Agent).
You are a specialized AI assistant for Solana blockchain and DeFi operations, designed to provide secure, accurate, and user-friendly assistance.
Expand Down Expand Up @@ -111,8 +124,6 @@ Realtime knowledge:
- { approximateCurrentTime: ${new Date().toISOString()}}
`;

export const defaultModel = usingAnthropic ? claude35Sonnet : openAiModel;

export interface ToolConfig {
displayName?: string;
icon?: ReactNode;
Expand Down
147 changes: 78 additions & 69 deletions src/app/(user)/chat/[id]/chat-interface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ function ChatMessage({
(_, alt, src, width, height) => `![${alt}](${src}#size=${width}x${height})`,
);

// Some models have extra messages about tool outputs
const isToolOutputResponse = message.content?.startsWith('```tool_outputs');
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seemed to be limited to Gemini, not sure where it came from. It would be sent as a normal message


return (
<div
className={cn(
Expand Down Expand Up @@ -453,79 +456,85 @@ function ChatMessage({
</div>
)}

{message.content && (
<div
className={cn(
'relative flex flex-col gap-2 rounded-2xl px-4 py-3 text-sm shadow-sm',
isUser ? 'bg-primary' : 'bg-muted/60',
)}
>
{!isToolOutputResponse &&
message.content &&
message.content.trim() !== '' && (
<div
className={cn(
'prose prose-sm max-w-prose break-words leading-tight md:prose-base',
isUser
? 'prose-invert dark:prose-neutral'
: 'prose-neutral dark:prose-invert',
'relative flex flex-col gap-2 rounded-2xl px-4 py-3 text-sm shadow-sm',
isUser ? 'bg-primary' : 'bg-muted/60',
)}
>
<ReactMarkdown
rehypePlugins={[rehypeRaw]}
remarkPlugins={[remarkGfm]}
components={{
a: ({ node, ...props }) => (
<a {...props} target="_blank" rel="noopener noreferrer" />
),
img: ({ node, alt, src, ...props }) => {
if (!src) return null;

try {
// Handle both relative and absolute URLs safely
const url = new URL(src, 'http://dummy.com');
const size = url.hash.match(/size=(\d+)x(\d+)/);

if (size) {
const [, width, height] = size;
// Remove hash from src
url.hash = '';
return (
<Image
src={url.pathname + url.search}
alt={alt || ''}
width={Number(width)}
height={Number(height)}
className="inline-block align-middle"
/>
);
}
} catch (e) {
// If URL parsing fails, fallback to original src
console.warn('Failed to parse image URL:', e);
}

const thumbnailPattern = /_thumb\.(png|jpg|jpeg|gif)$/i;
const isThumbnail = thumbnailPattern.test(src);

const width = isThumbnail ? 40 : 500;
const height = isThumbnail ? 40 : 300;

// Fallback to Image component with default dimensions
return (
<Image
src={src}
alt={alt || ''}
width={width}
height={height}
className="inline-block align-middle"
/>
);
},
}}
<div
className={cn(
'prose prose-sm max-w-prose break-words leading-tight md:prose-base',
isUser
? 'prose-invert dark:prose-neutral'
: 'prose-neutral dark:prose-invert',
)}
>
{processedContent}
</ReactMarkdown>
<ReactMarkdown
rehypePlugins={[rehypeRaw]}
remarkPlugins={[remarkGfm]}
components={{
a: ({ node, ...props }) => (
<a
{...props}
target="_blank"
rel="noopener noreferrer"
/>
),
img: ({ node, alt, src, ...props }) => {
if (!src) return null;

try {
// Handle both relative and absolute URLs safely
const url = new URL(src, 'http://dummy.com');
const size = url.hash.match(/size=(\d+)x(\d+)/);

if (size) {
const [, width, height] = size;
// Remove hash from src
url.hash = '';
return (
<Image
src={url.pathname + url.search}
alt={alt || ''}
width={Number(width)}
height={Number(height)}
className="inline-block align-middle"
/>
);
}
} catch (e) {
// If URL parsing fails, fallback to original src
console.warn('Failed to parse image URL:', e);
}

const thumbnailPattern = /_thumb\.(png|jpg|jpeg|gif)$/i;
const isThumbnail = thumbnailPattern.test(src);

const width = isThumbnail ? 40 : 500;
const height = isThumbnail ? 40 : 300;

// Fallback to Image component with default dimensions
return (
<Image
src={src}
alt={alt || ''}
width={width}
height={height}
className="inline-block align-middle"
/>
);
},
}}
>
{processedContent}
</ReactMarkdown>
</div>
</div>
</div>
)}
)}

{message.toolInvocations && (
<MessageToolInvocations
Expand Down Expand Up @@ -717,7 +726,7 @@ export default function ChatInterface({

useEffect(() => {
scrollToBottom();
}, []);
}, [chatMessages]);

const handleSend = async (value: string, attachments: Attachment[]) => {
if (!value.trim() && (!attachments || attachments.length === 0)) {
Expand Down Expand Up @@ -750,7 +759,7 @@ export default function ChatInterface({
useAnimationEffect();

return (
<div className="flex h-full flex-col">
<div className="flex h-[calc(100%-theme(spacing.16))] flex-col">
<div className="no-scrollbar relative flex-1 overflow-y-auto">
<div className="mx-auto w-full max-w-3xl">
<div className="space-y-4 px-4 pb-36 pt-4">
Expand Down
11 changes: 10 additions & 1 deletion src/app/(user)/chat/[id]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Suspense } from 'react';
import { Metadata } from 'next';
import { notFound } from 'next/navigation';

import { ModelSelector } from '@/components/model-selector';
import { verifyUser } from '@/server/actions/user';
import {
dbGetConversation,
Expand Down Expand Up @@ -51,6 +52,7 @@ async function ChatData({ params }: { params: Promise<{ id: string }> }) {
// Verify user authentication and access rights
const authResponse = await verifyUser();
const userId = authResponse?.data?.data?.id;
const modelPreference = authResponse?.data?.data?.modelPreference;

// Check if user has access to private conversation
if (
Expand All @@ -69,7 +71,14 @@ async function ChatData({ params }: { params: Promise<{ id: string }> }) {
return notFound();
}

return <ChatInterface id={id} initialMessages={messagesFromDB} />;
return (
<>
<div className="my-2 flex w-full items-center justify-center md:ml-2 md:block">
<ModelSelector selectedModelId={modelPreference ?? undefined} />
</div>
<ChatInterface id={id} initialMessages={messagesFromDB} />
</>
);
}

/**
Expand Down
14 changes: 13 additions & 1 deletion src/app/(user)/home/home-content.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import { Attachment, JSONValue } from 'ai';
import { useChat } from 'ai/react';
import { CheckCircle2, Loader2 } from 'lucide-react';
import { toast } from 'sonner';
import { mutate } from 'swr';
import { v4 as uuidv4 } from 'uuid';

import ChatInterface from '@/app/(user)/chat/[id]/chat-interface';
import { ModelSelector } from '@/components/model-selector';
import { SavedPromptsMenu } from '@/components/saved-prompts-menu';
import { Badge } from '@/components/ui/badge';
import BlurFade from '@/components/ui/blur-fade';
Expand Down Expand Up @@ -268,7 +270,6 @@ export function HomeContent() {
}
}, [pathname, resetChat]);

// 监听浏览器的前进后退
useEffect(() => {
const handlePopState = () => {
if (location.pathname === '/home') {
Expand Down Expand Up @@ -390,6 +391,11 @@ export function HomeContent() {

<div className="mx-auto w-full max-w-3xl space-y-8">
<BlurFade delay={0.1}>
<ModelSelector
className="mb-4"
selectedModelId={user?.modelPreference ?? undefined}
mutate={() => mutate(`user-${user?.privyId}`)}
/>
<ConversationInput
value={input}
onChange={setInput}
Expand Down Expand Up @@ -687,6 +693,12 @@ export function HomeContent() {
showChat ? 'opacity-100' : 'pointer-events-none opacity-0',
)}
>
<div className="my-2 flex w-full items-center justify-center md:ml-2 md:block">
<ModelSelector
selectedModelId={user?.modelPreference ?? undefined}
mutate={() => mutate(`user-${user?.privyId}`)}
/>
</div>
<ChatInterface id={chatId} initialMessages={messages} />
</div>
)}
Expand Down
4 changes: 3 additions & 1 deletion src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
defaultModel,
defaultSystemPrompt,
defaultTools,
getModelFromPreference,
getToolsFromRequiredTools,
} from '@/ai/providers';
import { MAX_TOKEN_MESSAGES } from '@/lib/constants';
Expand Down Expand Up @@ -47,6 +48,7 @@ export async function POST(req: Request) {
const userId = session?.data?.data?.id;
const publicKey = session?.data?.data?.publicKey;
const degenMode = session?.data?.data?.degenMode;
const modelPreference = session?.data?.data?.modelPreference;

if (!userId) {
return new Response('Unauthorized', { status: 401 });
Expand Down Expand Up @@ -193,7 +195,7 @@ export async function POST(req: Request) {

// Begin streaming text from the model
const result = streamText({
model: defaultModel,
model: getModelFromPreference(modelPreference),
system: systemPrompt,
tools: tools as Record<string, CoreTool<any, any>>,
experimental_toolCallStreaming: true,
Expand Down
Loading