From f7281504f1666efe4f8386f04dd3fdfe57ded3d2 Mon Sep 17 00:00:00 2001 From: Roberto Echeagaray Date: Wed, 15 Apr 2026 02:29:53 -0600 Subject: [PATCH 1/2] feat: TTS humanization pipeline, rating system, and STT upgrade - Upgrade default Whisper model to large-v3-mlx (Apple Silicon) / large-v3 (PyTorch) for better Spanish transcription; fix missing preprocessor_config.json with fallback to turbo processor - Add paralinguistic tags (sniff, shush, whimper, scream, whisper) to tag router PARA_TAGS set - Add thumbs up/down rating system on history rows; rating + sampling params stored per generation; GET /profiles/{id}/suggested-params returns averaged best params after 3+ high-rated generations - Show all 5 sampling params (temperature, top_k, top_p, repetition_penalty, speed) in history row badge popover with Reuse button that restores text + params to generation form - Add Top-K, Top-P, Rep. Penalty sliders to FloatingGenerateBox Advanced popover so those fields are properly saved - Add breath_injection, hybrid_generate, tag_router, text_preprocess utility modules for TTS humanization pipeline Co-Authored-By: Claude Sonnet 4.6 --- .../Generation/FloatingGenerateBox.tsx | 371 +++++++++++++++++- .../components/Generation/GenerationForm.tsx | 303 +++++++++++++- app/src/components/History/HistoryTable.tsx | 158 +++++++- .../VoiceProfiles/AudioSampleRecording.tsx | 205 ++++++++-- .../components/VoiceProfiles/ProfileForm.tsx | 11 +- .../components/VoiceProfiles/SampleUpload.tsx | 17 +- app/src/lib/api/client.ts | 13 + app/src/lib/api/types.ts | 29 +- app/src/lib/hooks/useAudioRecording.ts | 2 +- app/src/lib/hooks/useGenerationForm.ts | 27 ++ app/src/stores/generationStore.ts | 19 + backend/app.py | 12 + backend/backends/__init__.py | 91 ++++- backend/backends/chatterbox_turbo_backend.py | 10 +- backend/backends/mlx_backend.py | 116 +++++- backend/backends/pytorch_backend.py | 2 +- backend/database/migrations.py | 14 + backend/database/models.py | 7 + backend/models.py | 39 +- backend/routes/generations.py | 53 +++ backend/routes/profiles.py | 55 +++ backend/routes/transcription.py | 20 +- backend/services/generation.py | 77 +++- backend/services/history.py | 16 + backend/utils/audio.py | 60 ++- backend/utils/breath_injection.py | 98 +++++ backend/utils/chunked_tts.py | 45 ++- backend/utils/effects.py | 205 ++++++++++ backend/utils/hybrid_generate.py | 125 ++++++ backend/utils/tag_router.py | 58 +++ backend/utils/text_preprocess.py | 117 ++++++ 31 files changed, 2280 insertions(+), 95 deletions(-) create mode 100644 backend/utils/breath_injection.py create mode 100644 backend/utils/hybrid_generate.py create mode 100644 backend/utils/tag_router.py create mode 100644 backend/utils/text_preprocess.py diff --git a/app/src/components/Generation/FloatingGenerateBox.tsx b/app/src/components/Generation/FloatingGenerateBox.tsx index f1cd571d..5ab0feb4 100644 --- a/app/src/components/Generation/FloatingGenerateBox.tsx +++ b/app/src/components/Generation/FloatingGenerateBox.tsx @@ -1,10 +1,12 @@ import { useQuery } from '@tanstack/react-query'; import { useMatchRoute } from '@tanstack/react-router'; import { AnimatePresence, motion } from 'framer-motion'; -import { Loader2, Sparkles } from 'lucide-react'; +import { CheckCircle, Loader2, SlidersHorizontal, Sparkles } from 'lucide-react'; import { useEffect, useRef, useState } from 'react'; import { Button } from '@/components/ui/button'; +import { Checkbox } from '@/components/ui/checkbox'; import { Form, FormControl, FormField, FormItem, FormMessage } from '@/components/ui/form'; +import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/popover'; import { Select, SelectContent, @@ -12,6 +14,7 @@ import { SelectTrigger, SelectValue, } from '@/components/ui/select'; +import { Slider } from '@/components/ui/slider'; import { Textarea } from '@/components/ui/textarea'; import { apiClient } from '@/lib/api/client'; import { getLanguageOptionsForEngine, type LanguageCode } from '@/lib/constants/languages'; @@ -40,15 +43,19 @@ export function FloatingGenerateBox({ const { data: selectedProfile } = useProfile(selectedProfileId || ''); const { data: profiles } = useProfiles(); const [isExpanded, setIsExpanded] = useState(false); + const [showAdvanced, setShowAdvanced] = useState(false); const [selectedPresetId, setSelectedPresetId] = useState(null); const containerRef = useRef(null); const textareaRef = useRef(null); + const reuseEffectsChainRef = useRef(null); const matchRoute = useMatchRoute(); const isStoriesRoute = matchRoute({ to: '/stories' }); const selectedStoryId = useStoryStore((state) => state.selectedStoryId); const trackEditorHeight = useStoryStore((state) => state.trackEditorHeight); const { data: currentStory } = useStory(selectedStoryId); const addPendingStoryAdd = useGenerationStore((s) => s.addPendingStoryAdd); + const reuseParams = useGenerationStore((s) => s.reuseParams); + const setReuseParams = useGenerationStore((s) => s.setReuseParams); // Fetch effect presets for the dropdown const { data: effectPresets } = useQuery({ @@ -56,6 +63,13 @@ export function FloatingGenerateBox({ queryFn: () => apiClient.listEffectPresets(), }); + // Fetch suggested params for the selected profile + const { data: suggestedParams } = useQuery({ + queryKey: ['suggestedParams', selectedProfileId], + queryFn: () => apiClient.getSuggestedParams(selectedProfileId!), + enabled: !!selectedProfileId, + }); + // Calculate if track editor is visible (on stories route with items) const hasTrackEditor = isStoriesRoute && currentStory && currentStory.items.length > 0; @@ -73,6 +87,10 @@ export function FloatingGenerateBox({ if (selectedPresetId === '_profile') { return selectedProfile?.effects_chain ?? undefined; } + // Effects chain reused from history (no matching preset) + if (selectedPresetId === '_reuse') { + return reuseEffectsChainRef.current ?? undefined; + } if (!effectPresets) return undefined; const preset = effectPresets.find((p) => p.id === selectedPresetId); return preset?.effects_chain; @@ -217,6 +235,63 @@ export function FloatingGenerateBox({ }; }, [isExpanded]); + // Apply params from history "Reuse" button + useEffect(() => { + if (!reuseParams) return; + form.setValue('text', reuseParams.text); + if (reuseParams.language) form.setValue('language', reuseParams.language as LanguageCode); + if (reuseParams.engine) + form.setValue( + 'engine', + reuseParams.engine as + | 'qwen' + | 'qwen_custom_voice' + | 'luxtts' + | 'chatterbox' + | 'chatterbox_turbo' + | 'tada' + | 'kokoro', + ); + if (reuseParams.temperature != null) form.setValue('temperature', reuseParams.temperature); + if (reuseParams.top_k != null) form.setValue('top_k', Math.round(reuseParams.top_k)); + if (reuseParams.top_p != null) form.setValue('top_p', reuseParams.top_p); + if (reuseParams.repetition_penalty != null) + form.setValue('repetition_penalty', reuseParams.repetition_penalty); + if (reuseParams.speed != null) form.setValue('speed', reuseParams.speed); + // Apply effects chain if present + if (reuseParams.effects_chain && reuseParams.effects_chain.length > 0) { + reuseEffectsChainRef.current = reuseParams.effects_chain; + if (effectPresets) { + const chainJson = JSON.stringify(reuseParams.effects_chain); + const matchingPreset = effectPresets.find( + (p) => JSON.stringify(p.effects_chain) === chainJson, + ); + if (matchingPreset) { + setSelectedPresetId(matchingPreset.id); + } else { + // No matching preset — use sentinel so getEffectsChain returns the stored chain + setSelectedPresetId('_reuse'); + } + } else { + setSelectedPresetId('_reuse'); + } + } else { + reuseEffectsChainRef.current = null; + } + setIsExpanded(true); + // Consume the params so this effect doesn't re-fire + setReuseParams(null); + }, [reuseParams]); // eslint-disable-line react-hooks/exhaustive-deps + + function applySuggestedParams() { + if (!suggestedParams) return; + if (suggestedParams.temperature != null) form.setValue('temperature', suggestedParams.temperature); + if (suggestedParams.top_k != null) form.setValue('top_k', Math.round(suggestedParams.top_k)); + if (suggestedParams.top_p != null) form.setValue('top_p', suggestedParams.top_p); + if (suggestedParams.repetition_penalty != null) form.setValue('repetition_penalty', suggestedParams.repetition_penalty); + if (suggestedParams.speed != null) form.setValue('speed', suggestedParams.speed); + } + async function onSubmit(data: Parameters[0]) { await handleSubmit(data, selectedProfileId); } @@ -262,7 +337,7 @@ export function FloatingGenerateBox({ transition={{ duration: 0.15, ease: 'easeOut' }} style={{ overflow: 'hidden' }} > - {form.watch('engine') === 'chatterbox_turbo' ? ( + {(form.watch('engine') === 'chatterbox_turbo' || form.watch('engine') === 'qwen') ? ( -
+
+ {/* Settings / Advanced popover */} + + + + + +

Advanced settings

+ + {/* Row 1: Temperature + Speed */} +
+ ( + +
+ + + {field.value?.toFixed(2) ?? '—'} + +
+ + field.onChange(v)} + className="h-3" + /> + +
+ )} + /> + ( + +
+ + + {field.value?.toFixed(2) ?? '—'} + +
+ + field.onChange(v)} + className="h-3" + /> + +
+ )} + /> +
+ + {/* Row 2: Top-K + Top-P */} +
+ ( + +
+ + + {field.value !== undefined ? field.value : '—'} + +
+ + field.onChange(v)} + className="h-3" + /> + +
+ )} + /> + ( + +
+ + + {field.value?.toFixed(2) ?? '—'} + +
+ + field.onChange(v)} + className="h-3" + /> + +
+ )} + /> +
+ + {/* Row 3: Repetition Penalty (half-width, left column) */} +
+ ( + +
+ + + {field.value?.toFixed(2) ?? '—'} + +
+ + field.onChange(v)} + className="h-3" + /> + +
+ )} + /> +
+ + {/* Row 5: Humanize text + intensity */} +
+ ( + + +
+ + +
+
+
+ )} + /> + ( + + + + + + )} + /> +
+ + {/* Row 6: Inject breaths + Jitter */} +
+ ( + + +
+ + +
+
+
+ )} + /> + ( + +
+ + + {field.value !== undefined ? `${field.value}ms` : '—'} + +
+ + field.onChange(v)} + className="h-3" + /> + +
+ )} + /> +
+
+
+ + {/* Generate button */}
+
+ )}
{showVoiceSelector && (
@@ -414,9 +766,10 @@ export function FloatingGenerateBox({
+ diff --git a/app/src/components/Generation/GenerationForm.tsx b/app/src/components/Generation/GenerationForm.tsx index ef3ff2c0..060d273c 100644 --- a/app/src/components/Generation/GenerationForm.tsx +++ b/app/src/components/Generation/GenerationForm.tsx @@ -1,7 +1,8 @@ -import { useEffect } from 'react'; -import { Loader2, Mic } from 'lucide-react'; +import { useEffect, useState } from 'react'; +import { ChevronDown, ChevronUp, Loader2, Mic } from 'lucide-react'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { Checkbox } from '@/components/ui/checkbox'; import { Form, FormControl, @@ -19,6 +20,7 @@ import { SelectTrigger, SelectValue, } from '@/components/ui/select'; +import { Slider } from '@/components/ui/slider'; import { Textarea } from '@/components/ui/textarea'; import { getLanguageOptionsForEngine, type LanguageCode } from '@/lib/constants/languages'; import { useGenerationForm } from '@/lib/hooks/useGenerationForm'; @@ -39,6 +41,7 @@ export function GenerationForm() { const { data: selectedProfile } = useProfile(selectedProfileId || ''); const { form, handleSubmit, isPending } = useGenerationForm(); + const [advancedOpen, setAdvancedOpen] = useState(false); useEffect(() => { if (!selectedProfile) { @@ -59,6 +62,11 @@ export function GenerationForm() { await handleSubmit(data, selectedProfileId); } + const engine = form.watch('engine'); + const humanizeText = form.watch('humanize_text'); + const showParalinguistic = engine === 'chatterbox_turbo' || engine === 'qwen'; + const showSpeed = engine === 'qwen' || engine === 'qwen_custom_voice'; + return ( @@ -89,7 +97,7 @@ export function GenerationForm() { Text to Speak - {form.watch('engine') === 'chatterbox_turbo' ? ( + {showParalinguistic ? ( - {form.watch('engine') === 'chatterbox_turbo' - ? 'Max 5000 characters. Type / to insert sound effects.' - : 'Max 5000 characters'} + {showParalinguistic ? ( + <> + Max 5000 characters. Type / to insert sound effects. + {engine === 'qwen' && ( + + Tags like [laugh] route to Chatterbox Turbo internally. + + )} + + ) : ( + 'Max 5000 characters' + )} )} /> - {(form.watch('engine') === 'qwen' || form.watch('engine') === 'qwen_custom_voice') && ( + {(engine === 'qwen' || engine === 'qwen_custom_voice') && ( Model - {getEngineDescription(form.watch('engine') || 'qwen')} + {getEngineDescription(engine || 'qwen')} @@ -151,7 +168,7 @@ export function GenerationForm() { control={form.control} name="language" render={({ field }) => { - const engineLangs = getLanguageOptionsForEngine(form.watch('engine') || 'qwen'); + const engineLangs = getLanguageOptionsForEngine(engine || 'qwen'); return ( Language @@ -198,6 +215,274 @@ export function GenerationForm() { />
+ {/* Advanced section */} +
+ + + {advancedOpen && ( +
+ {/* Sampling Parameters */} +
+

+ Sampling +

+ + ( + +
+ Temperature + + {field.value ?? '—'} + +
+ + field.onChange(v)} + /> + + 0.0 – 2.0 · default ~0.9 +
+ )} + /> + + ( + +
+ Top-P + + {field.value ?? '—'} + +
+ + field.onChange(v)} + /> + + 0.0 – 1.0 +
+ )} + /> + + ( + +
+ Repetition Penalty + + {field.value ?? '—'} + +
+ + field.onChange(v)} + /> + + 0.5 – 3.0 +
+ )} + /> + + ( + + Top-K + + + field.onChange( + e.target.value ? parseInt(e.target.value, 10) : undefined, + ) + } + /> + + 0 – 5000 + + )} + /> + + {showSpeed && ( + ( + +
+ Speed + + {field.value !== undefined ? `${field.value}×` : '—'} + +
+ + field.onChange(v)} + /> + + 0.5× – 2.0× +
+ )} + /> + )} +
+ + {/* Humanization */} +
+

+ Humanization +

+ + ( + +
+ + + + + Humanize text + +
+ + Pre-process text with LLM to add natural speech patterns + +
+ )} + /> + + {humanizeText && ( + ( + + Intensity +
+ {(['light', 'medium', 'heavy'] as const).map((level) => ( + + ))} +
+ +
+ )} + /> + )} + + ( + +
+ + + + + Inject breaths + +
+ + Insert natural breath sounds between sentences + +
+ )} + /> + + ( + +
+ Timing jitter + + {field.value !== undefined ? `${field.value} ms` : '—'} + +
+ + field.onChange(v)} + /> + + + Random timing offset per chunk (0 – 50 ms) + +
+ )} + /> +
+
+ )} +
+ + + +

Generation params

+ {gen.temperature != null && ( +
+ temp + {gen.temperature.toFixed(2)} +
+ )} + {gen.speed != null && ( +
+ speed + {gen.speed.toFixed(2)} +
+ )} + {gen.top_k != null && ( +
+ top_k + {gen.top_k} +
+ )} + {gen.top_p != null && ( +
+ top_p + {gen.top_p.toFixed(2)} +
+ )} + {gen.repetition_penalty != null && ( +
+ rep_penalty + {gen.repetition_penalty.toFixed(2)} +
+ )} + {effectsChain && effectsChain.length > 0 && ( +
+ effects + {effectsChain.map((e) => e.type).join(', ')} +
+ )} +
+ + ); +} + // ─── Audio Bars ───────────────────────────────────────────────────────────── function AudioBars({ mode }: { mode: 'idle' | 'generating' | 'playing' }) { @@ -128,6 +201,7 @@ export function HistoryTable() { const exportGenerationAudio = useExportGenerationAudio(); const importGeneration = useImportGeneration(); const addPendingGeneration = useGenerationStore((state) => state.addPendingGeneration); + const setReuseParams = useGenerationStore((state) => state.setReuseParams); const setAudioWithAutoPlay = usePlayerStore((state) => state.setAudioWithAutoPlay); const restartCurrentAudio = usePlayerStore((state) => state.restartCurrentAudio); const currentAudioId = usePlayerStore((state) => state.audioId); @@ -307,6 +381,23 @@ export function HistoryTable() { } }; + const handleRate = async (generationId: string, rating: number, currentRating: number | null | undefined) => { + // Clicking the same rating again clears it — not supported by the API directly, + // but we can send rating 0 — backend won't accept it, so we just skip toggling for now. + // If user clicks the already-active thumb, we do nothing. + if (currentRating === rating) return; + try { + await apiClient.rateGeneration(generationId, rating); + queryClient.invalidateQueries({ queryKey: ['history'] }); + } catch (error) { + toast({ + title: 'Failed to rate generation', + description: error instanceof Error ? error.message : 'Unknown error', + variant: 'destructive', + }); + } + }; + const handleApplyEffects = (generationId: string) => { const gen = allHistory.find((g) => g.id === generationId); const versions = gen?.versions ?? []; @@ -357,6 +448,25 @@ export function HistoryTable() { } }; + const handleReuseParams = (gen: HistoryResponse) => { + // Get effects chain from the active/default version of this generation + const activeVersion = gen.versions?.find((v) => v.is_default) ?? gen.versions?.[0]; + const effectsChain = activeVersion?.effects_chain ?? null; + setReuseParams({ + text: gen.text, + language: gen.language, + engine: gen.engine ?? undefined, + model_size: gen.model_size ?? undefined, + temperature: gen.temperature, + top_k: gen.top_k, + top_p: gen.top_p, + repetition_penalty: gen.repetition_penalty, + speed: gen.speed, + effects_chain: effectsChain, + }); + toast({ title: 'Params applied', description: 'Generation settings loaded into the form.' }); + }; + const handleSwitchVersion = async (generationId: string, versionId: string) => { try { await apiClient.setDefaultVersion(generationId, versionId); @@ -538,6 +648,17 @@ export function HistoryTable() { onMouseDown={(e) => e.stopPropagation()} onClick={(e) => e.stopPropagation()} > + + + {/* Rating: thumbs up (5) / thumbs down (1) */} + + {hasVersions && (
)} + {/* RECORDING: interactive guide */} {isRecording && ( -
+
{showWaveform && audioStream && ( )} -
+ + {/* Timer row */} +
{formatAudioDuration(duration)}
+ + {formatAudioDuration(40 - duration)} left + +
+ + {/* Progress dots */} +
+ {SCRIPT_LINES.map((_, i) => ( +
+ ))} + + {currentLineIndex + 1}/{SCRIPT_LINES.length} + +
+ + {/* Scrollable script */} +
+ {SCRIPT_LINES.map((line, i) => { + const isCurrent = i === currentLineIndex; + const isPast = i < currentLineIndex; + + return ( +
+

+ [{line.cue}] +

+

+ {line.text} +

+
+ ); + })} +
+ + {/* Tips */} +
+

⚠ Say it with real intention, not like a tongue twister

+

⚠ Vary volume, rhythm and emotion between lines

+
+ + {/* Controls */} +
+ +
- -

- {formatAudioDuration(30 - duration)} remaining -

)} + {/* POST-RECORDING: completion state — unchanged */} {file && !isRecording && (
@@ -139,6 +291,7 @@ export function AudioSampleRecording({ Recording complete

File: {file.name}

+

Transcript auto-filled from guide

-
)} + {gen.humanize_text === true && ( +
+ humanize + {gen.humanize_intensity ?? 'on'} +
+ )} + {gen.jitter_ms != null && gen.jitter_ms > 0 && ( +
+ jitter + {gen.jitter_ms}ms +
+ )} {effectsChain && effectsChain.length > 0 && (
effects diff --git a/app/src/lib/api/types.ts b/app/src/lib/api/types.ts index deaea7ef..2ee447d2 100644 --- a/app/src/lib/api/types.ts +++ b/app/src/lib/api/types.ts @@ -117,6 +117,9 @@ export interface GenerationResponse { top_p?: number | null; repetition_penalty?: number | null; speed?: number | null; + humanize_text?: boolean; + humanize_intensity?: string | null; + jitter_ms?: number | null; created_at: string; versions?: GenerationVersionResponse[]; active_version_id?: string; @@ -132,6 +135,7 @@ export interface SuggestedParams { top_p?: number | null; repetition_penalty?: number | null; speed?: number | null; + n_samples?: number; } export interface HistoryQuery { diff --git a/backend/database/migrations.py b/backend/database/migrations.py index 9f682e50..d9088172 100644 --- a/backend/database/migrations.py +++ b/backend/database/migrations.py @@ -178,6 +178,12 @@ def _migrate_generations(engine, inspector, tables: set[str]) -> None: _add_column(engine, "generations", "repetition_penalty FLOAT", "repetition_penalty") if "speed" not in columns: _add_column(engine, "generations", "speed FLOAT", "speed") + if "humanize_text" not in columns: + _add_column(engine, "generations", "humanize_text BOOLEAN DEFAULT 0", "humanize_text") + if "humanize_intensity" not in columns: + _add_column(engine, "generations", "humanize_intensity VARCHAR", "humanize_intensity") + if "jitter_ms" not in columns: + _add_column(engine, "generations", "jitter_ms INTEGER DEFAULT 0", "jitter_ms") def _migrate_effect_presets(engine, inspector, tables: set[str]) -> None: diff --git a/backend/database/models.py b/backend/database/models.py index 7cd6afe9..a9e358f5 100644 --- a/backend/database/models.py +++ b/backend/database/models.py @@ -74,6 +74,10 @@ class Generation(Base): top_p = Column(Float, nullable=True) repetition_penalty = Column(Float, nullable=True) speed = Column(Float, nullable=True) + # Humanize / jitter settings + humanize_text = Column(Boolean, default=False) + humanize_intensity = Column(String, nullable=True) + jitter_ms = Column(Integer, default=0) created_at = Column(DateTime, default=datetime.utcnow) diff --git a/backend/models.py b/backend/models.py index 482ab429..fffb8ff6 100644 --- a/backend/models.py +++ b/backend/models.py @@ -122,6 +122,9 @@ class GenerationResponse(BaseModel): top_p: Optional[float] = None repetition_penalty: Optional[float] = None speed: Optional[float] = None + humanize_text: bool = False + humanize_intensity: Optional[str] = None + jitter_ms: int = 0 created_at: datetime versions: Optional[List["GenerationVersionResponse"]] = None active_version_id: Optional[str] = None @@ -162,6 +165,9 @@ class HistoryResponse(BaseModel): top_p: Optional[float] = None repetition_penalty: Optional[float] = None speed: Optional[float] = None + humanize_text: bool = False + humanize_intensity: Optional[str] = None + jitter_ms: int = 0 created_at: datetime versions: Optional[List["GenerationVersionResponse"]] = None active_version_id: Optional[str] = None @@ -184,6 +190,7 @@ class SuggestedParams(BaseModel): top_p: Optional[float] = None repetition_penalty: Optional[float] = None speed: Optional[float] = None + n_samples: int = 0 class HistoryListResponse(BaseModel): diff --git a/backend/routes/generations.py b/backend/routes/generations.py index 16e64a64..bca0831c 100644 --- a/backend/routes/generations.py +++ b/backend/routes/generations.py @@ -85,6 +85,9 @@ async def generate_speech( top_p=data.top_p, repetition_penalty=data.repetition_penalty, speed=data.speed, + humanize_text=data.humanize_text, + humanize_intensity=data.humanize_intensity, + jitter_ms=data.jitter_ms, ) task_manager.start_generation( diff --git a/backend/routes/profiles.py b/backend/routes/profiles.py index e61345d9..1b938917 100644 --- a/backend/routes/profiles.py +++ b/backend/routes/profiles.py @@ -313,17 +313,18 @@ async def get_suggested_params( profile_id: str, db: Session = Depends(get_db), ): - """Return averaged sampling params from generations with rating >= 4 for this profile. + """Return exponentially-weighted averaged sampling params from generations with rating >= 4. - Returns None if fewer than 3 highly-rated generations exist. + Weights more recent likes higher: w_i = exp(-0.1 * i) where i=0 is most recent. + Returns None only if zero liked generations exist. """ - from sqlalchemy import func + import math profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: raise HTTPException(status_code=404, detail="Profile not found") - # Query all highly-rated generations that have at least one non-null sampling param + # Order by most recent first so index 0 = newest like rated = ( db.query(DBGeneration) .filter( @@ -331,22 +332,30 @@ async def get_suggested_params( DBGeneration.rating >= 4, DBGeneration.status == "completed", ) + .order_by(DBGeneration.created_at.desc()) .all() ) - if len(rated) < 3: + if not rated: return None - # Average across generations that have each param set - def _avg(attr: str) -> float | None: - vals = [getattr(g, attr) for g in rated if getattr(g, attr) is not None] - return sum(vals) / len(vals) if vals else None - - temperature = _avg("temperature") - top_k_avg = _avg("top_k") - top_p = _avg("top_p") - repetition_penalty = _avg("repetition_penalty") - speed = _avg("speed") + # Weighted average using exponential decay; skip nulls per param + def _weighted_avg(attr: str) -> float | None: + pairs = [ + (math.exp(-0.1 * i), getattr(g, attr)) + for i, g in enumerate(rated) + if getattr(g, attr) is not None + ] + if not pairs: + return None + total_w = sum(w for w, _ in pairs) + return sum(w * v for w, v in pairs) / total_w + + temperature = _weighted_avg("temperature") + top_k_avg = _weighted_avg("top_k") + top_p = _weighted_avg("top_p") + repetition_penalty = _weighted_avg("repetition_penalty") + speed = _weighted_avg("speed") # If none of the params are available across all rated gens, no suggestion if all(v is None for v in (temperature, top_k_avg, top_p, repetition_penalty, speed)): @@ -358,6 +367,7 @@ def _avg(attr: str) -> float | None: top_p=top_p, repetition_penalty=repetition_penalty, speed=speed, + n_samples=len(rated), ) diff --git a/backend/services/history.py b/backend/services/history.py index 982543a1..9002177a 100644 --- a/backend/services/history.py +++ b/backend/services/history.py @@ -70,6 +70,9 @@ async def create_generation( top_p: Optional[float] = None, repetition_penalty: Optional[float] = None, speed: Optional[float] = None, + humanize_text: bool = False, + humanize_intensity: Optional[str] = None, + jitter_ms: int = 0, ) -> GenerationResponse: """ Create a new generation history entry. @@ -108,6 +111,9 @@ async def create_generation( top_p=top_p, repetition_penalty=repetition_penalty, speed=speed, + humanize_text=humanize_text, + humanize_intensity=humanize_intensity, + jitter_ms=jitter_ms, created_at=datetime.utcnow(), ) @@ -234,6 +240,9 @@ async def list_generations( top_p=generation.top_p, repetition_penalty=generation.repetition_penalty, speed=generation.speed, + humanize_text=bool(generation.humanize_text), + humanize_intensity=generation.humanize_intensity, + jitter_ms=generation.jitter_ms or 0, created_at=generation.created_at, versions=versions, active_version_id=active_version_id,