diff --git a/frontend/src/components/config/ConfigForm.tsx b/frontend/src/components/config/ConfigForm.tsx index 3710dd45..e75243d7 100644 --- a/frontend/src/components/config/ConfigForm.tsx +++ b/frontend/src/components/config/ConfigForm.tsx @@ -146,6 +146,13 @@ const ConfigForm: React.FC = ({ config, onChange }) => { onChange={onChange} label={t("config.sections.api.sanitizeMessages")} /> + + diff --git a/frontend/src/locales/en/translation.json b/frontend/src/locales/en/translation.json index b5de9308..affcce7c 100644 --- a/frontend/src/locales/en/translation.json +++ b/frontend/src/locales/en/translation.json @@ -152,7 +152,8 @@ "preserveChats": "Preserve Chats", "webSearch": "Web Search", "webCountTokens": "Enable web count_tokens", - "sanitizeMessages": "Sanitize messages (trim whitespace)" + "sanitizeMessages": "Sanitize messages (trim whitespace)", + "reuseConversation": "Reuse Conversations (reduce input tokens)" }, "cookie": { "title": "Cookie Settings", diff --git a/frontend/src/locales/zh/translation.json b/frontend/src/locales/zh/translation.json index 9410ed70..4cf660a1 100644 --- a/frontend/src/locales/zh/translation.json +++ b/frontend/src/locales/zh/translation.json @@ -152,7 +152,8 @@ "preserveChats": "保留聊天", "webSearch": "网页搜索", "webCountTokens": "允许 Web 渠道调用 count_tokens", - "sanitizeMessages": "消息清理(去除空白)" + "sanitizeMessages": "消息清理(去除空白)", + "reuseConversation": "复用对话(减少输入 Token 消耗)" }, "cookie": { "title": "Cookie设置", diff --git a/frontend/src/types/config.types.ts b/frontend/src/types/config.types.ts index 4ce1fc87..c07beaea 100644 --- a/frontend/src/types/config.types.ts +++ b/frontend/src/types/config.types.ts @@ -21,6 +21,7 @@ export interface ConfigData { web_search: boolean; enable_web_count_tokens: boolean; sanitize_messages: boolean; + reuse_conversation: boolean; // Cookie settings skip_first_warning: boolean; diff --git a/src/claude_web_state/chat.rs b/src/claude_web_state/chat.rs index 5d25dac1..11baca22 100644 --- a/src/claude_web_state/chat.rs +++ b/src/claude_web_state/chat.rs @@ -1,18 +1,35 @@ +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + use colored::Colorize; use futures::TryFutureExt; use serde_json::json; use snafu::ResultExt; use tracing::{Instrument, debug, error, info, info_span, warn}; -use wreq::{Method, Response, header::ACCEPT}; +use wreq::{Method, Response, header::{ACCEPT, REFERER}}; -use super::ClaudeWebState; +use super::{ClaudeWebState, PendingCacheWrite}; use crate::{ - config::CLEWDR_CONFIG, + claude_web_state::conversation_cache::{CachedConversation, CachedTurn}, + claude_web_state::diff::{self, DiffResult, extract_user_hashes, hash_system}, + config::{CLAUDE_ENDPOINT, CLEWDR_CONFIG}, error::{CheckClaudeErr, ClewdrError, WreqSnafu}, - types::claude::CreateMessageParams, - utils::print_out_json, + types::claude::{ContentBlock, CreateMessageParams, ImageSource, Message, MessageContent}, + types::claude_web::request::TurnMessageUuids, + utils::{TIME_ZONE, print_out_json}, }; +/// Bundled user messages ready to be sent +struct BundledMessages { + /// Short content goes into prompt + prompt: String, + /// Long content goes into attachments + attachments: Vec, + /// Extracted images (if any) + #[allow(dead_code)] + images: Vec, +} + impl ClaudeWebState { /// Attempts to send a chat message to Claude API with retry mechanism /// @@ -43,6 +60,15 @@ impl ClaudeWebState { let mut state = self.to_owned(); let p = p.to_owned(); + // Create shared stream health flag for monitoring SSE completion + let can_reuse = CLEWDR_CONFIG.load().reuse_conversation + && !CLEWDR_CONFIG.load().preserve_chats; + if can_reuse { + let flag = Arc::new(AtomicBool::new(false)); + state.stream_health_flag = Some(flag.clone()); + self.stream_health_flag = Some(flag); + } + let cookie = state.request_cookie().await?; // check if request is successful let web_res = async { state.bootstrap().await.and(state.send_chat(p).await) }; @@ -52,12 +78,21 @@ impl ClaudeWebState { match transform_res.await { Ok(b) => { + // Commit pending cache write (optimistic) + if let Some(pending) = state.pending_cache_write.take() { + state.commit_cache_write(pending).await; + } + if let Err(e) = state.clean_chat().await { warn!("Failed to clean chat: {}", e); } return Ok(b); } Err(e) => { + // Invalidate cache on error + state.conv_cache.invalidate(&state.cache_key()).await; + state.pending_cache_write = None; + // delete chat after an error if let Err(e) = state.clean_chat().await { warn!("Failed to clean chat: {}", e); @@ -76,25 +111,110 @@ impl ClaudeWebState { Err(ClewdrError::TooManyRetries) } - /// Sends a message to the Claude API by creating a new conversation and processing the request - /// - /// This method performs several key operations: - /// - Creates a new conversation with a unique UUID - /// - Configures thinking mode if applicable - /// - Transforms the client request to the Claude API format - /// - Handles image uploads if present - /// - Sends the request to the Claude API endpoint - /// - /// The method properly manages conversation state, including creating a new conversation, - /// configuring its settings, and sending the actual message content. It handles special - /// features like thinking mode for Pro accounts and image uploads for multimodal requests. - /// - /// # Arguments - /// * `p` - The client request body containing messages and configuration - /// - /// # Returns - /// * `Result` - Response from Claude or error + /// Main entry point — tries cache reuse, falls back to full paste async fn send_chat(&mut self, p: CreateMessageParams) -> Result { + let _org_uuid = self + .org_uuid + .to_owned() + .ok_or(ClewdrError::UnexpectedNone { + msg: "Organization UUID is not set", + })?; + + let can_reuse = CLEWDR_CONFIG.load().reuse_conversation + && !CLEWDR_CONFIG.load().preserve_chats; + + if can_reuse { + if let Some(result) = self.try_reuse_conversation(&p).await { + match result { + Ok(response) => return Ok(response), + Err(e) => { + warn!("Reuse failed, falling back to full: {}", e); + // invalidate on failure + self.conv_cache.invalidate(&self.cache_key()).await; + // fall through to full path + } + } + } + // cache miss or reuse not possible → full path + // but this time we'll also write cache on success + } + + self.send_full(p, can_reuse).await + } + + /// Attempt to reuse a cached conversation + /// Returns None if no cache or not reusable + /// Returns Some(Ok(response)) on success + /// Returns Some(Err(e)) if reuse was attempted but failed + async fn try_reuse_conversation( + &mut self, + p: &CreateMessageParams, + ) -> Option> { + let key = self.cache_key(); + let cached = self.conv_cache.get(&key).await?; + + // Check stream health from previous request + if !self.conv_cache.is_last_stream_healthy(&key).await { + info!("[CACHE] last stream was unhealthy, invalidating"); + self.conv_cache.invalidate(&key).await; + return None; + } + + // Validate: cookie must match + if cached.cookie_id != self.cookie_id() { + info!("[CACHE] cookie mismatch, invalidating"); + self.conv_cache.invalidate(&key).await; + return None; + } + // Validate: model must match + if cached.model != p.model { + info!("[CACHE] model changed: {} → {}", cached.model, p.model); + self.conv_cache.invalidate(&key).await; + return None; + } + // Validate: is_pro must match + if cached.is_pro != self.is_pro() { + info!("[CACHE] pro status changed"); + self.conv_cache.invalidate(&key).await; + return None; + } + + // Extract user message hashes from new request + let user_hashes = extract_user_hashes(&p.messages); + let sys_hash = hash_system(&p.system); + + let diff = diff::diff_messages(&cached, sys_hash, &user_hashes); + + match diff { + DiffResult::Append { parent_uuid, new_user_indices, new_user_hashes } => { + info!("[CACHE HIT] appending {} new user message(s)", new_user_indices.len()); + let result = self.send_incremental( + &cached, &parent_uuid, &new_user_indices, &new_user_hashes, p + ).await; + Some(result) + } + DiffResult::Fork { parent_uuid, fork_turn_index, remaining_user_indices, remaining_user_hashes } => { + info!("[CACHE FORK] forking at turn {}, {} user message(s)", fork_turn_index, remaining_user_indices.len()); + let result = self.send_incremental_fork( + &cached, &parent_uuid, fork_turn_index, + &remaining_user_indices, &remaining_user_hashes, p + ).await; + Some(result) + } + DiffResult::FullRebuild => { + info!("[CACHE MISS] full rebuild required"); + self.conv_cache.invalidate(&key).await; + None // fall through to send_full + } + } + } + + /// Full paste path (existing logic + cache write on success) + async fn send_full( + &mut self, + p: CreateMessageParams, + write_cache: bool, + ) -> Result { let org_uuid = self .org_uuid .to_owned() @@ -102,8 +222,9 @@ impl ClaudeWebState { msg: "Organization UUID is not set", })?; - // Create a new conversation + // === Create new conversation === let new_uuid = uuid::Uuid::new_v4().to_string(); + let is_temporary = !CLEWDR_CONFIG.load().preserve_chats; let endpoint = self .endpoint .join(&format!( @@ -114,24 +235,24 @@ impl ClaudeWebState { let is_temporary = !CLEWDR_CONFIG.load().preserve_chats; let body = json!({ "uuid": new_uuid, - "name": if is_temporary { "".to_string() } else { format!("ClewdR-{}", chrono::Utc::now().format("%Y-%m-%d %H:%M:%S")) }, + "name": if is_temporary { "".to_string() } else { + format!("ClewdR-{}", chrono::Utc::now().format("%Y-%m-%d %H:%M:%S")) + }, "is_temporary": is_temporary, }); let referer = if is_temporary { - self.endpoint - .join("new?incognito") + self.endpoint.join("new?incognito") .map(|u| u.to_string()) - .unwrap_or_else(|_| format!("{}new?incognito", crate::config::CLAUDE_ENDPOINT)) + .unwrap_or_else(|_| format!("{CLAUDE_ENDPOINT}new?incognito")) } else { - self.endpoint - .join("new") + self.endpoint.join("new") .map(|u| u.to_string()) - .unwrap_or_else(|_| format!("{}new", crate::config::CLAUDE_ENDPOINT)) + .unwrap_or_else(|_| format!("{CLAUDE_ENDPOINT}new")) }; self.build_request(Method::POST, endpoint) - .header(wreq::header::REFERER, referer) + .header(REFERER, referer) .json(&body) .send() .await @@ -143,16 +264,14 @@ impl ClaudeWebState { self.conv_uuid = Some(new_uuid.to_string()); debug!("New conversation created: {}", new_uuid); - // preserve original params for possible post-call token accounting + // === PUT settings === self.last_params = Some(p.clone()); - let mut body = json!({}); - // enable thinking mode - body["settings"]["paprika_mode"] = if p.thinking.is_some() && self.is_pro() { + let paprika = if p.thinking.is_some() && self.is_pro() { "extended".into() } else { json!(null) }; - + let settings_body = json!({ "settings": { "paprika_mode": paprika } }); let endpoint = self .endpoint .join(&format!( @@ -162,16 +281,23 @@ impl ClaudeWebState { .expect("Url parse error"); let _ = self .build_request(Method::PUT, endpoint) - .json(&body) + .json(&settings_body) .send() .await; - // generate the request body - // check if the request is empty - let mut body = self.transform_request(p).ok_or(ClewdrError::BadRequest { + + // === Transform and send === + let mut body = self.transform_request(p.clone()).ok_or(ClewdrError::BadRequest { msg: "Request body is empty", })?; - // check images + // Generate turn_message_uuids + let human_uuid = uuid::Uuid::new_v4().to_string(); + let assistant_uuid = uuid::Uuid::new_v4().to_string(); + body.turn_message_uuids = Some(TurnMessageUuids { + human_message_uuid: human_uuid.clone(), + assistant_message_uuid: assistant_uuid.clone(), + }); + let images = body.images.drain(..).collect::>(); // upload images @@ -188,7 +314,8 @@ impl ClaudeWebState { )) .expect("Url parse error"); - self.build_request(Method::POST, endpoint) + let response = self + .build_request(Method::POST, endpoint) .json(&body) .header(ACCEPT, "text/event-stream") .send() @@ -197,6 +324,307 @@ impl ClaudeWebState { msg: "Failed to send chat request", })? .check_claude() + .await?; + + // === Prepare cache write === + if write_cache { + let user_hashes = extract_user_hashes(&p.messages) + .iter().map(|(_, h)| *h).collect(); + let sys_hash = hash_system(&p.system); + let stream_flag = self.stream_health_flag.clone() + .unwrap_or_else(|| Arc::new(AtomicBool::new(true))); + + self.pending_cache_write = Some(PendingCacheWrite::Init { + key: self.cache_key(), + conv: CachedConversation { + conv_uuid: new_uuid.clone(), + org_uuid: org_uuid.clone(), + cookie_id: self.cookie_id(), + model: p.model.clone(), + is_pro: self.is_pro(), + system_hash: sys_hash, + turns: vec![CachedTurn { + user_hashes, + assistant_uuid, + }], + created_at: chrono::Utc::now(), + last_used: chrono::Utc::now(), + valid: true, + last_stream_healthy: stream_flag, + }, + }); + } + + Ok(response) + } + + /// Incremental send — append new messages to existing conversation + async fn send_incremental( + &mut self, + cached: &CachedConversation, + parent_uuid: &str, + new_user_indices: &[usize], + new_user_hashes: &[u64], + p: &CreateMessageParams, + ) -> Result { + self.conv_uuid = Some(cached.conv_uuid.clone()); + self.last_params = Some(p.clone()); + + // Update paprika_mode if needed + let need_thinking = p.thinking.is_some() && self.is_pro(); + self.update_paprika(&cached.conv_uuid, need_thinking).await; + + // Extract new user messages from original messages array + let new_user_msgs: Vec<&Message> = new_user_indices.iter() + .map(|&idx| &p.messages[idx]) + .collect(); + + // Bundle user messages into prompt + optional attachment + let bundled = self.bundle_user_messages(&new_user_msgs); + + // Generate turn UUIDs + let human_uuid = uuid::Uuid::new_v4().to_string(); + let assistant_uuid = uuid::Uuid::new_v4().to_string(); + + let body = self.build_incremental_body( + &bundled, parent_uuid, &human_uuid, &assistant_uuid, p, + ); + + print_out_json(&body, "claude_web_incremental_req.json"); + + let endpoint = self + .endpoint + .join(&format!( + "api/organizations/{}/chat_conversations/{}/completion", + cached.org_uuid, cached.conv_uuid + )) + .expect("Url parse error"); + + let response = self + .build_request(Method::POST, endpoint) + .json(&body) + .header(ACCEPT, "text/event-stream") + .send() .await + .context(WreqSnafu { + msg: "Failed to send incremental chat", + })? + .check_claude() + .await?; + + // Prepare optimistic cache write + self.pending_cache_write = Some(PendingCacheWrite::AppendTurn { + key: self.cache_key(), + turn: CachedTurn { + user_hashes: new_user_hashes.to_vec(), + assistant_uuid, + }, + }); + + Ok(response) + } + + /// Incremental send with fork — edit scenario + async fn send_incremental_fork( + &mut self, + cached: &CachedConversation, + parent_uuid: &str, + fork_turn_index: usize, + remaining_user_indices: &[usize], + remaining_user_hashes: &[u64], + p: &CreateMessageParams, + ) -> Result { + self.conv_uuid = Some(cached.conv_uuid.clone()); + self.last_params = Some(p.clone()); + + // Update paprika_mode if needed + let need_thinking = p.thinking.is_some() && self.is_pro(); + self.update_paprika(&cached.conv_uuid, need_thinking).await; + + // Extract remaining user messages + let remaining_user_msgs: Vec<&Message> = remaining_user_indices.iter() + .map(|&idx| &p.messages[idx]) + .collect(); + + // Bundle all remaining user messages + let bundled = self.bundle_user_messages(&remaining_user_msgs); + + let human_uuid = uuid::Uuid::new_v4().to_string(); + let assistant_uuid = uuid::Uuid::new_v4().to_string(); + + let body = self.build_incremental_body( + &bundled, parent_uuid, &human_uuid, &assistant_uuid, p, + ); + + let endpoint = self + .endpoint + .join(&format!( + "api/organizations/{}/chat_conversations/{}/completion", + cached.org_uuid, cached.conv_uuid + )) + .expect("Url parse error"); + + let response = self + .build_request(Method::POST, endpoint) + .json(&body) + .header(ACCEPT, "text/event-stream") + .send() + .await + .context(WreqSnafu { + msg: "Failed to send forked chat", + })? + .check_claude() + .await?; + + // Prepare fork cache write + self.pending_cache_write = Some(PendingCacheWrite::ForkAndAppend { + key: self.cache_key(), + fork_turn_index, + turn: CachedTurn { + user_hashes: remaining_user_hashes.to_vec(), + assistant_uuid, + }, + }); + + Ok(response) + } + + /// PUT paprika_mode setting on existing conversation + async fn update_paprika(&self, conv_uuid: &str, need_thinking: bool) { + let paprika = if need_thinking { "extended".into() } else { json!(null) }; + let endpoint = self + .endpoint + .join(&format!( + "api/organizations/{}/chat_conversations/{}", + self.org_uuid.as_ref().unwrap(), conv_uuid + )) + .expect("Url parse error"); + let body = json!({ "settings": { "paprika_mode": paprika } }); + let _ = self + .build_request(Method::PUT, endpoint) + .json(&body) + .send() + .await; + } + + /// Build the completion request body for incremental sends + fn build_incremental_body( + &self, + bundled: &BundledMessages, + parent_uuid: &str, + human_uuid: &str, + assistant_uuid: &str, + p: &CreateMessageParams, + ) -> serde_json::Value { + let mut body = json!({ + "prompt": bundled.prompt, + "parent_message_uuid": parent_uuid, + "timezone": TIME_ZONE.to_string(), + "turn_message_uuids": { + "human_message_uuid": human_uuid, + "assistant_message_uuid": assistant_uuid, + }, + "attachments": bundled.attachments, + "files": [], + "rendering_mode": if p.stream.unwrap_or_default() { "messages" } else { "raw" }, + }); + // Model (only for pro) + if self.is_pro() { + body["model"] = json!(p.model); + } + // Tools (same as full request) + let mut tools = vec![]; + if CLEWDR_CONFIG.load().web_search { + tools.push(json!({"type": "web_search_v0", "name": "web_search"})); + } + if !tools.is_empty() { + body["tools"] = json!(tools); + } + body + } + + /// Merge user messages into prompt or attachment based on length + fn bundle_user_messages( + &self, + user_msgs: &[&Message], + ) -> BundledMessages { + let mut texts: Vec = vec![]; + let mut images: Vec = vec![]; + + for msg in user_msgs { + match &msg.content { + MessageContent::Text { content } => { + texts.push(content.trim().to_string()); + } + MessageContent::Blocks { content } => { + for block in content { + match block { + ContentBlock::Text { text, .. } => { + texts.push(text.trim().to_string()); + } + ContentBlock::Image { source, .. } => { + images.push(source.clone()); + } + _ => {} + } + } + } + } + } + + let combined = texts.join("\n\n"); + + // Threshold: if combined text is under ~4000 chars, use prompt directly + // Otherwise put it in an attachment + const PROMPT_THRESHOLD: usize = 4000; + + if combined.len() <= PROMPT_THRESHOLD { + BundledMessages { + prompt: combined, + attachments: vec![], + images, + } + } else { + // Use attachment for long content + // prompt gets the custom_prompt polyfill or a short summary + let p_str = CLEWDR_CONFIG.load().custom_prompt.clone(); + BundledMessages { + prompt: p_str, + attachments: vec![json!({ + "extracted_content": combined, + "file_name": "paste.txt", + "file_size": combined.len(), + "file_type": "text/plain", + })], + images, + } + } + } + + /// Execute a pending cache write + async fn commit_cache_write(&self, pending: PendingCacheWrite) { + match pending { + PendingCacheWrite::Init { key, conv } => { + info!("[CACHE] initialized for conv {}", conv.conv_uuid); + self.conv_cache.set(key, conv).await; + } + PendingCacheWrite::AppendTurn { key, turn } => { + info!("[CACHE] appended turn (assistant={})", turn.assistant_uuid); + self.conv_cache.append_turn(&key, turn).await; + // Update stream health flag for the new request + if let Some(flag) = self.stream_health_flag.as_ref() { + self.conv_cache.update_stream_health(&key, flag.clone()).await; + } + } + PendingCacheWrite::ForkAndAppend { key, fork_turn_index, turn } => { + info!("[CACHE] forked at turn {}, new assistant={}", + fork_turn_index, turn.assistant_uuid); + self.conv_cache.fork_and_append(&key, fork_turn_index, turn).await; + // Update stream health flag for the new request + if let Some(flag) = self.stream_health_flag.as_ref() { + self.conv_cache.update_stream_health(&key, flag.clone()).await; + } + } + } } } diff --git a/src/claude_web_state/conversation_cache.rs b/src/claude_web_state/conversation_cache.rs new file mode 100644 index 00000000..6b7c6a64 --- /dev/null +++ b/src/claude_web_state/conversation_cache.rs @@ -0,0 +1,154 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::sync::Mutex; +use chrono::{DateTime, Duration, Utc}; + +/// Represents one round-trip (ClewdR request → Claude response) in a cached conversation +#[derive(Clone, Debug)] +pub struct CachedTurn { + /// Hash of each Role::User message's text content sent in this turn. + /// Turn 0 (full paste) may contain multiple user hashes. + /// Subsequent turns typically contain 1+ user hashes (bundled). + pub user_hashes: Vec, + /// The assistant_message_uuid from turn_message_uuids. + /// Used as `parent_message_uuid` for the next turn. + pub assistant_uuid: String, +} + +/// A cached conversation that can be reused across requests +#[derive(Clone, Debug)] +pub struct CachedConversation { + /// Claude.ai conversation UUID + pub conv_uuid: String, + /// Organization UUID (must match) + pub org_uuid: String, + /// Cookie identifier string (must match — different cookie = different account) + pub cookie_id: String, + /// Model used (must match) + pub model: String, + /// Whether the account was pro when conversation was created + pub is_pro: bool, + /// Hash of the system prompt (system change → full rebuild) + pub system_hash: u64, + /// Ordered list of completed turns + pub turns: Vec, + /// When this conversation was first created + pub created_at: DateTime, + /// Last time this conversation was successfully used + pub last_used: DateTime, + /// Whether cache is currently valid (set to false on stream errors) + pub valid: bool, + /// Shared flag set to true when the SSE stream completes with a stop signal. + /// Checked on next reuse; if still false, the previous stream was incomplete. + pub last_stream_healthy: Arc, +} + +impl CachedConversation { + /// Check if this cached conversation has expired (conservative 25-day TTL) + pub fn is_expired(&self) -> bool { + Utc::now() - self.created_at > Duration::days(25) + } + + /// Get the last assistant UUID (parent for next turn) + pub fn last_parent_uuid(&self) -> Option<&str> { + self.turns.last().map(|t| t.assistant_uuid.as_str()) + } + + /// Truncate turns from `from_index` onward (for fork scenarios) + pub fn truncate_turns(&mut self, from_index: usize) { + self.turns.truncate(from_index); + } +} + +/// Cache key: identifies a unique "conversation slot" +/// First version: one conversation per (cookie, key_index) pair +/// This means each downstream API key gets one cached conversation per cookie +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct CacheKey { + /// Index of the downstream API key in config (from self.key) + pub key_index: usize, +} + +/// Thread-safe conversation cache +#[derive(Clone)] +pub struct ConversationCache { + inner: Arc>>, +} + +impl ConversationCache { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn get(&self, key: &CacheKey) -> Option { + let map = self.inner.lock().await; + map.get(key).filter(|c| c.valid && !c.is_expired()).cloned() + } + + pub async fn set(&self, key: CacheKey, conv: CachedConversation) { + let mut map = self.inner.lock().await; + map.insert(key, conv); + } + + /// Append a new turn to an existing cached conversation + pub async fn append_turn(&self, key: &CacheKey, turn: CachedTurn) { + let mut map = self.inner.lock().await; + if let Some(conv) = map.get_mut(key) { + conv.turns.push(turn); + conv.last_used = Utc::now(); + } + } + + /// Truncate turns and append a new one (fork scenario) + pub async fn fork_and_append(&self, key: &CacheKey, from_index: usize, turn: CachedTurn) { + let mut map = self.inner.lock().await; + if let Some(conv) = map.get_mut(key) { + conv.truncate_turns(from_index); + conv.turns.push(turn); + conv.last_used = Utc::now(); + } + } + + /// Mark a cached conversation as invalid + pub async fn invalidate(&self, key: &CacheKey) { + let mut map = self.inner.lock().await; + if let Some(conv) = map.get_mut(key) { + conv.valid = false; + } + } + + /// Remove expired entries (call periodically) + pub async fn cleanup(&self) { + let mut map = self.inner.lock().await; + map.retain(|_, v| v.valid && !v.is_expired()); + } + + /// Invalidate all entries for a given cookie_id (cookie rotation) + pub async fn invalidate_by_cookie(&self, cookie_id: &str) { + let mut map = self.inner.lock().await; + for conv in map.values_mut() { + if conv.cookie_id == cookie_id { + conv.valid = false; + } + } + } + + /// Update the stream health flag on an existing cached conversation + pub async fn update_stream_health(&self, key: &CacheKey, flag: Arc) { + let mut map = self.inner.lock().await; + if let Some(conv) = map.get_mut(key) { + conv.last_stream_healthy = flag; + } + } + + /// Check if the last stream completed healthily for a given cache key + pub async fn is_last_stream_healthy(&self, key: &CacheKey) -> bool { + let map = self.inner.lock().await; + map.get(key) + .map(|c| c.last_stream_healthy.load(Ordering::Relaxed)) + .unwrap_or(true) + } +} diff --git a/src/claude_web_state/diff.rs b/src/claude_web_state/diff.rs new file mode 100644 index 00000000..ad71de57 --- /dev/null +++ b/src/claude_web_state/diff.rs @@ -0,0 +1,365 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; +use serde_json::Value; +use crate::types::claude::{Message, MessageContent, ContentBlock, Role}; +use super::conversation_cache::CachedConversation; + +/// Hash a single user message's text content +pub fn hash_user_message(msg: &Message) -> u64 { + let mut hasher = DefaultHasher::new(); + // Hash the text content of the message + match &msg.content { + MessageContent::Text { content } => { + content.hash(&mut hasher); + } + MessageContent::Blocks { content } => { + for block in content { + match block { + ContentBlock::Text { text, .. } => { + text.hash(&mut hasher); + } + ContentBlock::Image { source, .. } => { + // Hash the image source for change detection + // For base64: hash the data + // This ensures image changes are detected + format!("{:?}", source).hash(&mut hasher); + } + _ => {} + } + } + } + } + hasher.finish() +} + +/// Hash system prompt +pub fn hash_system(system: &Option) -> u64 { + let mut hasher = DefaultHasher::new(); + match system { + Some(v) => format!("{}", v).hash(&mut hasher), + None => 0u64.hash(&mut hasher), + } + hasher.finish() +} + +/// Extract (message_ref, hash) pairs for all Role::User messages in order +pub fn extract_user_hashes(messages: &[Message]) -> Vec<(usize, u64)> { + messages.iter().enumerate() + .filter(|(_, m)| m.role == Role::User) + .map(|(idx, m)| (idx, hash_user_message(m))) + .collect() +} + +/// Result of diffing new messages against cached conversation +#[derive(Debug)] +pub enum DiffResult { + /// All cached turns match; new user messages should be appended + Append { + /// parent_message_uuid = last cached turn's assistant_uuid + parent_uuid: String, + /// Indices into the original messages array for new Role::User messages + new_user_indices: Vec, + /// Hashes of the new user messages (for cache write) + new_user_hashes: Vec, + }, + /// A cached turn doesn't match; fork from an earlier point + Fork { + /// parent_message_uuid = the turn before the mismatch + parent_uuid: String, + /// Which turn index to truncate from + fork_turn_index: usize, + /// Indices into the original messages array for all user messages from fork point + remaining_user_indices: Vec, + /// Hashes of those user messages + remaining_user_hashes: Vec, + }, + /// Cannot reuse; must do full rebuild + FullRebuild, +} + +/// Compare new user messages against cached turns to determine reuse strategy +/// +/// # Arguments +/// * `cached` - The cached conversation state +/// * `new_system_hash` - Hash of the current system prompt +/// * `user_hashes` - (original_index, hash) pairs for all Role::User messages +/// +/// # Logic +/// Walk through cached turns in order. Each turn has a list of user_hashes. +/// For each turn, check if the corresponding user messages match. +/// +/// - If all turns match and there are remaining user messages → Append +/// - If a turn mismatches at its start → Fork from previous turn +/// - If turn[0] mismatches internally → FullRebuild (can't partially reuse paste) +/// - If system prompt changed → FullRebuild +pub fn diff_messages( + cached: &CachedConversation, + new_system_hash: u64, + user_hashes: &[(usize, u64)], // (original_msg_index, hash) +) -> DiffResult { + // System prompt changed → full rebuild + if cached.system_hash != new_system_hash { + return DiffResult::FullRebuild; + } + + // No cached turns (shouldn't happen, but defensive) + if cached.turns.is_empty() { + return DiffResult::FullRebuild; + } + + let mut cursor: usize = 0; // position in user_hashes + + for (turn_idx, turn) in cached.turns.iter().enumerate() { + for (hash_idx_in_turn, &cached_hash) in turn.user_hashes.iter().enumerate() { + if cursor >= user_hashes.len() { + // New messages are shorter than cached → can't reuse + // (user deleted messages from the end) + return DiffResult::FullRebuild; + } + + let (_orig_idx, new_hash) = user_hashes[cursor]; + + if new_hash != cached_hash { + // === MISMATCH DETECTED === + + if turn_idx == 0 { + // Turn 0 is the full paste — can't partially reuse + return DiffResult::FullRebuild; + } + + // Fork from the previous turn + let parent_uuid = cached.turns[turn_idx - 1].assistant_uuid.clone(); + + // Rewind cursor to the start of this turn + let rewind_count = hash_idx_in_turn; + let fork_cursor = cursor - rewind_count; + + let remaining_user_indices: Vec = + user_hashes[fork_cursor..].iter().map(|(idx, _)| *idx).collect(); + let remaining_user_hashes: Vec = + user_hashes[fork_cursor..].iter().map(|(_, h)| *h).collect(); + + return DiffResult::Fork { + parent_uuid, + fork_turn_index: turn_idx, + remaining_user_indices, + remaining_user_hashes, + }; + } + + cursor += 1; + } + } + + // All cached turns matched + if cursor >= user_hashes.len() { + // No new messages — this shouldn't normally happen + // (implies exact same request as before) + return DiffResult::FullRebuild; + } + + // Remaining messages are new → Append + let parent_uuid = cached.turns.last().unwrap().assistant_uuid.clone(); + let new_user_indices: Vec = + user_hashes[cursor..].iter().map(|(idx, _)| *idx).collect(); + let new_user_hashes: Vec = + user_hashes[cursor..].iter().map(|(_, h)| *h).collect(); + + DiffResult::Append { + parent_uuid, + new_user_indices, + new_user_hashes, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::claude::Message; + use crate::claude_web_state::conversation_cache::{CachedConversation, CachedTurn}; + use chrono::Utc; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + + fn make_cached(conv_uuid: &str, turns: Vec, system_hash: u64) -> CachedConversation { + CachedConversation { + conv_uuid: conv_uuid.to_string(), + org_uuid: "org".to_string(), + cookie_id: "cookie".to_string(), + model: "model".to_string(), + is_pro: false, + system_hash, + turns, + created_at: Utc::now(), + last_used: Utc::now(), + valid: true, + last_stream_healthy: Arc::new(AtomicBool::new(true)), + } + } + + fn make_user_msg(text: &str) -> Message { + Message::new_text(Role::User, text) + } + + #[test] + fn test_hash_user_message_text() { + let msg = make_user_msg("hello"); + let h1 = hash_user_message(&msg); + let msg2 = make_user_msg("hello"); + let h2 = hash_user_message(&msg2); + assert_eq!(h1, h2); + + let msg3 = make_user_msg("world"); + let h3 = hash_user_message(&msg3); + assert_ne!(h1, h3); + } + + #[test] + fn test_hash_system_matching() { + let sys = Some(serde_json::json!("system prompt")); + let h1 = hash_system(&sys); + let h2 = hash_system(&sys); + assert_eq!(h1, h2); + } + + #[test] + fn test_hash_system_different() { + let h1 = hash_system(&Some(serde_json::json!("sys1"))); + let h2 = hash_system(&Some(serde_json::json!("sys2"))); + assert_ne!(h1, h2); + } + + #[test] + fn test_extract_user_hashes() { + let messages = vec![ + make_user_msg("u1"), + Message::new_text(Role::Assistant, "a1"), + make_user_msg("u2"), + ]; + let hashes = extract_user_hashes(&messages); + assert_eq!(hashes.len(), 2); + assert_eq!(hashes[0].0, 0); + assert_eq!(hashes[1].0, 2); + } + + #[test] + fn test_diff_messages_append() { + let msgs = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3")]; + let hashes = extract_user_hashes(&msgs); + let sys_hash = hash_system(&None); + + let cached = make_cached("conv1", vec![CachedTurn { + user_hashes: vec![hashes[0].1, hashes[1].1], + assistant_uuid: "asst1".to_string(), + }], sys_hash); + + let result = diff_messages(&cached, sys_hash, &hashes); + match result { + DiffResult::Append { parent_uuid, new_user_indices, new_user_hashes } => { + assert_eq!(parent_uuid, "asst1"); + assert_eq!(new_user_indices, vec![2]); + assert_eq!(new_user_hashes.len(), 1); + assert_eq!(new_user_hashes[0], hashes[2].1); + } + _ => panic!("Expected Append, got {:?}", result), + } + } + + #[test] + fn test_diff_messages_full_rebuild_system_change() { + let msgs = vec![make_user_msg("u1")]; + let hashes = extract_user_hashes(&msgs); + let sys_hash1 = hash_system(&Some(serde_json::json!("sys1"))); + let sys_hash2 = hash_system(&Some(serde_json::json!("sys2"))); + + let cached = make_cached("conv1", vec![CachedTurn { + user_hashes: vec![hashes[0].1], + assistant_uuid: "asst1".to_string(), + }], sys_hash1); + + let result = diff_messages(&cached, sys_hash2, &hashes); + assert!(matches!(result, DiffResult::FullRebuild)); + } + + #[test] + fn test_diff_messages_full_rebuild_turn0_mismatch() { + let msgs = vec![make_user_msg("u1_changed")]; + let hashes = extract_user_hashes(&msgs); + + let cached = make_cached("conv1", vec![CachedTurn { + user_hashes: vec![12345u64], // mismatching hash + assistant_uuid: "asst1".to_string(), + }], hash_system(&None)); + + let result = diff_messages(&cached, hash_system(&None), &hashes); + assert!(matches!(result, DiffResult::FullRebuild)); + } + + #[test] + fn test_diff_messages_fork() { + let msgs = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3_edited")]; + let hashes = extract_user_hashes(&msgs); + + // turn 0 has u1, turn 1 has u2_original + let u2_original_hash = hash_user_message(&make_user_msg("u2_original")); + let cached = make_cached("conv1", vec![ + CachedTurn { + user_hashes: vec![hashes[0].1], + assistant_uuid: "asst0".to_string(), + }, + CachedTurn { + user_hashes: vec![u2_original_hash], // mismatch at turn 1 + assistant_uuid: "asst1".to_string(), + }, + ], hash_system(&None)); + + let result = diff_messages(&cached, hash_system(&None), &hashes); + match result { + DiffResult::Fork { parent_uuid, fork_turn_index, remaining_user_indices, .. } => { + assert_eq!(parent_uuid, "asst0"); + assert_eq!(fork_turn_index, 1); + // remaining starts from u2 in the new messages + assert!(remaining_user_indices.contains(&1)); + assert!(remaining_user_indices.contains(&2)); + } + _ => panic!("Expected Fork, got {:?}", result), + } + } + + #[test] + fn test_diff_messages_full_rebuild_same_request() { + let msgs = vec![make_user_msg("u1")]; + let hashes = extract_user_hashes(&msgs); + let sys_hash = hash_system(&None); + + let cached = make_cached("conv1", vec![CachedTurn { + user_hashes: vec![hashes[0].1], + assistant_uuid: "asst1".to_string(), + }], sys_hash); + + // Same messages, no new ones → FullRebuild + let result = diff_messages(&cached, sys_hash, &hashes); + assert!(matches!(result, DiffResult::FullRebuild)); + } + + #[test] + fn test_diff_messages_full_rebuild_shorter_messages() { + let msgs = vec![make_user_msg("u1")]; + let hashes = extract_user_hashes(&msgs); + let sys_hash = hash_system(&None); + + // Cache has more turns than new messages + let cached = make_cached("conv1", vec![ + CachedTurn { + user_hashes: vec![hashes[0].1], + assistant_uuid: "asst0".to_string(), + }, + CachedTurn { + user_hashes: vec![99999u64], // additional turn that new messages don't have + assistant_uuid: "asst1".to_string(), + }, + ], sys_hash); + + let result = diff_messages(&cached, sys_hash, &hashes); + assert!(matches!(result, DiffResult::FullRebuild)); + } +} diff --git a/src/claude_web_state/mod.rs b/src/claude_web_state/mod.rs index c9a5c88d..2de7e8a5 100644 --- a/src/claude_web_state/mod.rs +++ b/src/claude_web_state/mod.rs @@ -1,4 +1,6 @@ use std::sync::LazyLock; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; use axum::http::{ HeaderValue, @@ -24,10 +26,35 @@ use crate::{ pub mod bootstrap; pub mod chat; +pub mod conversation_cache; +pub mod diff; mod transform; /// Placeholder pub static SUPER_CLIENT: LazyLock = LazyLock::new(Client::new); +use conversation_cache::{CacheKey, CachedConversation, CachedTurn, ConversationCache}; + +/// Information needed to write cache after a successful response +#[derive(Clone, Debug)] +pub enum PendingCacheWrite { + /// First request: initialize cache with full conversation info + Init { + key: CacheKey, + conv: CachedConversation, + }, + /// Subsequent request: append a new turn + AppendTurn { + key: CacheKey, + turn: CachedTurn, + }, + /// Fork: truncate and append + ForkAndAppend { + key: CacheKey, + fork_turn_index: usize, + turn: CachedTurn, + }, +} + /// State of current connection #[derive(Clone)] pub struct ClaudeWebState { @@ -44,13 +71,21 @@ pub struct ClaudeWebState { pub client: Client, pub key: Option<(u64, usize)>, pub usage: Usage, - // keep the last request params for potential post-call token accounting + // keep the last request params for possible post-call token accounting pub last_params: Option, + /// Shared conversation cache for reuse across requests + pub conv_cache: ConversationCache, + /// Pending cache write info (set by send_chat, consumed after success) + pub pending_cache_write: Option, + /// Shared flag for monitoring stream health. + /// Set to true when the SSE stream completes with a proper stop signal. + /// Checked on next cache reuse attempt. + pub stream_health_flag: Option>, } impl ClaudeWebState { /// Create a new AppState instance - pub fn new(cookie_actor_handle: CookieActorHandle) -> Self { + pub fn new(cookie_actor_handle: CookieActorHandle, conv_cache: ConversationCache) -> Self { ClaudeWebState { cookie_actor_handle, cookie: None, @@ -66,6 +101,9 @@ impl ClaudeWebState { key: None, usage: Usage::default(), last_params: None, + conv_cache, + pending_cache_write: None, + stream_health_flag: None, } } @@ -145,6 +183,10 @@ impl ClaudeWebState { pub async fn return_cookie(&self, reason: Option) { // return the cookie to the cookie manager if let Some(ref cookie) = self.cookie { + // Invalidate cache for this cookie if there's a reason (cookie changed) + if reason.is_some() { + self.conv_cache.invalidate_by_cookie(&cookie.cookie.to_string()).await; + } self.cookie_actor_handle .return_cookie(cookie.to_owned(), reason) .await @@ -186,7 +228,51 @@ impl ClaudeWebState { /// Deletes or renames the current chat conversation based on configuration /// If preserve_chats is true, the chat is renamed rather than deleted pub async fn clean_chat(&self) -> Result<(), ClewdrError> { - Ok(()) + if CLEWDR_CONFIG.load().preserve_chats { + return Ok(()); + } + // If reuse is enabled, don't delete — we need the conversation alive + if CLEWDR_CONFIG.load().reuse_conversation { + return Ok(()); + } + let Some(ref org_uuid) = self.org_uuid else { + return Ok(()); + }; + let Some(ref conv_uuid) = self.conv_uuid else { + return Ok(()); + }; + let endpoint = self + .endpoint + .join(&format!("api/organizations/{}/usage", org_uuid)) + .ok()?; + + let res = state + .build_request(Method::GET, url) + .send() + .await + .inspect_err(|e| { + warn!("fetch_web_usage: request failed for {}: {}", cookie.cookie, e); + }) + .ok()?; + + res.json::() + .await + .inspect_err(|e| { + warn!("fetch_web_usage: parse failed for {}: {}", cookie.cookie, e); + }) + .ok() + } + + fn cache_key(&self) -> CacheKey { + CacheKey { + key_index: self.key.map(|(_, idx)| idx).unwrap_or(0), + } + } + + fn cookie_id(&self) -> String { + self.cookie.as_ref() + .map(|c| c.cookie.to_string()) + .unwrap_or_default() } /// Fetch usage data via the claude.ai web endpoint. @@ -239,4 +325,16 @@ impl ClaudeWebState { }) .ok() } + + fn cache_key(&self) -> CacheKey { + CacheKey { + key_index: self.key.map(|(_, idx)| idx).unwrap_or(0), + } + } + + fn cookie_id(&self) -> String { + self.cookie.as_ref() + .map(|c| c.cookie.to_string()) + .unwrap_or_default() + } } diff --git a/src/claude_web_state/transform.rs b/src/claude_web_state/transform.rs index 1084c125..ad18eef6 100644 --- a/src/claude_web_state/transform.rs +++ b/src/claude_web_state/transform.rs @@ -46,6 +46,8 @@ impl ClaudeWebState { timezone: TIME_ZONE.to_string(), images: merged.images, tools, + parent_message_uuid: None, + turn_message_uuids: None, }) } diff --git a/src/config/clewdr_config.rs b/src/config/clewdr_config.rs index b0e9dda6..3c3e8ccc 100644 --- a/src/config/clewdr_config.rs +++ b/src/config/clewdr_config.rs @@ -24,7 +24,8 @@ use crate::{ Args, config::{ CC_CLIENT_ID, CookieStatus, UselessCookie, default_check_update, default_ip, - default_max_retries, default_port, default_skip_cool_down, default_use_real_roles, + default_max_retries, default_port, default_skip_cool_down, default_true, + default_use_real_roles, }, error::ClewdrError, utils::enabled, @@ -97,6 +98,13 @@ pub struct ClewdrConfig { pub enable_web_count_tokens: bool, #[serde(default)] pub sanitize_messages: bool, + /// Enable conversation reuse to reduce input token costs. + /// When enabled, conversations are cached and subsequent requests + /// send only incremental messages instead of the full paste. + /// Only effective when preserve_chats = false (incognito mode). + /// Default: true + #[serde(default = "default_true")] + pub reuse_conversation: bool, // Cookie settings, can hot reload #[serde(default)] @@ -156,6 +164,7 @@ impl Default for ClewdrConfig { web_search: false, enable_web_count_tokens: false, sanitize_messages: false, + reuse_conversation: true, skip_first_warning: false, skip_second_warning: false, skip_restricted: false, diff --git a/src/config/constants.rs b/src/config/constants.rs index a2f87b31..f0ffe298 100644 --- a/src/config/constants.rs +++ b/src/config/constants.rs @@ -139,5 +139,10 @@ pub const fn default_skip_cool_down() -> bool { true } +/// Default value for boolean flags that default to true +pub const fn default_true() -> bool { + true +} + /// Default cookie value for testing purposes pub const PLACEHOLDER_COOKIE: &str = "sk-ant-sidXX----------------------------SET_YOUR_COOKIE_HERE----------------------------------------AAAAAAAA"; diff --git a/src/providers/claude/mod.rs b/src/providers/claude/mod.rs index f2ed5c91..56d986f1 100644 --- a/src/providers/claude/mod.rs +++ b/src/providers/claude/mod.rs @@ -8,6 +8,7 @@ use super::LLMProvider; use crate::{ claude_code_state::ClaudeCodeState, claude_web_state::ClaudeWebState, + claude_web_state::conversation_cache::ConversationCache, error::ClewdrError, middleware::claude::{ClaudeApiFormat, ClaudeContext}, services::cookie_actor::CookieActorHandle, @@ -53,12 +54,14 @@ pub struct ClaudeProviderResponse { struct ClaudeSharedState { cookie_actor_handle: CookieActorHandle, + conv_cache: ConversationCache, } impl ClaudeSharedState { - fn new(cookie_actor_handle: CookieActorHandle) -> Self { + fn new(cookie_actor_handle: CookieActorHandle, conv_cache: ConversationCache) -> Self { Self { cookie_actor_handle, + conv_cache, } } } @@ -70,8 +73,8 @@ pub struct ClaudeProviders { } impl ClaudeProviders { - pub fn new(cookie_actor_handle: CookieActorHandle) -> Self { - let shared = Arc::new(ClaudeSharedState::new(cookie_actor_handle)); + pub fn new(cookie_actor_handle: CookieActorHandle, conv_cache: ConversationCache) -> Self { + let shared = Arc::new(ClaudeSharedState::new(cookie_actor_handle, conv_cache)); let web = Arc::new(ClaudeWebProvider::new(shared.clone())); let code = Arc::new(ClaudeCodeProvider::new(shared.clone())); Self { web, code } @@ -103,7 +106,10 @@ impl LLMProvider for ClaudeWebProvider { type Output = ClaudeProviderResponse; async fn invoke(&self, request: Self::Request) -> Result { - let mut state = ClaudeWebState::new(self.shared.cookie_actor_handle.clone()); + let mut state = ClaudeWebState::new( + self.shared.cookie_actor_handle.clone(), + self.shared.conv_cache.clone(), + ); let stream = request.context.is_stream(); state.api_format = request.context.api_format(); state.stream = stream; @@ -212,6 +218,6 @@ impl LLMProvider for ClaudeCodeProvider { } } -pub fn build_providers(cookie_actor_handle: CookieActorHandle) -> ClaudeProviders { - ClaudeProviders::new(cookie_actor_handle) +pub fn build_providers(cookie_actor_handle: CookieActorHandle, conv_cache: ConversationCache) -> ClaudeProviders { + ClaudeProviders::new(cookie_actor_handle, conv_cache) } diff --git a/src/router.rs b/src/router.rs index 79a32b46..1804b637 100644 --- a/src/router.rs +++ b/src/router.rs @@ -10,6 +10,7 @@ use tower_http::{compression::CompressionLayer, cors::CorsLayer}; use crate::{ api::*, + claude_web_state::conversation_cache::ConversationCache, middleware::{ RequireAdminAuth, RequireBearerAuth, RequireFlexibleAuth, claude::{add_usage_info, apply_stop_sequences, check_overloaded, to_oai}, @@ -35,7 +36,24 @@ impl RouterBuilder { let cookie_handle = CookieActorHandle::start() .await .expect("Failed to start CookieActor"); - let claude_providers = crate::providers::claude::build_providers(cookie_handle.clone()); + + // Create shared conversation cache + let conv_cache = ConversationCache::new(); + + // Spawn periodic cleanup task (every hour) + let cache_clone = conv_cache.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600)); + loop { + interval.tick().await; + cache_clone.cleanup().await; + } + }); + + let claude_providers = crate::providers::claude::build_providers( + cookie_handle.clone(), + conv_cache, + ); RouterBuilder { claude_providers, cookie_actor_handle: cookie_handle, diff --git a/src/types/claude_web/request.rs b/src/types/claude_web/request.rs index b4c1e7c3..e7027712 100644 --- a/src/types/claude_web/request.rs +++ b/src/types/claude_web/request.rs @@ -29,6 +29,13 @@ impl Attachment { } } +/// Client-generated UUIDs for a single turn's messages +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TurnMessageUuids { + pub human_message_uuid: String, + pub assistant_message_uuid: String, +} + /// Request body to be sent to the Claude.ai #[derive(Deserialize, Serialize, Debug)] pub struct WebRequestBody { @@ -43,6 +50,12 @@ pub struct WebRequestBody { #[serde(skip)] pub images: Vec, pub tools: Vec, + /// Parent message UUID for conversation continuation + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_message_uuid: Option, + /// Client-generated UUIDs for this turn's messages + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_message_uuids: Option, } #[derive(Deserialize, Serialize, Debug)] diff --git a/src/types/claude_web/response.rs b/src/types/claude_web/response.rs index 03d31bb6..e4bc589b 100644 --- a/src/types/claude_web/response.rs +++ b/src/types/claude_web/response.rs @@ -7,6 +7,7 @@ use bytes::Bytes; use eventsource_stream::{EventStream, Eventsource}; use futures::{Stream, TryStreamExt}; use serde::Deserialize; +use std::sync::atomic::Ordering; use url::Url; use wreq::Proxy; @@ -78,6 +79,9 @@ impl ClaudeWebState { &mut self, wreq_res: wreq::Response, ) -> Result { + // Take the stream health flag so it can be moved into the stream wrapper + let stream_health_flag = self.stream_health_flag.take(); + if self.stream { // Stream through while accumulating completion text; persist usage at end let mut input_tokens = self.usage.input_tokens as u64; @@ -112,6 +116,10 @@ impl ClaudeWebState { let e = if let Some(retry) = event.retry { e.retry(retry) } else { e }; yield e.data(event.data); } + // Stream completed successfully — mark as healthy + if let Some(flag) = stream_health_flag.as_ref() { + flag.store(true, Ordering::Relaxed); + } // on end of stream, compute output tokens and persist totals if !acc.is_empty() { // Prefer official count_tokens if enabled and possible; else estimate locally @@ -177,6 +185,12 @@ impl ClaudeWebState { let stream = wreq_res.bytes_stream(); let stream = stream.eventsource(); let text = merge_sse(stream).await?; + + // Non-streaming: full response received successfully — mark as healthy + if let Some(flag) = stream_health_flag.as_ref() { + flag.store(true, Ordering::Relaxed); + } + print_out_text(text.to_owned(), "claude_web_non_stream.txt"); let mut response = CreateMessageResponse::text(text.clone(), Default::default(), self.usage.to_owned()); diff --git a/tests/conversation_cache_test.rs b/tests/conversation_cache_test.rs new file mode 100644 index 00000000..e41b3948 --- /dev/null +++ b/tests/conversation_cache_test.rs @@ -0,0 +1,421 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicBool; + +use clewdr::claude_web_state::conversation_cache::{ + CacheKey, CachedConversation, CachedTurn, ConversationCache, +}; +use clewdr::claude_web_state::diff::{ + self, DiffResult, extract_user_hashes, hash_system, hash_user_message, +}; +use clewdr::types::claude::{Message, Role}; + +fn make_user_msg(text: &str) -> Message { + Message::new_text(Role::User, text) +} + +fn make_cached(conv_uuid: &str, turns: Vec, system_hash: u64) -> CachedConversation { + CachedConversation { + conv_uuid: conv_uuid.to_string(), + org_uuid: "org".to_string(), + cookie_id: "cookie".to_string(), + model: "model".to_string(), + is_pro: false, + system_hash, + turns, + created_at: chrono::Utc::now(), + last_used: chrono::Utc::now(), + valid: true, + last_stream_healthy: Arc::new(AtomicBool::new(true)), + } +} + +/// Test: 3 sequential requests, verify 2nd and 3rd use cache +#[tokio::test] +async fn test_sequential_requests_use_cache() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + // Request 1: full messages [u1, u2, u3] + let msgs1 = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3")]; + let hashes1 = extract_user_hashes(&msgs1); + let conv = make_cached( + "conv1", + vec![CachedTurn { + user_hashes: hashes1.iter().map(|(_, h)| *h).collect(), + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + cache.set(key.clone(), conv).await; + + // Request 2: same prefix + new message [u1, u2, u3, u4] + let msgs2 = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3"), make_user_msg("u4")]; + let hashes2 = extract_user_hashes(&msgs2); + let cached = cache.get(&key).await.unwrap(); + let result = diff::diff_messages(&cached, sys_hash, &hashes2); + match result { + DiffResult::Append { parent_uuid, new_user_indices, new_user_hashes } => { + assert_eq!(parent_uuid, "asst0"); + assert_eq!(new_user_indices, vec![3]); + assert_eq!(new_user_hashes.len(), 1); + } + _ => panic!("Expected Append, got {:?}", result), + } + + // Simulate successful append: update cache + cache.append_turn(&key, CachedTurn { + user_hashes: vec![hashes2[3].1], + assistant_uuid: "asst1".to_string(), + }).await; + + // Request 3: same prefix + another new message [u1, u2, u3, u4, u5] + let msgs3 = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3"), make_user_msg("u4"), make_user_msg("u5")]; + let hashes3 = extract_user_hashes(&msgs3); + let cached = cache.get(&key).await.unwrap(); + assert_eq!(cached.turns.len(), 2); + let result = diff::diff_messages(&cached, sys_hash, &hashes3); + match result { + DiffResult::Append { parent_uuid, new_user_indices, .. } => { + assert_eq!(parent_uuid, "asst1"); + assert_eq!(new_user_indices, vec![4]); + } + _ => panic!("Expected Append, got {:?}", result), + } +} + +/// Test: edit scenario (message modification → fork) +#[tokio::test] +async fn test_edit_scenario_fork() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + // Initial: [u1, u2, u3] + let msgs1 = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3")]; + let hashes1 = extract_user_hashes(&msgs1); + let conv = make_cached( + "conv1", + vec![CachedTurn { + user_hashes: hashes1.iter().map(|(_, h)| *h).collect(), + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + cache.set(key.clone(), conv).await; + + // Edit: [u1, u2_edited, u3] + let msgs2 = vec![make_user_msg("u1"), make_user_msg("u2_edited"), make_user_msg("u3")]; + let hashes2 = extract_user_hashes(&msgs2); + let cached = cache.get(&key).await.unwrap(); + let result = diff::diff_messages(&cached, sys_hash, &hashes2); + + // Turn 0 has the mismatch (u2_edited vs u2) → FullRebuild + assert!(matches!(result, DiffResult::FullRebuild)); +} + +/// Test: edit scenario with multi-turn fork +#[tokio::test] +async fn test_edit_scenario_fork_multi_turn() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + // Turn 0: [u1, u2, u3], Turn 1: [u4] + let msgs1 = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3")]; + let hashes1 = extract_user_hashes(&msgs1); + let u4_hash = hash_user_message(&make_user_msg("u4")); + let conv = make_cached( + "conv1", + vec![ + CachedTurn { + user_hashes: hashes1.iter().map(|(_, h)| *h).collect(), + assistant_uuid: "asst0".to_string(), + }, + CachedTurn { + user_hashes: vec![u4_hash], + assistant_uuid: "asst1".to_string(), + }, + ], + sys_hash, + ); + cache.set(key.clone(), conv).await; + + // Edit u4 → [u1, u2, u3, u4_edited, u5] + let msgs2 = vec![ + make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3"), + make_user_msg("u4_edited"), make_user_msg("u5"), + ]; + let hashes2 = extract_user_hashes(&msgs2); + let cached = cache.get(&key).await.unwrap(); + let result = diff::diff_messages(&cached, sys_hash, &hashes2); + + match result { + DiffResult::Fork { parent_uuid, fork_turn_index, remaining_user_indices, .. } => { + assert_eq!(parent_uuid, "asst0"); + assert_eq!(fork_turn_index, 1); + assert!(remaining_user_indices.contains(&3)); // u4_edited + assert!(remaining_user_indices.contains(&4)); // u5 + } + _ => panic!("Expected Fork, got {:?}", result), + } +} + +/// Test: system prompt change → full rebuild +#[tokio::test] +async fn test_system_prompt_change_full_rebuild() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash1 = hash_system(&Some(serde_json::json!("system v1"))); + let sys_hash2 = hash_system(&Some(serde_json::json!("system v2"))); + + let msgs = vec![make_user_msg("u1"), make_user_msg("u2")]; + let hashes = extract_user_hashes(&msgs); + let conv = make_cached( + "conv1", + vec![CachedTurn { + user_hashes: hashes.iter().map(|(_, h)| *h).collect(), + assistant_uuid: "asst0".to_string(), + }], + sys_hash1, + ); + cache.set(key.clone(), conv).await; + + // Same messages but different system prompt + let cached = cache.get(&key).await.unwrap(); + let result = diff::diff_messages(&cached, sys_hash2, &hashes); + assert!(matches!(result, DiffResult::FullRebuild)); +} + +/// Test: model switch → cache invalidated +#[tokio::test] +async fn test_model_switch_invalidation() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + let conv = make_cached( + "conv1", + vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + cache.set(key.clone(), conv).await; + + // Verify cache is valid + let cached = cache.get(&key).await.unwrap(); + assert_eq!(cached.model, "model"); + + // Simulate model change: the caller invalidates and creates new + cache.invalidate(&key).await; + assert!(cache.get(&key).await.is_none()); +} + +/// Test: incremental failure → fallback to full rebuild +#[tokio::test] +async fn test_incremental_failure_fallback() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + // Set up cache + let msgs = vec![make_user_msg("u1"), make_user_msg("u2")]; + let hashes = extract_user_hashes(&msgs); + let conv = make_cached( + "conv1", + vec![CachedTurn { + user_hashes: hashes.iter().map(|(_, h)| *h).collect(), + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + cache.set(key.clone(), conv).await; + + // Simulate failure: invalidate cache + cache.invalidate(&key).await; + + // Next request should get cache miss + assert!(cache.get(&key).await.is_none()); + + // Caller falls back to send_full and creates new cache entry + let new_msgs = vec![make_user_msg("u1"), make_user_msg("u2"), make_user_msg("u3")]; + let new_hashes = extract_user_hashes(&new_msgs); + let new_conv = make_cached( + "conv2", + vec![CachedTurn { + user_hashes: new_hashes.iter().map(|(_, h)| *h).collect(), + assistant_uuid: "asst_new".to_string(), + }], + sys_hash, + ); + cache.set(key.clone(), new_conv).await; + + // Verify new cache works + let cached = cache.get(&key).await.unwrap(); + assert_eq!(cached.conv_uuid, "conv2"); +} + +/// Test: cookie rotation → cache invalidation +#[tokio::test] +async fn test_cookie_rotation_invalidation() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + let conv = make_cached( + "conv1", + vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + cache.set(key.clone(), conv).await; + + // Simulate cookie rotation + cache.invalidate_by_cookie("cookie").await; + let cached = cache.get(&key).await; + assert!(cached.is_none()); +} + +/// Test: cache cleanup removes expired entries +#[tokio::test] +async fn test_cache_cleanup() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + // Create a conversation that's already expired (created 26 days ago) + let mut conv = make_cached( + "conv_expired", + vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + conv.created_at = chrono::Utc::now() - chrono::Duration::days(26); + cache.set(key.clone(), conv).await; + + // Before cleanup, it exists but is expired + let cached = cache.get(&key).await; + assert!(cached.is_none()); // get() filters expired + + // Cleanup removes it + cache.cleanup().await; +} + +/// Test: stream health flag is shared between cache and stream +#[tokio::test] +async fn test_stream_health_flag() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + let flag = Arc::new(AtomicBool::new(false)); + let conv = CachedConversation { + conv_uuid: "conv1".to_string(), + org_uuid: "org".to_string(), + cookie_id: "cookie".to_string(), + model: "model".to_string(), + is_pro: false, + system_hash: sys_hash, + turns: vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst0".to_string(), + }], + created_at: chrono::Utc::now(), + last_used: chrono::Utc::now(), + valid: true, + last_stream_healthy: flag.clone(), + }; + cache.set(key.clone(), conv).await; + + // Initially unhealthy + assert!(!cache.is_last_stream_healthy(&key).await); + + // Simulate stream completion + flag.store(true, std::sync::atomic::Ordering::Relaxed); + + // Now healthy + assert!(cache.is_last_stream_healthy(&key).await); +} + +/// Test: stream health flag update on append +#[tokio::test] +async fn test_stream_health_update_on_append() { + let cache = ConversationCache::new(); + let key = CacheKey { key_index: 0 }; + let sys_hash = hash_system(&None); + + let flag = Arc::new(AtomicBool::new(true)); // initially healthy (stream completed) + let conv = CachedConversation { + conv_uuid: "conv1".to_string(), + org_uuid: "org".to_string(), + cookie_id: "cookie".to_string(), + model: "model".to_string(), + is_pro: false, + system_hash: sys_hash, + turns: vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst0".to_string(), + }], + created_at: chrono::Utc::now(), + last_used: chrono::Utc::now(), + valid: true, + last_stream_healthy: flag, + }; + cache.set(key.clone(), conv).await; + + // New request with new flag + let new_flag = Arc::new(AtomicBool::new(false)); + cache.update_stream_health(&key, new_flag.clone()).await; + + // Not yet healthy + assert!(!cache.is_last_stream_healthy(&key).await); + + // Stream completes + new_flag.store(true, std::sync::atomic::Ordering::Relaxed); + assert!(cache.is_last_stream_healthy(&key).await); +} + +/// Test: cache key isolation +#[tokio::test] +async fn test_cache_key_isolation() { + let cache = ConversationCache::new(); + let key0 = CacheKey { key_index: 0 }; + let key1 = CacheKey { key_index: 1 }; + let sys_hash = hash_system(&None); + + let conv0 = make_cached( + "conv_key0", + vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst0".to_string(), + }], + sys_hash, + ); + let conv1 = make_cached( + "conv_key1", + vec![CachedTurn { + user_hashes: vec![hash_user_message(&make_user_msg("u1"))], + assistant_uuid: "asst1".to_string(), + }], + sys_hash, + ); + + cache.set(key0.clone(), conv0).await; + cache.set(key1.clone(), conv1).await; + + let c0 = cache.get(&key0).await.unwrap(); + let c1 = cache.get(&key1).await.unwrap(); + assert_eq!(c0.conv_uuid, "conv_key0"); + assert_eq!(c1.conv_uuid, "conv_key1"); + + // Invalidate one doesn't affect the other + cache.invalidate(&key0).await; + assert!(cache.get(&key0).await.is_none()); + assert!(cache.get(&key1).await.is_some()); +}