diff --git a/crates/ov_cli/src/client.rs b/crates/ov_cli/src/client.rs index b360f316a..0a2c0de78 100644 --- a/crates/ov_cli/src/client.rs +++ b/crates/ov_cli/src/client.rs @@ -5,8 +5,8 @@ use std::fs::File; use std::path::Path; use tempfile::NamedTempFile; use url::Url; -use zip::write::FileOptions; use zip::CompressionMethod; +use zip::write::FileOptions; use crate::error::{Error, Result}; @@ -62,7 +62,8 @@ impl HttpClient { let temp_file = NamedTempFile::new()?; let file = File::create(temp_file.path())?; let mut zip = zip::ZipWriter::new(file); - let options: FileOptions<'_, ()> = FileOptions::default().compression_method(CompressionMethod::Deflated); + let options: FileOptions<'_, ()> = + FileOptions::default().compression_method(CompressionMethod::Deflated); let walkdir = walkdir::WalkDir::new(dir_path); for entry in walkdir.into_iter().filter_map(|e| e.ok()) { @@ -89,14 +90,13 @@ impl HttpClient { // Read file content let file_content = tokio::fs::read(file_path).await?; - + // Create multipart form - let part = reqwest::multipart::Part::bytes(file_content) - .file_name(file_name.to_string()); - - let part = part.mime_str("application/octet-stream").map_err(|e| { - Error::Network(format!("Failed to set mime type: {}", e)) - })?; + let part = reqwest::multipart::Part::bytes(file_content).file_name(file_name.to_string()); + + let part = part + .mime_str("application/octet-stream") + .map_err(|e| Error::Network(format!("Failed to set mime type: {}", e)))?; let form = reqwest::multipart::Form::new().part("file", part); @@ -235,10 +235,7 @@ impl HttpClient { self.handle_response(response).await } - async fn handle_response( - &self, - response: reqwest::Response, - ) -> Result { + async fn handle_response(&self, response: reqwest::Response) -> Result { let status = response.status(); // Handle empty response (204 No Content, etc.) @@ -259,7 +256,11 @@ impl HttpClient { .and_then(|e| e.get("message")) .and_then(|m| m.as_str()) .map(|s| s.to_string()) - .or_else(|| json.get("detail").and_then(|d| d.as_str()).map(|s| s.to_string())) + .or_else(|| { + json.get("detail") + .and_then(|d| d.as_str()) + .map(|s| s.to_string()) + }) .unwrap_or_else(|| format!("HTTP error {}", status)); return Err(Error::Api(error_msg)); } @@ -307,7 +308,12 @@ impl HttpClient { self.get("/api/v1/content/overview", ¶ms).await } - pub async fn reindex(&self, uri: &str, regenerate: bool, wait: bool) -> Result { + pub async fn reindex( + &self, + uri: &str, + regenerate: bool, + wait: bool, + ) -> Result { let body = serde_json::json!({ "uri": uri, "regenerate": regenerate, @@ -320,7 +326,7 @@ impl HttpClient { pub async fn get_bytes(&self, uri: &str) -> Result> { let url = format!("{}/api/v1/content/download", self.base_url); let params = vec![("uri".to_string(), uri.to_string())]; - + let response = self .http .get(&url) @@ -337,20 +343,22 @@ impl HttpClient { .json() .await .map_err(|e| Error::Network(format!("Failed to parse error response: {}", e))); - + let error_msg = match json_result { - Ok(json) => { - json - .get("error") - .and_then(|e| e.get("message")) - .and_then(|m| m.as_str()) - .map(|s| s.to_string()) - .or_else(|| json.get("detail").and_then(|d| d.as_str()).map(|s| s.to_string())) - .unwrap_or_else(|| format!("HTTP error {}", status)) - } + Ok(json) => json + .get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .or_else(|| { + json.get("detail") + .and_then(|d| d.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| format!("HTTP error {}", status)), Err(_) => format!("HTTP error {}", status), }; - + return Err(Error::Api(error_msg)); } @@ -363,7 +371,16 @@ impl HttpClient { // ============ Filesystem Methods ============ - pub async fn ls(&self, uri: &str, simple: bool, recursive: bool, output: &str, abs_limit: i32, show_all_hidden: bool, node_limit: i32) -> Result { + pub async fn ls( + &self, + uri: &str, + simple: bool, + recursive: bool, + output: &str, + abs_limit: i32, + show_all_hidden: bool, + node_limit: i32, + ) -> Result { let params = vec![ ("uri".to_string(), uri.to_string()), ("simple".to_string(), simple.to_string()), @@ -376,7 +393,15 @@ impl HttpClient { self.get("/api/v1/fs/ls", ¶ms).await } - pub async fn tree(&self, uri: &str, output: &str, abs_limit: i32, show_all_hidden: bool, node_limit: i32, level_limit: i32) -> Result { + pub async fn tree( + &self, + uri: &str, + output: &str, + abs_limit: i32, + show_all_hidden: bool, + node_limit: i32, + level_limit: i32, + ) -> Result { let params = vec![ ("uri".to_string(), uri.to_string()), ("output".to_string(), output.to_string()), @@ -453,7 +478,13 @@ impl HttpClient { self.post("/api/v1/search/search", &body).await } - pub async fn grep(&self, uri: &str, pattern: &str, ignore_case: bool, node_limit: i32) -> Result { + pub async fn grep( + &self, + uri: &str, + pattern: &str, + ignore_case: bool, + node_limit: i32, + ) -> Result { let body = serde_json::json!({ "uri": uri, "pattern": pattern, @@ -463,8 +494,12 @@ impl HttpClient { self.post("/api/v1/search/grep", &body).await } - - pub async fn glob(&self, pattern: &str, uri: &str, node_limit: i32) -> Result { + pub async fn glob( + &self, + pattern: &str, + uri: &str, + node_limit: i32, + ) -> Result { let body = serde_json::json!({ "pattern": pattern, "uri": uri, @@ -737,11 +772,7 @@ impl HttpClient { self.put(&path, &body).await } - pub async fn admin_regenerate_key( - &self, - account_id: &str, - user_id: &str, - ) -> Result { + pub async fn admin_regenerate_key(&self, account_id: &str, user_id: &str) -> Result { let path = format!( "/api/v1/admin/accounts/{}/users/{}/key", account_id, user_id diff --git a/crates/ov_cli/src/commands/admin.rs b/crates/ov_cli/src/commands/admin.rs index 227694391..13771a68f 100644 --- a/crates/ov_cli/src/commands/admin.rs +++ b/crates/ov_cli/src/commands/admin.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; use serde_json::json; pub async fn create_account( @@ -10,7 +10,9 @@ pub async fn create_account( output_format: OutputFormat, compact: bool, ) -> Result<()> { - let response = client.admin_create_account(account_id, admin_user_id).await?; + let response = client + .admin_create_account(account_id, admin_user_id) + .await?; output_success(&response, output_format, compact); Ok(()) } @@ -32,13 +34,12 @@ pub async fn delete_account( compact: bool, ) -> Result<()> { let response = client.admin_delete_account(account_id).await?; - let result = if response.is_null() - || response.as_object().map(|o| o.is_empty()).unwrap_or(false) - { - json!({"account_id": account_id}) - } else { - response - }; + let result = + if response.is_null() || response.as_object().map(|o| o.is_empty()).unwrap_or(false) { + json!({"account_id": account_id}) + } else { + response + }; output_success(&result, output_format, compact); Ok(()) } @@ -51,7 +52,9 @@ pub async fn register_user( output_format: OutputFormat, compact: bool, ) -> Result<()> { - let response = client.admin_register_user(account_id, user_id, role).await?; + let response = client + .admin_register_user(account_id, user_id, role) + .await?; output_success(&response, output_format, compact); Ok(()) } @@ -75,13 +78,12 @@ pub async fn remove_user( compact: bool, ) -> Result<()> { let response = client.admin_remove_user(account_id, user_id).await?; - let result = if response.is_null() - || response.as_object().map(|o| o.is_empty()).unwrap_or(false) - { - json!({"account_id": account_id, "user_id": user_id}) - } else { - response - }; + let result = + if response.is_null() || response.as_object().map(|o| o.is_empty()).unwrap_or(false) { + json!({"account_id": account_id, "user_id": user_id}) + } else { + response + }; output_success(&result, output_format, compact); Ok(()) } diff --git a/crates/ov_cli/src/commands/chat.rs b/crates/ov_cli/src/commands/chat.rs index 6c15f1b4f..de125d4ba 100644 --- a/crates/ov_cli/src/commands/chat.rs +++ b/crates/ov_cli/src/commands/chat.rs @@ -10,8 +10,8 @@ use std::time::Duration; use clap::Parser; use reqwest::Client; -use rustyline::error::ReadlineError; use rustyline::DefaultEditor; +use rustyline::error::ReadlineError; use serde::{Deserialize, Serialize}; use termimad::MadSkin; @@ -90,7 +90,7 @@ struct ChatResponse { /// Stream event from SSE #[derive(Debug, Deserialize)] struct ChatStreamEvent { - event: String, // "reasoning", "tool_call", "tool_result", "response" + event: String, // "reasoning", "tool_call", "tool_result", "response" data: serde_json::Value, timestamp: Option, } @@ -198,7 +198,11 @@ impl ChatCommand { let mut buffer = String::new(); let mut final_message = String::new(); - while let Some(chunk) = response.chunk().await.map_err(|e| Error::Network(format!("Stream error: {}", e)))? { + while let Some(chunk) = response + .chunk() + .await + .map_err(|e| Error::Network(format!("Stream error: {}", e)))? + { let chunk_str = String::from_utf8_lossy(&chunk); buffer.push_str(&chunk_str); @@ -221,7 +225,8 @@ impl ChatCommand { } else if let Some(obj) = event.data.as_object() { if let Some(msg) = obj.get("message").and_then(|m| m.as_str()) { final_message = msg.to_string(); - } else if let Some(err) = obj.get("error").and_then(|e| e.as_str()) { + } else if let Some(err) = obj.get("error").and_then(|e| e.as_str()) + { eprintln!("\x1b[1;31mError: {}\x1b[0m", err); } } @@ -290,7 +295,10 @@ impl ChatCommand { } // Send message - match self.send_interactive_message(client, input, &mut session_id).await { + match self + .send_interactive_message(client, input, &mut session_id) + .await + { Ok(_) => {} Err(e) => { eprintln!("\x1b[1;31mError: {}\x1b[0m", e); @@ -330,9 +338,11 @@ impl ChatCommand { session_id: &mut Option, ) -> Result<()> { if self.stream { - self.send_interactive_message_stream(client, input, session_id).await + self.send_interactive_message_stream(client, input, session_id) + .await } else { - self.send_interactive_message_non_stream(client, input, session_id).await + self.send_interactive_message_non_stream(client, input, session_id) + .await } } @@ -431,7 +441,11 @@ impl ChatCommand { let mut final_message = String::new(); let mut got_session_id = false; - while let Some(chunk) = response.chunk().await.map_err(|e| Error::Network(format!("Stream error: {}", e)))? { + while let Some(chunk) = response + .chunk() + .await + .map_err(|e| Error::Network(format!("Stream error: {}", e)))? + { let chunk_str = String::from_utf8_lossy(&chunk); buffer.push_str(&chunk_str); @@ -464,7 +478,8 @@ impl ChatCommand { } else if let Some(obj) = event.data.as_object() { if let Some(msg) = obj.get("message").and_then(|m| m.as_str()) { final_message = msg.to_string(); - } else if let Some(err) = obj.get("error").and_then(|e| e.as_str()) { + } else if let Some(err) = obj.get("error").and_then(|e| e.as_str()) + { eprintln!("\x1b[1;31mError: {}\x1b[0m", err); } } diff --git a/crates/ov_cli/src/commands/content.rs b/crates/ov_cli/src/commands/content.rs index 7674d23e1..4803bb180 100644 --- a/crates/ov_cli/src/commands/content.rs +++ b/crates/ov_cli/src/commands/content.rs @@ -51,17 +51,14 @@ pub async fn reindex( Ok(()) } -pub async fn get( - client: &HttpClient, - uri: &str, - local_path: &str, -) -> Result<()> { +pub async fn get(client: &HttpClient, uri: &str, local_path: &str) -> Result<()> { // Check if target path already exists let path = Path::new(local_path); if path.exists() { - return Err(crate::error::Error::Client( - format!("File already exists: {}", local_path) - )); + return Err(crate::error::Error::Client(format!( + "File already exists: {}", + local_path + ))); } // Ensure parent directory exists diff --git a/crates/ov_cli/src/commands/filesystem.rs b/crates/ov_cli/src/commands/filesystem.rs index bd9f64996..934faa180 100644 --- a/crates/ov_cli/src/commands/filesystem.rs +++ b/crates/ov_cli/src/commands/filesystem.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; pub async fn ls( client: &HttpClient, @@ -14,7 +14,17 @@ pub async fn ls( output_format: OutputFormat, compact: bool, ) -> Result<()> { - let result = client.ls(uri, simple, recursive, output, abs_limit, show_all_hidden, node_limit).await?; + let result = client + .ls( + uri, + simple, + recursive, + output, + abs_limit, + show_all_hidden, + node_limit, + ) + .await?; output_success(&result, output_format, compact); Ok(()) } @@ -30,7 +40,16 @@ pub async fn tree( output_format: OutputFormat, compact: bool, ) -> Result<()> { - let result = client.tree(uri, output, abs_limit, show_all_hidden, node_limit, level_limit).await?; + let result = client + .tree( + uri, + output, + abs_limit, + show_all_hidden, + node_limit, + level_limit, + ) + .await?; output_success(&result, output_format, compact); Ok(()) } diff --git a/crates/ov_cli/src/commands/mod.rs b/crates/ov_cli/src/commands/mod.rs index e9786eb65..dfde8b48f 100644 --- a/crates/ov_cli/src/commands/mod.rs +++ b/crates/ov_cli/src/commands/mod.rs @@ -2,11 +2,11 @@ pub mod admin; pub mod chat; pub mod content; pub mod crypto; -pub mod search; pub mod filesystem; pub mod observer; +pub mod pack; +pub mod relations; +pub mod resources; +pub mod search; pub mod session; pub mod system; -pub mod resources; -pub mod relations; -pub mod pack; diff --git a/crates/ov_cli/src/commands/observer.rs b/crates/ov_cli/src/commands/observer.rs index 8a6178041..83caa5ef7 100644 --- a/crates/ov_cli/src/commands/observer.rs +++ b/crates/ov_cli/src/commands/observer.rs @@ -1,12 +1,8 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; -pub async fn queue( - client: &HttpClient, - output_format: OutputFormat, - compact: bool, -) -> Result<()> { +pub async fn queue(client: &HttpClient, output_format: OutputFormat, compact: bool) -> Result<()> { let response: serde_json::Value = client.get("/api/v1/observer/queue", &[]).await?; output_success(&response, output_format, compact); Ok(()) @@ -22,11 +18,7 @@ pub async fn vikingdb( Ok(()) } -pub async fn vlm( - client: &HttpClient, - output_format: OutputFormat, - compact: bool, -) -> Result<()> { +pub async fn vlm(client: &HttpClient, output_format: OutputFormat, compact: bool) -> Result<()> { let response: serde_json::Value = client.get("/api/v1/observer/vlm", &[]).await?; output_success(&response, output_format, compact); Ok(()) @@ -52,11 +44,7 @@ pub async fn retrieval( Ok(()) } -pub async fn system( - client: &HttpClient, - output_format: OutputFormat, - compact: bool, -) -> Result<()> { +pub async fn system(client: &HttpClient, output_format: OutputFormat, compact: bool) -> Result<()> { let response: serde_json::Value = client.get("/api/v1/observer/system", &[]).await?; output_success(&response, output_format, compact); Ok(()) diff --git a/crates/ov_cli/src/commands/pack.rs b/crates/ov_cli/src/commands/pack.rs index 7f6e4dcfa..82b5f2f2a 100644 --- a/crates/ov_cli/src/commands/pack.rs +++ b/crates/ov_cli/src/commands/pack.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; pub async fn export( client: &HttpClient, diff --git a/crates/ov_cli/src/commands/relations.rs b/crates/ov_cli/src/commands/relations.rs index 357755e62..f0b943e0f 100644 --- a/crates/ov_cli/src/commands/relations.rs +++ b/crates/ov_cli/src/commands/relations.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; pub async fn list_relations( client: &HttpClient, diff --git a/crates/ov_cli/src/commands/resources.rs b/crates/ov_cli/src/commands/resources.rs index 78ff6c63e..dc29317f5 100644 --- a/crates/ov_cli/src/commands/resources.rs +++ b/crates/ov_cli/src/commands/resources.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; pub async fn add_resource( client: &HttpClient, diff --git a/crates/ov_cli/src/commands/search.rs b/crates/ov_cli/src/commands/search.rs index 633776254..ea1d63d67 100644 --- a/crates/ov_cli/src/commands/search.rs +++ b/crates/ov_cli/src/commands/search.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; pub async fn find( client: &HttpClient, @@ -11,7 +11,9 @@ pub async fn find( output_format: OutputFormat, compact: bool, ) -> Result<()> { - let result = client.find(query.to_string(), uri.to_string(), node_limit, threshold).await?; + let result = client + .find(query.to_string(), uri.to_string(), node_limit, threshold) + .await?; output_success(&result, output_format, compact); Ok(()) } @@ -26,7 +28,15 @@ pub async fn search( output_format: OutputFormat, compact: bool, ) -> Result<()> { - let result = client.search(query.to_string(), uri.to_string(), session_id, node_limit, threshold).await?; + let result = client + .search( + query.to_string(), + uri.to_string(), + session_id, + node_limit, + threshold, + ) + .await?; output_success(&result, output_format, compact); Ok(()) } @@ -45,7 +55,6 @@ pub async fn grep( Ok(()) } - pub async fn glob( client: &HttpClient, pattern: &str, diff --git a/crates/ov_cli/src/commands/session.rs b/crates/ov_cli/src/commands/session.rs index 7a052a490..05a7ab361 100644 --- a/crates/ov_cli/src/commands/session.rs +++ b/crates/ov_cli/src/commands/session.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; use serde_json::json; pub async fn new_session( @@ -35,6 +35,41 @@ pub async fn get_session( Ok(()) } +pub async fn get_session_context( + client: &HttpClient, + session_id: &str, + token_budget: i32, + output_format: OutputFormat, + compact: bool, +) -> Result<()> { + let path = format!("/api/v1/sessions/{}/context", url_encode(session_id)); + let response: serde_json::Value = client + .get( + &path, + &[("token_budget".to_string(), token_budget.to_string())], + ) + .await?; + output_success(&response, output_format, compact); + Ok(()) +} + +pub async fn get_session_archive( + client: &HttpClient, + session_id: &str, + archive_id: &str, + output_format: OutputFormat, + compact: bool, +) -> Result<()> { + let path = format!( + "/api/v1/sessions/{}/archives/{}", + url_encode(session_id), + url_encode(archive_id) + ); + let response: serde_json::Value = client.get(&path, &[]).await?; + output_success(&response, output_format, compact); + Ok(()) +} + pub async fn delete_session( client: &HttpClient, session_id: &str, @@ -43,14 +78,15 @@ pub async fn delete_session( ) -> Result<()> { let path = format!("/api/v1/sessions/{}", url_encode(session_id)); let response: serde_json::Value = client.delete(&path, &[]).await?; - + // Return session_id in result if empty (similar to Python implementation) - let result = if response.is_null() || response.as_object().map(|o| o.is_empty()).unwrap_or(false) { - json!({"session_id": session_id}) - } else { - response - }; - + let result = + if response.is_null() || response.as_object().map(|o| o.is_empty()).unwrap_or(false) { + json!({"session_id": session_id}) + } else { + response + }; + output_success(&result, output_format, compact); Ok(()) } @@ -68,7 +104,7 @@ pub async fn add_message( "role": role, "content": content }); - + let response: serde_json::Value = client.post(&path, &body).await?; output_success(&response, output_format, compact); Ok(()) @@ -99,35 +135,36 @@ pub async fn add_memory( compact: bool, ) -> Result<()> { // Parse input to determine messages - let messages: Vec<(String, String)> = if let Ok(value) = serde_json::from_str::(input) { - if let Some(arr) = value.as_array() { - // JSON array of {role, content} - arr.iter() - .map(|item| { - let role = item["role"].as_str().unwrap_or("user").to_string(); - let content = item["content"].as_str().unwrap_or("").to_string(); - (role, content) - }) - .collect() - } else if value.get("role").is_some() || value.get("content").is_some() { - // Single JSON object with role/content - let role = value["role"].as_str().unwrap_or("user").to_string(); - let content = value["content"].as_str().unwrap_or("").to_string(); - vec![(role, content)] + let messages: Vec<(String, String)> = + if let Ok(value) = serde_json::from_str::(input) { + if let Some(arr) = value.as_array() { + // JSON array of {role, content} + arr.iter() + .map(|item| { + let role = item["role"].as_str().unwrap_or("user").to_string(); + let content = item["content"].as_str().unwrap_or("").to_string(); + (role, content) + }) + .collect() + } else if value.get("role").is_some() || value.get("content").is_some() { + // Single JSON object with role/content + let role = value["role"].as_str().unwrap_or("user").to_string(); + let content = value["content"].as_str().unwrap_or("").to_string(); + vec![(role, content)] + } else { + // JSON but not a message object, treat as plain string + vec![("user".to_string(), input.to_string())] + } } else { - // JSON but not a message object, treat as plain string + // Plain string vec![("user".to_string(), input.to_string())] - } - } else { - // Plain string - vec![("user".to_string(), input.to_string())] - }; + }; // 1. Create a new session let session_response: serde_json::Value = client.post("/api/v1/sessions", &json!({})).await?; - let session_id = session_response["session_id"] - .as_str() - .ok_or_else(|| crate::error::Error::Api("Failed to get session_id from new session response".to_string()))?; + let session_id = session_response["session_id"].as_str().ok_or_else(|| { + crate::error::Error::Api("Failed to get session_id from new session response".to_string()) + })?; // 2. Add messages for (role, content) in &messages { diff --git a/crates/ov_cli/src/commands/system.rs b/crates/ov_cli/src/commands/system.rs index eae7b35d8..f41fc7c6d 100644 --- a/crates/ov_cli/src/commands/system.rs +++ b/crates/ov_cli/src/commands/system.rs @@ -1,6 +1,6 @@ use crate::client::HttpClient; use crate::error::Result; -use crate::output::{output_success, OutputFormat}; +use crate::output::{OutputFormat, output_success}; use serde_json::json; pub async fn wait( @@ -20,11 +20,7 @@ pub async fn wait( Ok(()) } -pub async fn status( - client: &HttpClient, - output_format: OutputFormat, - compact: bool, -) -> Result<()> { +pub async fn status(client: &HttpClient, output_format: OutputFormat, compact: bool) -> Result<()> { let response: serde_json::Value = client.get("/api/v1/system/status", &[]).await?; output_success(&response, output_format, compact); Ok(()) @@ -36,13 +32,16 @@ pub async fn health( compact: bool, ) -> Result { let response: serde_json::Value = client.get("/health", &[]).await?; - + // Extract the key fields - let healthy = response.get("healthy").and_then(|v| v.as_bool()).unwrap_or(false); + let healthy = response + .get("healthy") + .and_then(|v| v.as_bool()) + .unwrap_or(false); let _status = response.get("status").and_then(|v| v.as_str()); let version = response.get("version").and_then(|v| v.as_str()); let user_id = response.get("user_id").and_then(|v| v.as_str()); - + // For table output, print in a readable format if matches!(output_format, OutputFormat::Table) || matches!(output_format, OutputFormat::Json) { output_success(&response, output_format, compact); @@ -57,6 +56,6 @@ pub async fn health( } println!(); } - + Ok(healthy) } diff --git a/crates/ov_cli/src/main.rs b/crates/ov_cli/src/main.rs index 5e49d0319..23d0a7f82 100644 --- a/crates/ov_cli/src/main.rs +++ b/crates/ov_cli/src/main.rs @@ -202,7 +202,12 @@ enum Commands { #[arg(short, long)] all: bool, /// Maximum number of nodes to list - #[arg(long = "node-limit", short = 'n', alias = "limit", default_value = "256")] + #[arg( + long = "node-limit", + short = 'n', + alias = "limit", + default_value = "256" + )] node_limit: i32, }, /// Get directory tree @@ -216,7 +221,12 @@ enum Commands { #[arg(short, long)] all: bool, /// Maximum number of nodes to list - #[arg(long = "node-limit", short = 'n', alias = "limit", default_value = "256")] + #[arg( + long = "node-limit", + short = 'n', + alias = "limit", + default_value = "256" + )] node_limit: i32, /// Maximum depth level to traverse (default: 3) #[arg(short = 'L', long = "level-limit", default_value = "3")] @@ -290,7 +300,12 @@ enum Commands { #[arg(short, long, default_value = "")] uri: String, /// Maximum number of results - #[arg(short = 'n', long = "node-limit", alias = "limit", default_value = "10")] + #[arg( + short = 'n', + long = "node-limit", + alias = "limit", + default_value = "10" + )] node_limit: i32, /// Score threshold #[arg(short, long)] @@ -307,7 +322,12 @@ enum Commands { #[arg(long)] session_id: Option, /// Maximum number of results - #[arg(short = 'n', long = "node-limit", alias = "limit", default_value = "10")] + #[arg( + short = 'n', + long = "node-limit", + alias = "limit", + default_value = "10" + )] node_limit: i32, /// Score threshold #[arg(short, long)] @@ -324,7 +344,12 @@ enum Commands { #[arg(short, long)] ignore_case: bool, /// Maximum number of results - #[arg(short = 'n', long = "node-limit", alias = "limit", default_value = "256")] + #[arg( + short = 'n', + long = "node-limit", + alias = "limit", + default_value = "256" + )] node_limit: i32, }, /// Run file glob pattern search @@ -335,7 +360,12 @@ enum Commands { #[arg(short, long, default_value = "viking://")] uri: String, /// Maximum number of results - #[arg(short = 'n', long = "node-limit", alias = "limit", default_value = "256")] + #[arg( + short = 'n', + long = "node-limit", + alias = "limit", + default_value = "256" + )] node_limit: i32, }, /// Add memory in one shot (creates session, adds messages, commits) @@ -427,6 +457,21 @@ enum SessionCommands { /// Session ID session_id: String, }, + /// Get full merged session context + GetSessionContext { + /// Session ID + session_id: String, + /// Token budget for latest archive overview inclusion + #[arg(long = "token-budget", default_value = "128000")] + token_budget: i32, + }, + /// Get one completed archive for a session + GetSessionArchive { + /// Session ID + session_id: String, + /// Archive ID + archive_id: String, + }, /// Delete a session Delete { /// Session ID @@ -564,65 +609,71 @@ async fn main() { ) .await } - Commands::AddSkill { data, wait, timeout } => { - handle_add_skill(data, wait, timeout, ctx).await - } - Commands::Relations { uri } => { - handle_relations(uri, ctx).await - } - Commands::Link { from_uri, to_uris, reason } => { - handle_link(from_uri, to_uris, reason, ctx).await - } - Commands::Unlink { from_uri, to_uri } => { - handle_unlink(from_uri, to_uri, ctx).await - } - Commands::Export { uri, to } => { - handle_export(uri, to, ctx).await - } - Commands::Import { file_path, target_uri, force, no_vectorize } => { - handle_import(file_path, target_uri, force, no_vectorize, ctx).await - } + Commands::AddSkill { + data, + wait, + timeout, + } => handle_add_skill(data, wait, timeout, ctx).await, + Commands::Relations { uri } => handle_relations(uri, ctx).await, + Commands::Link { + from_uri, + to_uris, + reason, + } => handle_link(from_uri, to_uris, reason, ctx).await, + Commands::Unlink { from_uri, to_uri } => handle_unlink(from_uri, to_uri, ctx).await, + Commands::Export { uri, to } => handle_export(uri, to, ctx).await, + Commands::Import { + file_path, + target_uri, + force, + no_vectorize, + } => handle_import(file_path, target_uri, force, no_vectorize, ctx).await, Commands::Wait { timeout } => { let client = ctx.get_client(); commands::system::wait(&client, timeout, ctx.output_format, ctx.compact).await - }, + } Commands::Status => { let client = ctx.get_client(); commands::observer::system(&client, ctx.output_format, ctx.compact).await - }, + } Commands::Health => handle_health(ctx).await, Commands::System { action } => handle_system(action, ctx).await, Commands::Observer { action } => handle_observer(action, ctx).await, Commands::Session { action } => handle_session(action, ctx).await, Commands::Admin { action } => handle_admin(action, ctx).await, - Commands::Ls { uri, simple, recursive, abs_limit, all, node_limit } => { - handle_ls(uri, simple, recursive, abs_limit, all, node_limit, ctx).await - } - Commands::Tree { uri, abs_limit, all, node_limit, level_limit } => { - handle_tree(uri, abs_limit, all, node_limit, level_limit, ctx).await - } - Commands::Mkdir { uri } => { - handle_mkdir(uri, ctx).await - } - Commands::Rm { uri, recursive } => { - handle_rm(uri, recursive, ctx).await - } - Commands::Mv { from_uri, to_uri } => { - handle_mv(from_uri, to_uri, ctx).await - } - Commands::Stat { uri } => { - handle_stat(uri, ctx).await - } - Commands::AddMemory { content } => { - handle_add_memory(content, ctx).await - } - Commands::Tui { uri } => { - handle_tui(uri, ctx).await - } - Commands::Chat { message, session, sender, stream, no_format, no_history } => { + Commands::Ls { + uri, + simple, + recursive, + abs_limit, + all, + node_limit, + } => handle_ls(uri, simple, recursive, abs_limit, all, node_limit, ctx).await, + Commands::Tree { + uri, + abs_limit, + all, + node_limit, + level_limit, + } => handle_tree(uri, abs_limit, all, node_limit, level_limit, ctx).await, + Commands::Mkdir { uri } => handle_mkdir(uri, ctx).await, + Commands::Rm { uri, recursive } => handle_rm(uri, recursive, ctx).await, + Commands::Mv { from_uri, to_uri } => handle_mv(from_uri, to_uri, ctx).await, + Commands::Stat { uri } => handle_stat(uri, ctx).await, + Commands::AddMemory { content } => handle_add_memory(content, ctx).await, + Commands::Tui { uri } => handle_tui(uri, ctx).await, + Commands::Chat { + message, + session, + sender, + stream, + no_format, + no_history, + } => { let session_id = session.or_else(|| config::get_or_create_machine_id().ok()); let cmd = commands::chat::ChatCommand { - endpoint: std::env::var("VIKINGBOT_ENDPOINT").unwrap_or_else(|_| "http://localhost:1933/bot/v1".to_string()), + endpoint: std::env::var("VIKINGBOT_ENDPOINT") + .unwrap_or_else(|_| "http://localhost:1933/bot/v1".to_string()), api_key: std::env::var("VIKINGBOT_API_KEY").ok(), session: session_id, sender, @@ -641,23 +692,37 @@ async fn main() { Commands::Read { uri } => handle_read(uri, ctx).await, Commands::Abstract { uri } => handle_abstract(uri, ctx).await, Commands::Overview { uri } => handle_overview(uri, ctx).await, - Commands::Reindex { uri, regenerate, wait } => { - handle_reindex(uri, regenerate, wait, ctx).await - } + Commands::Reindex { + uri, + regenerate, + wait, + } => handle_reindex(uri, regenerate, wait, ctx).await, Commands::Get { uri, local_path } => handle_get(uri, local_path, ctx).await, - Commands::Find { query, uri, node_limit, threshold } => { - handle_find(query, uri, node_limit, threshold, ctx).await - } - Commands::Search { query, uri, session_id, node_limit, threshold } => { - handle_search(query, uri, session_id, node_limit, threshold, ctx).await - } - Commands::Grep { uri, pattern, ignore_case, node_limit } => { - handle_grep(uri, pattern, ignore_case, node_limit, ctx).await - } + Commands::Find { + query, + uri, + node_limit, + threshold, + } => handle_find(query, uri, node_limit, threshold, ctx).await, + Commands::Search { + query, + uri, + session_id, + node_limit, + threshold, + } => handle_search(query, uri, session_id, node_limit, threshold, ctx).await, + Commands::Grep { + uri, + pattern, + ignore_case, + node_limit, + } => handle_grep(uri, pattern, ignore_case, node_limit, ctx).await, - Commands::Glob { pattern, uri, node_limit } => { - handle_glob(pattern, uri, node_limit, ctx).await - } + Commands::Glob { + pattern, + uri, + node_limit, + } => handle_glob(pattern, uri, node_limit, ctx).await, }; if let Err(e) = result { @@ -682,32 +747,35 @@ async fn handle_add_resource( watch_interval: f64, ctx: CliContext, ) -> Result<()> { - let is_url = path.starts_with("http://") - || path.starts_with("https://") - || path.starts_with("git@"); - + let is_url = + path.starts_with("http://") || path.starts_with("https://") || path.starts_with("git@"); + if !is_url { use std::path::Path; - + // Unescape path: replace backslash followed by space with just space let unescaped_path = path.replace("\\ ", " "); let path_obj = Path::new(&unescaped_path); if !path_obj.exists() { eprintln!("Error: Path '{}' does not exist.", path); - + // Check if there might be unquoted spaces use std::env; let args: Vec = env::args().collect(); - - if let Some(add_resource_pos) = args.iter().position(|s| s == "add-resource" || s == "add") { + + if let Some(add_resource_pos) = + args.iter().position(|s| s == "add-resource" || s == "add") + { if args.len() > add_resource_pos + 2 { let extra_args = &args[add_resource_pos + 2..]; let suggested_path = format!("{} {}", path, extra_args.join(" ")); - eprintln!("\nIt looks like you may have forgotten to quote a path with spaces."); + eprintln!( + "\nIt looks like you may have forgotten to quote a path with spaces." + ); eprintln!("Suggested command: ov add-resource \"{}\"", suggested_path); } } - + std::process::exit(1); } path = unescaped_path; @@ -750,7 +818,8 @@ async fn handle_add_resource( watch_interval, ctx.output_format, ctx.compact, - ).await + ) + .await } async fn handle_add_skill( @@ -761,14 +830,19 @@ async fn handle_add_skill( ) -> Result<()> { let client = ctx.get_client(); commands::resources::add_skill( - &client, &data, wait, timeout, ctx.output_format, ctx.compact - ).await + &client, + &data, + wait, + timeout, + ctx.output_format, + ctx.compact, + ) + .await } async fn handle_relations(uri: String, ctx: CliContext) -> Result<()> { let client = ctx.get_client(); - commands::relations::list_relations(&client, &uri, ctx.output_format, ctx.compact - ).await + commands::relations::list_relations(&client, &uri, ctx.output_format, ctx.compact).await } async fn handle_link( @@ -779,25 +853,24 @@ async fn handle_link( ) -> Result<()> { let client = ctx.get_client(); commands::relations::link( - &client, &from_uri, &to_uris, &reason, ctx.output_format, ctx.compact - ).await + &client, + &from_uri, + &to_uris, + &reason, + ctx.output_format, + ctx.compact, + ) + .await } -async fn handle_unlink( - from_uri: String, - to_uri: String, - ctx: CliContext, -) -> Result<()> { +async fn handle_unlink(from_uri: String, to_uri: String, ctx: CliContext) -> Result<()> { let client = ctx.get_client(); - commands::relations::unlink( - &client, &from_uri, &to_uri, ctx.output_format, ctx.compact - ).await + commands::relations::unlink(&client, &from_uri, &to_uri, ctx.output_format, ctx.compact).await } async fn handle_export(uri: String, to: String, ctx: CliContext) -> Result<()> { let client = ctx.get_client(); - commands::pack::export(&client, &uri, &to, ctx.output_format, ctx.compact - ).await + commands::pack::export(&client, &uri, &to, ctx.output_format, ctx.compact).await } async fn handle_import( @@ -809,8 +882,15 @@ async fn handle_import( ) -> Result<()> { let client = ctx.get_client(); commands::pack::import( - &client, &file_path, &target_uri, force, no_vectorize, ctx.output_format, ctx.compact - ).await + &client, + &file_path, + &target_uri, + force, + no_vectorize, + ctx.output_format, + ctx.compact, + ) + .await } async fn handle_system(cmd: SystemCommands, ctx: CliContext) -> Result<()> { @@ -823,8 +903,7 @@ async fn handle_system(cmd: SystemCommands, ctx: CliContext) -> Result<()> { commands::system::status(&client, ctx.output_format, ctx.compact).await } SystemCommands::Health => { - let _ = - commands::system::health(&client, ctx.output_format, ctx.compact).await?; + let _ = commands::system::health(&client, ctx.output_format, ctx.compact).await?; Ok(()) } SystemCommands::Crypto { action } => commands::crypto::handle_crypto(action).await, @@ -865,21 +944,57 @@ async fn handle_session(cmd: SessionCommands, ctx: CliContext) -> Result<()> { commands::session::list_sessions(&client, ctx.output_format, ctx.compact).await } SessionCommands::Get { session_id } => { - commands::session::get_session(&client, &session_id, ctx.output_format, ctx.compact - ).await + commands::session::get_session(&client, &session_id, ctx.output_format, ctx.compact) + .await + } + SessionCommands::GetSessionContext { + session_id, + token_budget, + } => { + commands::session::get_session_context( + &client, + &session_id, + token_budget, + ctx.output_format, + ctx.compact, + ) + .await + } + SessionCommands::GetSessionArchive { + session_id, + archive_id, + } => { + commands::session::get_session_archive( + &client, + &session_id, + &archive_id, + ctx.output_format, + ctx.compact, + ) + .await } SessionCommands::Delete { session_id } => { - commands::session::delete_session(&client, &session_id, ctx.output_format, ctx.compact - ).await + commands::session::delete_session(&client, &session_id, ctx.output_format, ctx.compact) + .await } - SessionCommands::AddMessage { session_id, role, content } => { + SessionCommands::AddMessage { + session_id, + role, + content, + } => { commands::session::add_message( - &client, &session_id, &role, &content, ctx.output_format, ctx.compact - ).await + &client, + &session_id, + &role, + &content, + ctx.output_format, + ctx.compact, + ) + .await } SessionCommands::Commit { session_id } => { - commands::session::commit_session(&client, &session_id, ctx.output_format, ctx.compact - ).await + commands::session::commit_session(&client, &session_id, ctx.output_format, ctx.compact) + .await } } } @@ -887,43 +1002,84 @@ async fn handle_session(cmd: SessionCommands, ctx: CliContext) -> Result<()> { async fn handle_admin(cmd: AdminCommands, ctx: CliContext) -> Result<()> { let client = ctx.get_client(); match cmd { - AdminCommands::CreateAccount { account_id, admin_user_id } => { + AdminCommands::CreateAccount { + account_id, + admin_user_id, + } => { commands::admin::create_account( - &client, &account_id, &admin_user_id, ctx.output_format, ctx.compact, - ).await + &client, + &account_id, + &admin_user_id, + ctx.output_format, + ctx.compact, + ) + .await } AdminCommands::ListAccounts => { commands::admin::list_accounts(&client, ctx.output_format, ctx.compact).await } AdminCommands::DeleteAccount { account_id } => { - commands::admin::delete_account( - &client, &account_id, ctx.output_format, ctx.compact, - ).await + commands::admin::delete_account(&client, &account_id, ctx.output_format, ctx.compact) + .await } - AdminCommands::RegisterUser { account_id, user_id, role } => { + AdminCommands::RegisterUser { + account_id, + user_id, + role, + } => { commands::admin::register_user( - &client, &account_id, &user_id, &role, ctx.output_format, ctx.compact, - ).await + &client, + &account_id, + &user_id, + &role, + ctx.output_format, + ctx.compact, + ) + .await } AdminCommands::ListUsers { account_id } => { - commands::admin::list_users( - &client, &account_id, ctx.output_format, ctx.compact, - ).await + commands::admin::list_users(&client, &account_id, ctx.output_format, ctx.compact).await } - AdminCommands::RemoveUser { account_id, user_id } => { + AdminCommands::RemoveUser { + account_id, + user_id, + } => { commands::admin::remove_user( - &client, &account_id, &user_id, ctx.output_format, ctx.compact, - ).await + &client, + &account_id, + &user_id, + ctx.output_format, + ctx.compact, + ) + .await } - AdminCommands::SetRole { account_id, user_id, role } => { + AdminCommands::SetRole { + account_id, + user_id, + role, + } => { commands::admin::set_role( - &client, &account_id, &user_id, &role, ctx.output_format, ctx.compact, - ).await + &client, + &account_id, + &user_id, + &role, + ctx.output_format, + ctx.compact, + ) + .await } - AdminCommands::RegenerateKey { account_id, user_id } => { + AdminCommands::RegenerateKey { + account_id, + user_id, + } => { commands::admin::regenerate_key( - &client, &account_id, &user_id, ctx.output_format, ctx.compact, - ).await + &client, + &account_id, + &user_id, + ctx.output_format, + ctx.compact, + ) + .await } } } @@ -940,21 +1096,17 @@ async fn handle_config(cmd: ConfigCommands, _ctx: CliContext) -> Result<()> { output::output_success( &serde_json::to_value(config).unwrap(), output::OutputFormat::Json, - true + true, ); Ok(()) } - ConfigCommands::Validate => { - match Config::load() { - Ok(_) => { - println!("Configuration is valid"); - Ok(()) - } - Err(e) => { - Err(Error::Config(e.to_string())) - } + ConfigCommands::Validate => match Config::load() { + Ok(_) => { + println!("Configuration is valid"); + Ok(()) } - } + Err(e) => Err(Error::Config(e.to_string())), + }, } } @@ -975,7 +1127,15 @@ async fn handle_overview(uri: String, ctx: CliContext) -> Result<()> { async fn handle_reindex(uri: String, regenerate: bool, wait: bool, ctx: CliContext) -> Result<()> { let client = ctx.get_client(); - commands::content::reindex(&client, &uri, regenerate, wait, ctx.output_format, ctx.compact).await + commands::content::reindex( + &client, + &uri, + regenerate, + wait, + ctx.output_format, + ctx.compact, + ) + .await } async fn handle_get(uri: String, local_path: String, ctx: CliContext) -> Result<()> { @@ -997,7 +1157,16 @@ async fn handle_find( params.push(format!("\"{}\"", query)); print_command_echo("ov find", ¶ms.join(" "), ctx.config.echo_command); let client = ctx.get_client(); - commands::search::find(&client, &query, &uri, node_limit, threshold, ctx.output_format, ctx.compact).await + commands::search::find( + &client, + &query, + &uri, + node_limit, + threshold, + ctx.output_format, + ctx.compact, + ) + .await } async fn handle_search( @@ -1018,7 +1187,17 @@ async fn handle_search( params.push(format!("\"{}\"", query)); print_command_echo("ov search", ¶ms.join(" "), ctx.config.echo_command); let client = ctx.get_client(); - commands::search::search(&client, &query, &uri, session_id, node_limit, threshold, ctx.output_format, ctx.compact).await + commands::search::search( + &client, + &query, + &uri, + session_id, + node_limit, + threshold, + ctx.output_format, + ctx.compact, + ) + .await } /// Print command with specified parameters for debugging @@ -1028,35 +1207,81 @@ fn print_command_echo(command: &str, params: &str, echo_enabled: bool) { } } -async fn handle_ls(uri: String, simple: bool, recursive: bool, abs_limit: i32, show_all_hidden: bool, node_limit: i32, ctx: CliContext) -> Result<()> { +async fn handle_ls( + uri: String, + simple: bool, + recursive: bool, + abs_limit: i32, + show_all_hidden: bool, + node_limit: i32, + ctx: CliContext, +) -> Result<()> { let mut params = vec![ uri.clone(), format!("-l {}", abs_limit), format!("-n {}", node_limit), ]; - if simple { params.push("-s".to_string()); } - if recursive { params.push("-r".to_string()); } - if show_all_hidden { params.push("-a".to_string()); } + if simple { + params.push("-s".to_string()); + } + if recursive { + params.push("-r".to_string()); + } + if show_all_hidden { + params.push("-a".to_string()); + } print_command_echo("ov ls", ¶ms.join(" "), ctx.config.echo_command); let client = ctx.get_client(); let api_output = if ctx.compact { "agent" } else { "original" }; - commands::filesystem::ls(&client, &uri, simple, recursive, api_output, abs_limit, show_all_hidden, node_limit, ctx.output_format, ctx.compact).await + commands::filesystem::ls( + &client, + &uri, + simple, + recursive, + api_output, + abs_limit, + show_all_hidden, + node_limit, + ctx.output_format, + ctx.compact, + ) + .await } -async fn handle_tree(uri: String, abs_limit: i32, show_all_hidden: bool, node_limit: i32, level_limit: i32, ctx: CliContext) -> Result<()> { +async fn handle_tree( + uri: String, + abs_limit: i32, + show_all_hidden: bool, + node_limit: i32, + level_limit: i32, + ctx: CliContext, +) -> Result<()> { let mut params = vec![ uri.clone(), format!("-l {}", abs_limit), format!("-n {}", node_limit), format!("-L {}", level_limit), ]; - if show_all_hidden { params.push("-a".to_string()); } + if show_all_hidden { + params.push("-a".to_string()); + } print_command_echo("ov tree", ¶ms.join(" "), ctx.config.echo_command); let client = ctx.get_client(); let api_output = if ctx.compact { "agent" } else { "original" }; - commands::filesystem::tree(&client, &uri, api_output, abs_limit, show_all_hidden, node_limit, level_limit, ctx.output_format, ctx.compact).await + commands::filesystem::tree( + &client, + &uri, + api_output, + abs_limit, + show_all_hidden, + node_limit, + level_limit, + ctx.output_format, + ctx.compact, + ) + .await } async fn handle_mkdir(uri: String, ctx: CliContext) -> Result<()> { @@ -1079,29 +1304,57 @@ async fn handle_stat(uri: String, ctx: CliContext) -> Result<()> { commands::filesystem::stat(&client, &uri, ctx.output_format, ctx.compact).await } -async fn handle_grep(uri: String, pattern: String, ignore_case: bool, node_limit: i32, ctx: CliContext) -> Result<()> { +async fn handle_grep( + uri: String, + pattern: String, + ignore_case: bool, + node_limit: i32, + ctx: CliContext, +) -> Result<()> { let mut params = vec![format!("--uri={}", uri), format!("-n {}", node_limit)]; - if ignore_case { params.push("-i".to_string()); } + if ignore_case { + params.push("-i".to_string()); + } params.push(format!("\"{}\"", pattern)); print_command_echo("ov grep", ¶ms.join(" "), ctx.config.echo_command); let client = ctx.get_client(); - commands::search::grep(&client, &uri, &pattern, ignore_case, node_limit, ctx.output_format, ctx.compact).await + commands::search::grep( + &client, + &uri, + &pattern, + ignore_case, + node_limit, + ctx.output_format, + ctx.compact, + ) + .await } - async fn handle_glob(pattern: String, uri: String, node_limit: i32, ctx: CliContext) -> Result<()> { - let params = vec![format!("--uri={}", uri), format!("-n {}", node_limit), format!("\"{}\"", pattern)]; + let params = vec![ + format!("--uri={}", uri), + format!("-n {}", node_limit), + format!("\"{}\"", pattern), + ]; print_command_echo("ov glob", ¶ms.join(" "), ctx.config.echo_command); let client = ctx.get_client(); - commands::search::glob(&client, &pattern, &uri, node_limit, ctx.output_format, ctx.compact).await + commands::search::glob( + &client, + &pattern, + &uri, + node_limit, + ctx.output_format, + ctx.compact, + ) + .await } async fn handle_health(ctx: CliContext) -> Result<()> { let client = ctx.get_client(); - + // Reuse the system health command let _ = commands::system::health(&client, ctx.output_format, ctx.compact).await?; - + Ok(()) } diff --git a/crates/ov_cli/src/output.rs b/crates/ov_cli/src/output.rs index c6db0f430..017cd3c19 100644 --- a/crates/ov_cli/src/output.rs +++ b/crates/ov_cli/src/output.rs @@ -91,6 +91,16 @@ fn print_table(result: T, compact: bool) { // Handle object if let Some(obj) = value.as_object() { if !obj.is_empty() { + if let Some(rendered) = render_session_context(obj, compact) { + println!("{}", rendered); + return; + } + + if let Some(rendered) = render_session_archive(obj, compact) { + println!("{}", rendered); + return; + } + // Rule 5: ComponentStatus (name + is_healthy + status) if obj.contains_key("name") && obj.contains_key("is_healthy") @@ -349,6 +359,241 @@ fn value_to_table(value: &serde_json::Value, compact: bool) -> Option { None } +fn render_session_context( + obj: &serde_json::Map, + compact: bool, +) -> Option { + if !(obj.contains_key("latest_archive_overview") + && obj.contains_key("latest_archive_id") + && obj.contains_key("pre_archive_abstracts") + && obj.contains_key("messages")) + { + return None; + } + + let latest_archive_id = obj + .get("latest_archive_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let latest_archive_overview = obj + .get("latest_archive_overview") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let estimated_tokens = obj + .get("estimatedTokens") + .map(format_value) + .unwrap_or_else(|| "0".to_string()); + + let mut lines: Vec = Vec::new(); + lines.push(format!( + "latest_archive_id {}", + if latest_archive_id.is_empty() { + "(none)" + } else { + latest_archive_id + } + )); + lines.push(format!("estimated_tokens {}", estimated_tokens)); + + if let Some(stats) = obj.get("stats").and_then(|v| v.as_object()) { + lines.push(format!( + "active_messages {}", + obj.get("messages") + .and_then(|v| v.as_array()) + .map(|items| items.len()) + .unwrap_or(0) + )); + lines.push(format!( + "total_archives {}", + stats.get("totalArchives") + .map(format_value) + .unwrap_or_else(|| "0".to_string()) + )); + lines.push(format!( + "included_archives {}", + stats.get("includedArchives") + .map(format_value) + .unwrap_or_else(|| "0".to_string()) + )); + lines.push(format!( + "dropped_archives {}", + stats.get("droppedArchives") + .map(format_value) + .unwrap_or_else(|| "0".to_string()) + )); + } + + lines.push(String::new()); + lines.push("latest_archive_overview".to_string()); + if latest_archive_overview.is_empty() { + if latest_archive_id.is_empty() { + lines.push("(none)".to_string()); + } else { + lines.push("(trimmed by token budget or unavailable)".to_string()); + } + } else { + lines.push(latest_archive_overview.to_string()); + } + + if let Some(items) = obj.get("pre_archive_abstracts").and_then(|v| v.as_array()) { + lines.push(String::new()); + lines.push(format!("pre_archive_abstracts ({})", items.len())); + if items.is_empty() { + lines.push("(empty)".to_string()); + } else if let Some(table) = format_array_to_table(items, compact) { + lines.push(table.trim_end().to_string()); + } + } + + if let Some(messages) = obj.get("messages").and_then(|v| v.as_array()) { + lines.push(String::new()); + lines.push(format!("messages ({})", messages.len())); + if messages.is_empty() { + lines.push("(empty)".to_string()); + } else { + let rows = build_message_rows(messages); + if let Some(table) = format_array_to_table(&rows, compact) { + lines.push(table.trim_end().to_string()); + } + } + } + + Some(lines.join("\n")) +} + +fn render_session_archive( + obj: &serde_json::Map, + compact: bool, +) -> Option { + if !(obj.contains_key("archive_id") && obj.contains_key("overview") && obj.contains_key("messages")) + { + return None; + } + + let archive_id = obj.get("archive_id").and_then(|v| v.as_str()).unwrap_or(""); + let abstract_text = obj.get("abstract").and_then(|v| v.as_str()).unwrap_or(""); + let overview = obj.get("overview").and_then(|v| v.as_str()).unwrap_or(""); + + let mut lines: Vec = Vec::new(); + lines.push(format!( + "archive_id {}", + if archive_id.is_empty() { + "(none)" + } else { + archive_id + } + )); + lines.push(format!( + "abstract {}", + if abstract_text.is_empty() { + "(empty)" + } else { + abstract_text + } + )); + + lines.push(String::new()); + lines.push("overview".to_string()); + lines.push(if overview.is_empty() { + "(empty)".to_string() + } else { + overview.to_string() + }); + + if let Some(messages) = obj.get("messages").and_then(|v| v.as_array()) { + lines.push(String::new()); + lines.push(format!("messages ({})", messages.len())); + if messages.is_empty() { + lines.push("(empty)".to_string()); + } else { + let rows = build_message_rows(messages); + if let Some(table) = format_array_to_table(&rows, compact) { + lines.push(table.trim_end().to_string()); + } + } + } + + Some(lines.join("\n")) +} + +fn build_message_rows(messages: &[serde_json::Value]) -> Vec { + let mut rows: Vec = Vec::new(); + + for message in messages { + let Some(obj) = message.as_object() else { + continue; + }; + + let mut row = serde_json::Map::new(); + row.insert( + "id".to_string(), + obj.get("id").cloned().unwrap_or(serde_json::Value::Null), + ); + row.insert( + "role".to_string(), + obj.get("role").cloned().unwrap_or(serde_json::Value::Null), + ); + row.insert( + "created_at".to_string(), + obj.get("created_at") + .cloned() + .unwrap_or(serde_json::Value::Null), + ); + row.insert( + "content".to_string(), + serde_json::Value::String(summarize_message_content( + obj.get("parts").and_then(|v| v.as_array()), + )), + ); + rows.push(serde_json::Value::Object(row)); + } + + rows +} + +fn summarize_message_content(parts: Option<&Vec>) -> String { + let Some(parts) = parts else { + return String::new(); + }; + + let mut chunks: Vec = Vec::new(); + for part in parts { + let Some(obj) = part.as_object() else { + chunks.push(format_value(part)); + continue; + }; + + let part_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or(""); + match part_type { + "text" => { + if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { + chunks.push(text.to_string()); + } + } + "context" => { + let abstract_text = obj.get("abstract").and_then(|v| v.as_str()).unwrap_or(""); + chunks.push(if abstract_text.is_empty() { + "[context]".to_string() + } else { + format!("[context] {}", abstract_text) + }); + } + "tool" => { + let name = obj.get("tool_name").and_then(|v| v.as_str()).unwrap_or("tool"); + let status = obj.get("tool_status").and_then(|v| v.as_str()).unwrap_or(""); + chunks.push(if status.is_empty() { + format!("[tool:{}]", name) + } else { + format!("[tool:{}:{}]", name, status) + }); + } + _ => chunks.push(format_value(part)), + } + } + + chunks.join(" | ") +} + struct ColumnInfo { max_width: usize, // Max width for alignment (capped at 120) is_numeric: bool, // True if all values in column are numeric diff --git a/crates/ov_cli/src/tui/app.rs b/crates/ov_cli/src/tui/app.rs index 8e7517d46..7fdc5a181 100644 --- a/crates/ov_cli/src/tui/app.rs +++ b/crates/ov_cli/src/tui/app.rs @@ -120,15 +120,14 @@ impl App { // If in vector mode, reload records with new current_uri if self.showing_vector_records { - self.load_vector_records(Some(self.current_uri.clone())).await; + self.load_vector_records(Some(self.current_uri.clone())) + .await; } } async fn load_directory_content(&mut self, uri: &str) { - let (abstract_result, overview_result) = tokio::join!( - self.client.abstract_content(uri), - self.client.overview(uri), - ); + let (abstract_result, overview_result) = + tokio::join!(self.client.abstract_content(uri), self.client.overview(uri),); let mut parts = Vec::new(); @@ -218,7 +217,8 @@ impl App { self.vector_state.next_page_cursor = next_cursor; self.vector_state.cursor = 0; self.vector_state.scroll_offset = 0; - self.status_message = format!("Loaded {} vector records", self.vector_state.records.len()); + self.status_message = + format!("Loaded {} vector records", self.vector_state.records.len()); } Err(e) => { self.status_message = format!("Failed to load vector records: {}", e); @@ -246,7 +246,10 @@ impl App { self.vector_state.records.append(&mut new_records); self.vector_state.has_more = next_cursor.is_some(); self.vector_state.next_page_cursor = next_cursor; - self.status_message = format!("Loaded {} total vector records", self.vector_state.records.len()); + self.status_message = format!( + "Loaded {} total vector records", + self.vector_state.records.len() + ); } Err(e) => { self.status_message = format!("Failed to load next page: {}", e); @@ -257,7 +260,8 @@ impl App { pub async fn toggle_vector_records_mode(&mut self) { self.showing_vector_records = !self.showing_vector_records; if self.showing_vector_records && self.vector_state.records.is_empty() { - self.load_vector_records(Some(self.current_uri.clone())).await; + self.load_vector_records(Some(self.current_uri.clone())) + .await; } } diff --git a/crates/ov_cli/src/tui/mod.rs b/crates/ov_cli/src/tui/mod.rs index 6c9e9d28e..7ff9a0973 100644 --- a/crates/ov_cli/src/tui/mod.rs +++ b/crates/ov_cli/src/tui/mod.rs @@ -6,9 +6,9 @@ mod ui; use std::io; use crossterm::{ - event::{self as ct_event, Event}, - terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, ExecutableCommand, + event::{self as ct_event, Event}, + terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode}, }; use ratatui::prelude::*; diff --git a/crates/ov_cli/src/tui/tree.rs b/crates/ov_cli/src/tui/tree.rs index bc90a0a56..f51e386fe 100644 --- a/crates/ov_cli/src/tui/tree.rs +++ b/crates/ov_cli/src/tui/tree.rs @@ -137,10 +137,7 @@ impl TreeState { } } - async fn fetch_children( - client: &HttpClient, - uri: &str, - ) -> Result, String> { + async fn fetch_children(client: &HttpClient, uri: &str) -> Result, String> { let result = client .ls(uri, false, false, "original", 256, false, 1000) .await @@ -167,10 +164,12 @@ impl TreeState { // Sort: directories first, then alphabetical nodes.sort_by(|a, b| { - b.entry - .is_dir - .cmp(&a.entry.is_dir) - .then_with(|| a.entry.name().to_lowercase().cmp(&b.entry.name().to_lowercase())) + b.entry.is_dir.cmp(&a.entry.is_dir).then_with(|| { + a.entry + .name() + .to_lowercase() + .cmp(&b.entry.name().to_lowercase()) + }) }); Ok(nodes) diff --git a/crates/ov_cli/src/tui/ui.rs b/crates/ov_cli/src/tui/ui.rs index d3d0630f9..5b3b2072e 100644 --- a/crates/ov_cli/src/tui/ui.rs +++ b/crates/ov_cli/src/tui/ui.rs @@ -1,9 +1,9 @@ use ratatui::{ + Frame, layout::{Constraint, Direction, Layout}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Wrap}, - Frame, }; use super::app::{App, Panel}; @@ -61,11 +61,7 @@ fn render_tree(frame: &mut Frame, app: &App, area: ratatui::layout::Rect) { .map(|row| { let indent = " ".repeat(row.depth); let icon = if row.is_dir { - if row.expanded { - "▾ " - } else { - "▸ " - } + if row.expanded { "▾ " } else { "▸ " } } else { " " }; diff --git a/crates/ov_cli/src/utils.rs b/crates/ov_cli/src/utils.rs index 24dee374b..89408443f 100644 --- a/crates/ov_cli/src/utils.rs +++ b/crates/ov_cli/src/utils.rs @@ -11,9 +11,5 @@ pub fn truncate_utf8(s: &str, max_bytes: usize) -> &str { boundary -= 1; } - if boundary == 0 { - "" - } else { - &s[..boundary] - } + if boundary == 0 { "" } else { &s[..boundary] } } diff --git a/docs/en/api/01-overview.md b/docs/en/api/01-overview.md index ae4ddbd44..a679f0402 100644 --- a/docs/en/api/01-overview.md +++ b/docs/en/api/01-overview.md @@ -323,6 +323,7 @@ Compact JSON with status wrapper (when `--compact` is true, which is the default | POST | `/api/v1/sessions` | Create session | | GET | `/api/v1/sessions` | List sessions | | GET | `/api/v1/sessions/{id}` | Get session | +| GET | `/api/v1/sessions/{id}/context` | Get assembled session context | | DELETE | `/api/v1/sessions/{id}` | Delete session | | POST | `/api/v1/sessions/{id}/commit` | Commit session | | POST | `/api/v1/sessions/{id}/messages` | Add message | diff --git a/docs/en/api/05-sessions.md b/docs/en/api/05-sessions.md index 7c4526980..8ce60c0d3 100644 --- a/docs/en/api/05-sessions.md +++ b/docs/en/api/05-sessions.md @@ -176,6 +176,191 @@ openviking session get a1b2c3d4 --- +### get_session_context() + +Get the assembled session context used by OpenClaw-style context rebuilding. + +This endpoint returns: +- `latest_archive_overview`: the `overview` of the latest completed archive, when it fits the token budget +- `latest_archive_id`: the ID of the latest completed archive, used for archive expansion +- `pre_archive_abstracts`: lightweight history entries for older completed archives, each containing `archive_id` and `abstract` +- `messages`: all incomplete archive messages after the latest completed archive, plus current live session messages +- `stats`: token and inclusion stats for the returned context + +Notes: +- `latest_archive_overview` becomes an empty string when no completed archive exists, or when the latest overview does not fit in the token budget. +- `latest_archive_id` is returned whenever a latest completed archive exists, even if `latest_archive_overview` is trimmed by budget. +- `token_budget` is applied to the assembled payload after active `messages`: `latest_archive_overview` has higher priority than `pre_archive_abstracts`, and older abstracts are dropped first when budget is tight. +- Only archive content that is actually returned is counted toward `estimatedTokens` and `stats.archiveTokens`. +- Session commit generates an archive summary during Phase 2 for every non-empty archive attempt. Only archives with a completed `.done` marker are exposed here. + +**Parameters** + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| session_id | str | Yes | - | Session ID | +| token_budget | int | No | 128000 | Token budget for assembled archive payload after active `messages` | + +**Python SDK (Embedded / HTTP)** + +```python +context = await client.get_session_context("a1b2c3d4", token_budget=128000) +print(context["latest_archive_overview"]) +print(context["latest_archive_id"]) +print(context["pre_archive_abstracts"]) +print(len(context["messages"])) + +session = client.session("a1b2c3d4") +context = await session.get_session_context(token_budget=128000) +``` + +**HTTP API** + +``` +GET /api/v1/sessions/{session_id}/context?token_budget=128000 +``` + +```bash +curl -X GET "http://localhost:1933/api/v1/sessions/a1b2c3d4/context?token_budget=128000" \ + -H "X-API-Key: your-key" +``` + +**CLI** + +```bash +ov session get-session-context a1b2c3d4 --token-budget 128000 +``` + +**Response** + +```json +{ + "status": "ok", + "result": { + "latest_archive_overview": "# Session Summary\n\n**Overview**: User discussed deployment and auth setup.", + "latest_archive_id": "archive_002", + "pre_archive_abstracts": [ + { + "archive_id": "archive_001", + "abstract": "User previously discussed repository bootstrap and authentication setup." + } + ], + "messages": [ + { + "id": "msg_pending_1", + "role": "user", + "parts": [ + {"type": "text", "text": "Pending user message"} + ], + "created_at": "2026-03-24T09:10:11Z" + }, + { + "id": "msg_live_1", + "role": "assistant", + "parts": [ + {"type": "text", "text": "Current live message"} + ], + "created_at": "2026-03-24T09:10:20Z" + } + ], + "estimatedTokens": 160, + "stats": { + "totalArchives": 2, + "includedArchives": 2, + "droppedArchives": 0, + "failedArchives": 0, + "activeTokens": 98, + "archiveTokens": 62 + } + } +} +``` + +--- + +### get_session_archive() + +Get the full contents of one completed archive for a session. + +This endpoint is intended to work with `latest_archive_id` and `pre_archive_abstracts[*].archive_id` returned by `get_session_context()`. + +This endpoint returns: +- `archive_id`: the archive ID that was expanded +- `abstract`: the lightweight summary for the archive +- `overview`: the full archive overview +- `messages`: the archived transcript for that archive + +**Parameters** + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| session_id | str | Yes | - | Session ID | +| archive_id | str | Yes | - | Archive ID such as `archive_002` | + +**Python SDK (Embedded / HTTP)** + +```python +archive = await client.get_session_archive("a1b2c3d4", "archive_002") +print(archive["archive_id"]) +print(archive["overview"]) +print(len(archive["messages"])) + +session = client.session("a1b2c3d4") +archive = await session.get_archive("archive_002") +``` + +**HTTP API** + +``` +GET /api/v1/sessions/{session_id}/archives/{archive_id} +``` + +```bash +curl -X GET "http://localhost:1933/api/v1/sessions/a1b2c3d4/archives/archive_002" \ + -H "X-API-Key: your-key" +``` + +**CLI** + +```bash +ov session get-session-archive a1b2c3d4 archive_002 +``` + +**Response** + +```json +{ + "status": "ok", + "result": { + "archive_id": "archive_002", + "abstract": "User discussed deployment and auth setup.", + "overview": "# Session Summary\n\n**Overview**: User discussed deployment and auth setup.", + "messages": [ + { + "id": "msg_archive_1", + "role": "user", + "parts": [ + {"type": "text", "text": "How should I deploy this service?"} + ], + "created_at": "2026-03-24T08:55:01Z" + }, + { + "id": "msg_archive_2", + "role": "assistant", + "parts": [ + {"type": "text", "text": "Use the staged deployment flow and verify auth first."} + ], + "created_at": "2026-03-24T08:55:18Z" + } + ] + } +} +``` + +If the archive does not exist, is incomplete, or does not belong to the session, the API returns `404`. + +--- + ### delete_session() Delete a session. @@ -422,6 +607,11 @@ curl -X POST http://localhost:1933/api/v1/sessions/a1b2c3d4/used \ Commit a session. Message archiving (Phase 1) completes immediately. Summary generation and memory extraction (Phase 2) run asynchronously in the background. Returns a `task_id` for polling progress. +Notes: +- Rapid consecutive commits on the same session are accepted; each request gets its own `task_id`. +- Background Phase 2 work is serialized by archive order: archive `N+1` waits until archive `N` writes `.done`. +- If an earlier archive failed and left no `.done`, later commit requests fail with `FAILED_PRECONDITION` until that failure is resolved. + **Parameters** | Parameter | Type | Required | Default | Description | diff --git a/docs/en/concepts/08-session.md b/docs/en/concepts/08-session.md index 1b6a59027..c53bcefcb 100644 --- a/docs/en/concepts/08-session.md +++ b/docs/en/concepts/08-session.md @@ -6,7 +6,7 @@ Session manages conversation messages, tracks context usage, and extracts long-t **Lifecycle**: Create → Interact → Commit -Getting a session by ID will auto-create it if it does not exist. +Getting a session by ID does not auto-create it by default. Use `client.get_session(..., auto_create=True)` when you want missing sessions to be created automatically. ```python session = client.session(session_id="chat_001") diff --git a/docs/zh/api/01-overview.md b/docs/zh/api/01-overview.md index 3b1346098..98351d2f4 100644 --- a/docs/zh/api/01-overview.md +++ b/docs/zh/api/01-overview.md @@ -325,6 +325,7 @@ openviking -o json ls viking://resources/ | POST | `/api/v1/sessions` | 创建会话 | | GET | `/api/v1/sessions` | 列出会话 | | GET | `/api/v1/sessions/{id}` | 获取会话 | +| GET | `/api/v1/sessions/{id}/context` | 获取组装后的会话上下文 | | DELETE | `/api/v1/sessions/{id}` | 删除会话 | | POST | `/api/v1/sessions/{id}/commit` | 提交会话 | | POST | `/api/v1/sessions/{id}/messages` | 添加消息 | diff --git a/docs/zh/api/05-sessions.md b/docs/zh/api/05-sessions.md index 2ec116466..4e294eb09 100644 --- a/docs/zh/api/05-sessions.md +++ b/docs/zh/api/05-sessions.md @@ -176,6 +176,191 @@ openviking session get a1b2c3d4 --- +### get_session_context() + +获取供上下文组装使用的会话上下文。 + +该接口返回: +- `latest_archive_overview`:最新一个已完成归档的 `overview` 文本,在 token budget 足够时返回 +- `latest_archive_id`:最新一个已完成归档的 ID,用于后续展开 archive 详情 +- `pre_archive_abstracts`:更早历史归档的轻量列表,每项只包含 `archive_id` 和 `abstract` +- `messages`:最新已完成归档之后的所有未完成归档消息,再加上当前 live session 消息 +- `stats`:返回结果对应的 token 与纳入统计 + +说明: +- 没有可用 completed archive,或最新 overview 超出 token budget 时,`latest_archive_overview` 返回空字符串。 +- 只要存在最新 completed archive,就会返回 `latest_archive_id`;即使 `latest_archive_overview` 因 budget 被裁剪,这个 ID 仍然可用。 +- `token_budget` 会在 active `messages` 之后作用于 assembled archive payload:`latest_archive_overview` 优先级高于 `pre_archive_abstracts`,预算紧张时先淘汰最旧的 abstracts。 +- 只有最终实际返回的 archive 内容,才会计入 `estimatedTokens` 和 `stats.archiveTokens`。 +- 当前每次有消息的 session commit 都会在 Phase 2 生成 archive 摘要;只有带 `.done` 标记的 completed archive 才会被这里返回。 + +**参数** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| session_id | str | 是 | - | 会话 ID | +| token_budget | int | 否 | 128000 | active `messages` 之后留给 assembled archive payload 的 token 预算 | + +**Python SDK (Embedded / HTTP)** + +```python +context = await client.get_session_context("a1b2c3d4", token_budget=128000) +print(context["latest_archive_overview"]) +print(context["latest_archive_id"]) +print(context["pre_archive_abstracts"]) +print(len(context["messages"])) + +session = client.session("a1b2c3d4") +context = await session.get_session_context(token_budget=128000) +``` + +**HTTP API** + +``` +GET /api/v1/sessions/{session_id}/context?token_budget=128000 +``` + +```bash +curl -X GET "http://localhost:1933/api/v1/sessions/a1b2c3d4/context?token_budget=128000" \ + -H "X-API-Key: your-key" +``` + +**CLI** + +```bash +ov session get-session-context a1b2c3d4 --token-budget 128000 +``` + +**响应** + +```json +{ + "status": "ok", + "result": { + "latest_archive_overview": "# Session Summary\n\n**Overview**: User discussed deployment and auth setup.", + "latest_archive_id": "archive_002", + "pre_archive_abstracts": [ + { + "archive_id": "archive_001", + "abstract": "用户之前讨论了仓库初始化和鉴权配置。" + } + ], + "messages": [ + { + "id": "msg_pending_1", + "role": "user", + "parts": [ + {"type": "text", "text": "Pending user message"} + ], + "created_at": "2026-03-24T09:10:11Z" + }, + { + "id": "msg_live_1", + "role": "assistant", + "parts": [ + {"type": "text", "text": "Current live message"} + ], + "created_at": "2026-03-24T09:10:20Z" + } + ], + "estimatedTokens": 147, + "stats": { + "totalArchives": 2, + "includedArchives": 2, + "droppedArchives": 0, + "failedArchives": 0, + "activeTokens": 98, + "archiveTokens": 49 + } + } +} +``` + +--- + +### get_session_archive() + +获取某次已完成归档的完整内容。 + +该接口通常配合 `get_session_context()` 返回的 `latest_archive_id` 或 `pre_archive_abstracts[*].archive_id` 使用。 + +该接口返回: +- `archive_id`:被展开的 archive ID +- `abstract`:该 archive 的轻量摘要 +- `overview`:该 archive 的完整 overview +- `messages`:该次 archive 对应的完整消息内容 + +**参数** + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|------|------|------|--------|------| +| session_id | str | 是 | - | 会话 ID | +| archive_id | str | 是 | - | 归档 ID,例如 `archive_002` | + +**Python SDK (Embedded / HTTP)** + +```python +archive = await client.get_session_archive("a1b2c3d4", "archive_002") +print(archive["archive_id"]) +print(archive["overview"]) +print(len(archive["messages"])) + +session = client.session("a1b2c3d4") +archive = await session.get_archive("archive_002") +``` + +**HTTP API** + +``` +GET /api/v1/sessions/{session_id}/archives/{archive_id} +``` + +```bash +curl -X GET "http://localhost:1933/api/v1/sessions/a1b2c3d4/archives/archive_002" \ + -H "X-API-Key: your-key" +``` + +**CLI** + +```bash +ov session get-session-archive a1b2c3d4 archive_002 +``` + +**响应** + +```json +{ + "status": "ok", + "result": { + "archive_id": "archive_002", + "abstract": "用户讨论了部署流程和鉴权配置。", + "overview": "# Session Summary\n\n**Overview**: 用户讨论了部署流程和鉴权配置。", + "messages": [ + { + "id": "msg_archive_1", + "role": "user", + "parts": [ + {"type": "text", "text": "这个服务应该怎么部署?"} + ], + "created_at": "2026-03-24T08:55:01Z" + }, + { + "id": "msg_archive_2", + "role": "assistant", + "parts": [ + {"type": "text", "text": "建议先走分阶段部署,再核验鉴权链路。"} + ], + "created_at": "2026-03-24T08:55:18Z" + } + ] + } +} +``` + +如果 archive 不存在、未完成,或者不属于该 session,接口返回 `404`。 + +--- + ### delete_session() 删除会话。 @@ -422,6 +607,11 @@ curl -X POST http://localhost:1933/api/v1/sessions/a1b2c3d4/used \ 提交会话。归档消息(Phase 1)立即完成,摘要生成和记忆提取(Phase 2)在后台异步执行。返回 `task_id` 用于查询后台任务进度。 +说明: +- 同一 session 的多次快速连续 commit 会被接受;每次请求都会拿到独立的 `task_id`。 +- 后台 Phase 2 会按 archive 顺序串行推进:`archive N+1` 会等待 `archive N` 写出 `.done` 后再继续。 +- 如果更早的 archive 已失败且没有 `.done`,后续 commit 会直接返回 `FAILED_PRECONDITION`,直到该失败被处理。 + **参数** | 参数 | 类型 | 必填 | 默认值 | 说明 | diff --git a/docs/zh/concepts/08-session.md b/docs/zh/concepts/08-session.md index d8f00d3aa..e5959059f 100644 --- a/docs/zh/concepts/08-session.md +++ b/docs/zh/concepts/08-session.md @@ -6,7 +6,7 @@ Session 负责管理对话消息、记录上下文使用、提取长期记忆。 **生命周期**:创建 → 交互 → 提交 -通过 session_id 获取会话时,如果会话不存在将自动创建。 +通过 session_id 获取会话时,默认不会自动创建不存在的会话;如果需要自动创建,请显式使用 `client.get_session(..., auto_create=True)`。 ```python session = client.session(session_id="chat_001") diff --git a/examples/openclaw-plugin/__tests__/context-engine-assemble.test.ts b/examples/openclaw-plugin/__tests__/context-engine-assemble.test.ts new file mode 100644 index 000000000..de1c17e94 --- /dev/null +++ b/examples/openclaw-plugin/__tests__/context-engine-assemble.test.ts @@ -0,0 +1,284 @@ +import { describe, expect, it, vi } from "vitest"; + +import type { OpenVikingClient } from "../client.js"; +import { memoryOpenVikingConfigSchema } from "../config.js"; +import { createMemoryOpenVikingContextEngine } from "../context-engine.js"; + +const cfg = memoryOpenVikingConfigSchema.parse({ + mode: "remote", + baseUrl: "http://127.0.0.1:1933", + autoCapture: false, + autoRecall: false, + ingestReplyAssist: false, +}); + +function roughEstimate(messages: unknown[]): number { + return Math.ceil(JSON.stringify(messages).length / 4); +} + +function makeLogger() { + return { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }; +} + +function makeStats() { + return { + totalArchives: 0, + includedArchives: 0, + droppedArchives: 0, + failedArchives: 0, + activeTokens: 0, + archiveTokens: 0, + }; +} + +function makeEngine(contextResult: unknown) { + const logger = makeLogger(); + const client = { + getSessionContext: vi.fn().mockResolvedValue(contextResult), + } as unknown as OpenVikingClient; + const getClient = vi.fn().mockResolvedValue(client); + const resolveAgentId = vi.fn((sessionId: string) => `agent:${sessionId}`); + + const engine = createMemoryOpenVikingContextEngine({ + id: "openviking", + name: "Context Engine (OpenViking)", + version: "test", + cfg, + logger, + getClient, + resolveAgentId, + }); + + return { + engine, + client: client as unknown as { getSessionContext: ReturnType }, + getClient, + logger, + resolveAgentId, + }; +} + +describe("context-engine assemble()", () => { + it("assembles summary archive and completed tool parts into agent messages", async () => { + const { engine, client, resolveAgentId } = makeEngine({ + latest_archive_overview: "# Session Summary\nPreviously discussed repository setup.", + latest_archive_id: "archive_001", + pre_archive_abstracts: [], + messages: [ + { + id: "msg_1", + role: "assistant", + created_at: "2026-03-24T00:00:00Z", + parts: [ + { type: "text", text: "I checked the latest context." }, + { type: "context", abstract: "User prefers concise answers." }, + { + type: "tool", + tool_id: "tool_123", + tool_name: "read_file", + tool_input: { path: "src/app.ts" }, + tool_output: "export const value = 1;", + tool_status: "completed", + }, + ], + }, + ], + estimatedTokens: 321, + stats: { + ...makeStats(), + totalArchives: 1, + includedArchives: 1, + archiveTokens: 40, + activeTokens: 281, + }, + }); + + const liveMessages = [{ role: "user", content: "fallback live message" }]; + const result = await engine.assemble({ + sessionId: "session-1", + messages: liveMessages, + tokenBudget: 4096, + }); + + expect(resolveAgentId).toHaveBeenCalledWith("session-1"); + expect(client.getSessionContext).toHaveBeenCalledWith("session-1", 4096, "agent:session-1"); + expect(result.estimatedTokens).toBe(321); + expect(result.systemPromptAddition).toContain("Compressed Context"); + expect(result.messages).toEqual([ + { + role: "user", + content: "# Session Summary\nPreviously discussed repository setup.", + }, + { + role: "assistant", + content: [ + { type: "text", text: "I checked the latest context." }, + { type: "text", text: "User prefers concise answers." }, + { + type: "toolUse", + id: "tool_123", + name: "read_file", + input: { path: "src/app.ts" }, + }, + ], + }, + { + role: "toolResult", + toolCallId: "tool_123", + toolName: "read_file", + content: [{ type: "text", text: "export const value = 1;" }], + isError: false, + }, + ]); + }); + + it("emits a non-error toolResult for a running tool (not a synthetic error)", async () => { + const { engine } = makeEngine({ + latest_archive_overview: "", + latest_archive_id: "", + pre_archive_abstracts: [], + messages: [ + { + id: "msg_2", + role: "assistant", + created_at: "2026-03-24T00:00:00Z", + parts: [ + { + type: "tool", + tool_id: "tool_running", + tool_name: "bash", + tool_input: { command: "npm test" }, + tool_output: "", + tool_status: "running", + }, + ], + }, + ], + estimatedTokens: 88, + stats: { + ...makeStats(), + activeTokens: 88, + }, + }); + + const result = await engine.assemble({ + sessionId: "session-running", + messages: [], + }); + + expect(result.systemPromptAddition).toBeUndefined(); + expect(result.messages).toHaveLength(2); + expect(result.messages[0]).toEqual({ + role: "assistant", + content: [ + { + type: "toolUse", + id: "tool_running", + name: "bash", + input: { command: "npm test" }, + }, + ], + }); + expect(result.messages[1]).toMatchObject({ + role: "toolResult", + toolCallId: "tool_running", + toolName: "bash", + isError: false, + }); + const text = (result.messages[1] as any).content?.[0]?.text ?? ""; + expect(text).toContain("interrupted"); + expect((result.messages[1] as { content: Array<{ text: string }> }).content[0]?.text).toContain( + "missing tool result", + ); + }); + + it("degrades tool parts without tool_id into assistant text blocks", async () => { + const { engine } = makeEngine({ + latest_archive_overview: "", + latest_archive_id: "", + pre_archive_abstracts: [], + messages: [ + { + id: "msg_3", + role: "assistant", + created_at: "2026-03-24T00:00:00Z", + parts: [ + { type: "text", text: "Tool state snapshot:" }, + { + type: "tool", + tool_id: "", + tool_name: "grep", + tool_input: { pattern: "TODO" }, + tool_output: "src/app.ts:17 TODO refine this", + tool_status: "completed", + }, + ], + }, + ], + estimatedTokens: 71, + stats: { + ...makeStats(), + activeTokens: 71, + }, + }); + + const result = await engine.assemble({ + sessionId: "session-missing-id", + messages: [], + }); + + expect(result.messages).toEqual([ + { + role: "assistant", + content: [ + { type: "text", text: "Tool state snapshot:" }, + { + type: "text", + text: "[grep] (completed)\nInput: {\"pattern\":\"TODO\"}\nOutput: src/app.ts:17 TODO refine this", + }, + ], + }, + ]); + }); + + it("falls back to live messages when assembled active messages look truncated", async () => { + const { engine } = makeEngine({ + latest_archive_overview: "", + latest_archive_id: "", + pre_archive_abstracts: [], + messages: [ + { + id: "msg_4", + role: "user", + created_at: "2026-03-24T00:00:00Z", + parts: [{ type: "text", text: "Only one stored message" }], + }, + ], + estimatedTokens: 12, + stats: { + ...makeStats(), + activeTokens: 12, + }, + }); + + const liveMessages = [ + { role: "user", content: "message one" }, + { role: "assistant", content: [{ type: "text", text: "message two" }] }, + ]; + + const result = await engine.assemble({ + sessionId: "session-fallback", + messages: liveMessages, + tokenBudget: 1024, + }); + + expect(result).toEqual({ + messages: liveMessages, + estimatedTokens: roughEstimate(liveMessages), + }); + }); +}); diff --git a/examples/openclaw-plugin/client.ts b/examples/openclaw-plugin/client.ts index 11187dd96..1ea0a1b1c 100644 --- a/examples/openclaw-plugin/client.ts +++ b/examples/openclaw-plugin/client.ts @@ -35,6 +35,82 @@ export type PendingClientEntry = { reject: (err: unknown) => void; }; +export type CommitSessionResult = { + session_id: string; + /** "accepted" (async), "completed", "failed", or "timeout" (wait mode). */ + status: string; + task_id?: string; + archive_uri?: string; + archived?: boolean; + /** Present when wait=true and extraction completed. Keyed by category. */ + memories_extracted?: Record; + error?: string; +}; + +export type TaskResult = { + task_id: string; + task_type: string; + status: string; + created_at: number; + updated_at: number; + resource_id?: string; + result?: Record; + error?: string; +}; + +export type OVMessagePart = { + type: string; + text?: string; + uri?: string; + abstract?: string; + context_type?: string; + tool_id?: string; + tool_name?: string; + tool_input?: unknown; + tool_output?: string; + tool_status?: string; + skill_uri?: string; +}; + +export type OVMessage = { + id: string; + role: string; + parts: OVMessagePart[]; + created_at: string; +}; + +export type PreArchiveAbstract = { + archive_id: string; + abstract: string; +}; + +export type SessionContextResult = { + latest_archive_overview: string; + latest_archive_id: string; + pre_archive_abstracts: PreArchiveAbstract[]; + messages: OVMessage[]; + estimatedTokens: number; + stats: { + totalArchives: number; + includedArchives: number; + droppedArchives: number; + failedArchives: number; + activeTokens: number; + archiveTokens: number; + }; +}; + +export type SessionArchiveResult = { + archive_id: string; + abstract: string; + overview: string; + messages: OVMessage[]; +}; + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + export const localClientCache = new Map(); // Module-level pending promise map: shared across all plugin registrations so @@ -260,9 +336,21 @@ export class OpenVikingClient { ); } - /** GET session — server auto-creates if absent; also loads messages from storage before extract. */ - async getSession(sessionId: string, agentId?: string): Promise<{ message_count?: number }> { - return this.request<{ message_count?: number }>( + /** GET session — server auto-creates if absent; returns session meta including message stats and token usage. */ + async getSession(sessionId: string, agentId?: string): Promise<{ + message_count?: number; + commit_count?: number; + last_commit_at?: string; + pending_tokens?: number; + llm_token_usage?: { prompt_tokens: number; completion_tokens: number; total_tokens: number }; + }> { + return this.request<{ + message_count?: number; + commit_count?: number; + last_commit_at?: string; + pending_tokens?: number; + llm_token_usage?: { prompt_tokens: number; completion_tokens: number; total_tokens: number }; + }>( `/api/v1/sessions/${encodeURIComponent(sessionId)}`, { method: "GET" }, agentId, @@ -271,38 +359,84 @@ export class OpenVikingClient { /** * Commit a session: archive (Phase 1) and extract memories (Phase 2). - * wait=false (default): Phase 2 runs in background, returns task_id for polling. - * wait=true: blocks until Phase 2 completes, returns memories_extracted count. + * + * wait=false (default): returns immediately after Phase 1 with task_id. + * wait=true: after Phase 1, polls GET /tasks/{task_id} until Phase 2 + * completes (or times out), then returns the merged result. */ async commitSession( sessionId: string, - options?: { wait?: boolean; agentId?: string }, - ): Promise<{ - session_id: string; - status: string; - task_id?: string; - archive_uri?: string; - archived?: boolean; - memories_extracted?: number; - }> { - const wait = options?.wait ?? false; - return this.request<{ - session_id: string; - status: string; - task_id?: string; - archive_uri?: string; - archived?: boolean; - memories_extracted?: number; - }>(`/api/v1/sessions/${encodeURIComponent(sessionId)}/commit?wait=${wait}`, { - method: "POST", - body: JSON.stringify({}), - }, options?.agentId); + options?: { wait?: boolean; timeoutMs?: number; agentId?: string }, + ): Promise { + const result = await this.request( + `/api/v1/sessions/${encodeURIComponent(sessionId)}/commit`, + { method: "POST", body: JSON.stringify({}) }, + options?.agentId, + ); + + if (!options?.wait || !result.task_id) { + return result; + } + + // Client-side poll until Phase 2 finishes + const deadline = Date.now() + (options.timeoutMs ?? 120_000); + const pollInterval = 500; + while (Date.now() < deadline) { + await sleep(pollInterval); + const task = await this.getTask(result.task_id, options.agentId).catch(() => null); + if (!task) break; + if (task.status === "completed") { + const taskResult = (task.result ?? {}) as Record; + result.status = "completed"; + result.memories_extracted = (taskResult.memories_extracted ?? {}) as Record; + return result; + } + if (task.status === "failed") { + result.status = "failed"; + result.error = task.error; + return result; + } + } + result.status = "timeout"; + return result; + } + + /** Poll a background task by ID. */ + async getTask(taskId: string, agentId?: string): Promise { + return this.request( + `/api/v1/tasks/${encodeURIComponent(taskId)}`, + { method: "GET" }, + agentId, + ); + } + + async getSessionContext( + sessionId: string, + tokenBudget: number = 128_000, + agentId?: string, + ): Promise { + return this.request( + `/api/v1/sessions/${encodeURIComponent(sessionId)}/context?token_budget=${tokenBudget}`, + { method: "GET" }, + agentId, + ); + } + + async getSessionArchive( + sessionId: string, + archiveId: string, + agentId?: string, + ): Promise { + return this.request( + `/api/v1/sessions/${encodeURIComponent(sessionId)}/archives/${encodeURIComponent(archiveId)}`, + { method: "GET" }, + agentId, + ); } async deleteSession(sessionId: string, agentId?: string): Promise { await this.request(`/api/v1/sessions/${encodeURIComponent(sessionId)}`, { method: "DELETE" }, agentId); } - async deleteUri(uri: string, agentId?: string): Promise { await this.request(`/api/v1/fs?uri=${encodeURIComponent(uri)}&recursive=false`, { method: "DELETE", diff --git a/examples/openclaw-plugin/config.ts b/examples/openclaw-plugin/config.ts index 90100593f..3b52f6666 100644 --- a/examples/openclaw-plugin/config.ts +++ b/examples/openclaw-plugin/config.ts @@ -23,9 +23,15 @@ export type MemoryOpenVikingConfig = { recallMaxContentChars?: number; recallPreferAbstract?: boolean; recallTokenBudget?: number; + commitTokenThreshold?: number; ingestReplyAssist?: boolean; ingestReplyAssistMinSpeakerTurns?: number; ingestReplyAssistMinChars?: number; + /** + * When true (default), emit structured `openviking: diag {...}` lines (and any future + * standard-diagnostics file writes) for assemble/afterTurn. Set false to disable. + */ + emitStandardDiagnostics?: boolean; }; const DEFAULT_BASE_URL = "http://127.0.0.1:1933"; @@ -39,9 +45,11 @@ const DEFAULT_RECALL_SCORE_THRESHOLD = 0.15; const DEFAULT_RECALL_MAX_CONTENT_CHARS = 500; const DEFAULT_RECALL_PREFER_ABSTRACT = true; const DEFAULT_RECALL_TOKEN_BUDGET = 2000; +const DEFAULT_COMMIT_TOKEN_THRESHOLD = 2000; const DEFAULT_INGEST_REPLY_ASSIST = true; const DEFAULT_INGEST_REPLY_ASSIST_MIN_SPEAKER_TURNS = 2; const DEFAULT_INGEST_REPLY_ASSIST_MIN_CHARS = 120; +const DEFAULT_EMIT_STANDARD_DIAGNOSTICS = true; const DEFAULT_LOCAL_CONFIG_PATH = join(homedir(), ".openviking", "ov.conf"); const DEFAULT_AGENT_ID = "default"; @@ -118,9 +126,11 @@ export const memoryOpenVikingConfigSchema = { "recallMaxContentChars", "recallPreferAbstract", "recallTokenBudget", + "commitTokenThreshold", "ingestReplyAssist", "ingestReplyAssistMinSpeakerTurns", "ingestReplyAssistMinChars", + "emitStandardDiagnostics", ], "openviking config", ); @@ -181,6 +191,10 @@ export const memoryOpenVikingConfigSchema = { 100, Math.min(50000, Math.floor(toNumber(cfg.recallTokenBudget, DEFAULT_RECALL_TOKEN_BUDGET))), ), + commitTokenThreshold: Math.max( + 0, + Math.min(100_000, Math.floor(toNumber(cfg.commitTokenThreshold, DEFAULT_COMMIT_TOKEN_THRESHOLD))), + ), ingestReplyAssist: cfg.ingestReplyAssist !== false, ingestReplyAssistMinSpeakerTurns: Math.max( 1, @@ -201,6 +215,10 @@ export const memoryOpenVikingConfigSchema = { Math.floor(toNumber(cfg.ingestReplyAssistMinChars, DEFAULT_INGEST_REPLY_ASSIST_MIN_CHARS)), ), ), + emitStandardDiagnostics: + typeof cfg.emitStandardDiagnostics === "boolean" + ? cfg.emitStandardDiagnostics + : DEFAULT_EMIT_STANDARD_DIAGNOSTICS, }; }, uiHints: { @@ -292,6 +310,12 @@ export const memoryOpenVikingConfigSchema = { advanced: true, help: "Maximum estimated tokens for auto-recall memory injection. Injection stops when budget is exhausted.", }, + commitTokenThreshold: { + label: "Commit Token Threshold", + placeholder: String(DEFAULT_COMMIT_TOKEN_THRESHOLD), + advanced: true, + help: "Minimum estimated pending tokens before auto-commit triggers. Set to 0 to commit every turn.", + }, ingestReplyAssist: { label: "Ingest Reply Assist", help: "When transcript-like memory ingestion is detected, add a lightweight reply instruction to reduce NO_REPLY.", @@ -309,6 +333,11 @@ export const memoryOpenVikingConfigSchema = { help: "Minimum sanitized text length required before ingest reply assist can trigger.", advanced: true, }, + emitStandardDiagnostics: { + label: "Standard diagnostics (diag JSON lines)", + advanced: true, + help: "When enabled, emit structured openviking: diag {...} lines for assemble and afterTurn. Disable to reduce log noise.", + }, }, }; diff --git a/examples/openclaw-plugin/context-engine.ts b/examples/openclaw-plugin/context-engine.ts index 738dd2281..a81390945 100644 --- a/examples/openclaw-plugin/context-engine.ts +++ b/examples/openclaw-plugin/context-engine.ts @@ -1,4 +1,4 @@ -import type { OpenVikingClient } from "./client.js"; +import type { OpenVikingClient, OVMessage } from "./client.js"; import type { MemoryOpenVikingConfig } from "./config.js"; import { getCaptureDecision, @@ -8,6 +8,7 @@ import { trimForLog, toJsonLog, } from "./memory-ranking.js"; +import { sanitizeToolUseResultPairing } from "./session-transcript-repair.js"; type AgentMessage = { role?: string; @@ -18,6 +19,7 @@ type ContextEngineInfo = { id: string; name: string; version?: string; + ownsCompaction: true; }; type AssembleResult = { @@ -72,13 +74,9 @@ type ContextEngine = { }) => Promise; }; -export type ContextEngineWithSessionMapping = ContextEngine & { - /** Return the OV session ID for an OpenClaw sessionKey (identity: sessionKey IS the OV session ID). */ - getOVSessionForKey: (sessionKey: string) => string; - /** Ensure an OV session exists on the server for the given OpenClaw sessionKey (auto-created by getSession if absent). */ - resolveOVSession: (sessionKey: string) => Promise; - /** Commit (extract + archive) then delete the OV session, so a fresh one is created on next use. */ - commitOVSession: (sessionKey: string) => Promise; +export type ContextEngineWithCommit = ContextEngine & { + /** Commit (archive + extract) the OV session. Returns true on success. */ + commitOVSession: (sessionId: string) => Promise; }; type Logger = { @@ -91,39 +89,244 @@ function estimateTokens(messages: AgentMessage[]): number { return Math.max(1, messages.length * 80); } -async function tryLegacyCompact(params: { - sessionId: string; - sessionFile: string; - tokenBudget?: number; - force?: boolean; - currentTokenCount?: number; - compactionTarget?: "budget" | "threshold"; - customInstructions?: string; - runtimeContext?: Record; -}): Promise { - const candidates = [ - "openclaw/context-engine/legacy", - "openclaw/dist/context-engine/legacy.js", - ]; - - for (const path of candidates) { - try { - const mod = (await import(path)) as { - LegacyContextEngine?: new () => { - compact: (arg: typeof params) => Promise; - }; +function roughEstimate(messages: AgentMessage[]): number { + return Math.ceil(JSON.stringify(messages).length / 4); +} + +function msgTokenEstimate(msg: AgentMessage): number { + const raw = (msg as Record).content; + if (typeof raw === "string") return Math.ceil(raw.length / 4); + if (Array.isArray(raw)) return Math.ceil(JSON.stringify(raw).length / 4); + return 1; +} + +function messageDigest(messages: AgentMessage[], maxCharsPerMsg = 2000): Array<{role: string; content: string; tokens: number; truncated: boolean}> { + return messages.map((msg) => { + const m = msg as Record; + const role = String(m.role ?? "unknown"); + const raw = m.content; + let text: string; + if (typeof raw === "string") { + text = raw; + } else if (Array.isArray(raw)) { + text = (raw as Record[]) + .map((b) => { + if (b.type === "text") return String(b.text ?? ""); + if (b.type === "toolUse") return `[toolUse: ${String(b.name)}(${JSON.stringify(b.arguments ?? {}).slice(0, 200)})]`; + if (b.type === "toolResult") return `[toolResult: ${JSON.stringify(b.content ?? "").slice(0, 200)}]`; + return `[${String(b.type)}]`; + }) + .join("\n"); + } else { + text = JSON.stringify(raw) ?? ""; + } + const truncated = text.length > maxCharsPerMsg; + return { + role, + content: truncated ? text.slice(0, maxCharsPerMsg) + "..." : text, + tokens: msgTokenEstimate(msg), + truncated, + }; + }); +} + +function emitDiag(log: typeof logger, stage: string, sessionId: string, data: Record, enabled = true): void { + if (!enabled) return; + log.info(`openviking: diag ${JSON.stringify({ ts: Date.now(), stage, sessionId, data })}`); +} + +function totalExtractedMemories(memories?: Record): number { + if (!memories || typeof memories !== "object") { + return 0; + } + return Object.values(memories).reduce((sum, count) => sum + (count ?? 0), 0); +} + +function validTokenBudget(raw: unknown): number | undefined { + if (typeof raw === "number" && Number.isFinite(raw) && raw > 0) { + return raw; + } + return undefined; +} + +/** + * Convert an OpenViking stored message (parts-based format) into one or more + * OpenClaw AgentMessages (content-blocks format). + * + * For assistant messages with ToolParts, this produces: + * 1. The assistant message with toolUse blocks in its content array + * 2. A separate toolResult message per ToolPart (carrying tool_output) + */ +function convertToAgentMessages(msg: { role: string; parts: unknown[] }): AgentMessage[] { + const parts = msg.parts ?? []; + const contentBlocks: Record[] = []; + const toolResults: AgentMessage[] = []; + + for (const part of parts) { + if (!part || typeof part !== "object") continue; + const p = part as Record; + + if (p.type === "text" && typeof p.text === "string") { + contentBlocks.push({ type: "text", text: p.text }); + } else if (p.type === "context") { + if (typeof p.abstract === "string" && p.abstract) { + contentBlocks.push({ type: "text", text: p.abstract }); + } + } else if (p.type === "tool" && msg.role === "assistant") { + const toolId = typeof p.tool_id === "string" ? p.tool_id : ""; + const toolName = typeof p.tool_name === "string" ? p.tool_name : "unknown"; + + if (toolId) { + contentBlocks.push({ + type: "toolUse", + id: toolId, + name: toolName, + input: p.tool_input ?? {}, + }); + + const status = typeof p.tool_status === "string" ? p.tool_status : ""; + const output = typeof p.tool_output === "string" ? p.tool_output : ""; + + if (status === "completed" || status === "error") { + toolResults.push({ + role: "toolResult", + toolCallId: toolId, + toolName, + content: [{ type: "text", text: output || "(no output)" }], + isError: status === "error", + } as unknown as AgentMessage); + } else { + toolResults.push({ + role: "toolResult", + toolCallId: toolId, + toolName, + content: [{ type: "text", text: "(interrupted — tool did not complete)" }], + isError: false, + } as unknown as AgentMessage); + } + } else { + // No tool_id: degrade to text block to preserve information. + // Cannot emit toolUse/toolResult without a valid id. + const status = typeof p.tool_status === "string" ? p.tool_status : "unknown"; + const output = typeof p.tool_output === "string" ? p.tool_output : ""; + const segments = [`[${toolName}] (${status})`]; + if (p.tool_input) { + try { + segments.push(`Input: ${JSON.stringify(p.tool_input)}`); + } catch { + // non-serializable input, skip + } + } + if (output) { + segments.push(`Output: ${output}`); + } + contentBlocks.push({ type: "text", text: segments.join("\n") }); + } + } + } + + const result: AgentMessage[] = []; + + if (msg.role === "assistant") { + result.push({ role: msg.role, content: contentBlocks }); + result.push(...toolResults); + } else { + const texts = contentBlocks + .filter((b) => b.type === "text") + .map((b) => b.text as string); + result.push({ role: msg.role, content: texts.join("\n") || "" }); + } + + return result; +} + +function normalizeAssistantContent(messages: AgentMessage[]): void { + for (let i = 0; i < messages.length; i++) { + const msg = messages[i]; + if (msg?.role === "assistant" && typeof msg.content === "string") { + messages[i] = { + ...msg, + content: [{ type: "text", text: msg.content }], }; - if (!mod?.LegacyContextEngine) { - continue; + } + } +} + +export function formatMessageFaithful(msg: OVMessage): string { + const roleTag = `[${msg.role}]`; + if (!msg.parts || msg.parts.length === 0) { + return `${roleTag}: (empty)`; + } + + const sections: string[] = []; + for (const part of msg.parts) { + if (!part || typeof part !== "object") continue; + switch (part.type) { + case "text": + if (part.text) sections.push(part.text); + break; + case "tool": { + const status = part.tool_status ?? "unknown"; + const header = `[Tool: ${part.tool_name ?? "unknown"}] (${status})`; + const inputStr = part.tool_input + ? `Input: ${JSON.stringify(part.tool_input, null, 2)}` + : ""; + const outputStr = part.tool_output ? `Output:\n${part.tool_output}` : ""; + sections.push([header, inputStr, outputStr].filter(Boolean).join("\n")); + break; } - const legacy = new mod.LegacyContextEngine(); - return legacy.compact(params); - } catch { - // continue + case "context": + sections.push( + `[Context: ${part.uri ?? "?"}]${part.abstract ? ` ${part.abstract}` : ""}`, + ); + break; + default: + sections.push(`[${part.type}]: ${JSON.stringify(part)}`); } } - return null; + return `${roleTag}:\n${sections.join("\n\n")}`; +} + +function buildSystemPromptAddition(): string { + return [ + "## Session Context Guide", + "", + "Your conversation history may include:", + "", + "1. **[Session History Summary]** — A compressed summary of all prior", + " conversation sessions. Use it to understand background and continuity.", + " It is lossy: specific details (commands, file paths, code, config", + " values) may have been compressed away. It may be omitted when the", + " token budget is tight.", + "", + "2. **[Archive Index]** — A list of archive entries in chronological order", + " (archive_001 is the oldest, higher numbers are more recent). Most", + " lines summarize one archive; the latest archive may appear as an ID", + " pointer only.", + "", + "3. **Active messages** — The current, uncompressed conversation.", + "", + "**When you need precise details from a prior session:**", + "", + "1. Review [Archive Index] to identify which archive likely contains", + " the information you need.", + "2. Call `ov_archive_expand` with that archive ID to retrieve the", + " archived conversation content.", + "3. If multiple archives look relevant, try the most recent one first.", + "4. Answer using the retrieved content together with active messages.", + "", + "**Rules:**", + "- If active messages conflict with archive content, trust active", + " messages as the newer source of truth.", + "- Only expand an archive when the existing context lacks the specific detail needed.", + "- If [Session History Summary] is absent, use [Archive Index] and active", + " messages to decide whether to expand an archive.", + "- Do not fabricate details from summaries. When uncertain, expand first", + " or state that the information comes from a compressed summary.", + "- After expanding, cite the archive ID in your answer", + ' (e.g. "Based on archive_003, ...").', + ].join("\n"); } function warnOrInfo(logger: Logger, message: string): void { @@ -134,6 +337,34 @@ function warnOrInfo(logger: Logger, message: string): void { logger.info(message); } +function formatMessagesForLog(label: string, messages: AgentMessage[]): string { + const lines: string[] = [`===== ${label} (${messages.length} msgs) =====`]; + for (let i = 0; i < messages.length; i++) { + const msg = messages[i] as Record; + const role = msg.role ?? "?"; + const raw = msg.content; + let text: string; + if (typeof raw === "string") { + text = raw; + } else if (Array.isArray(raw)) { + text = (raw as Record[]) + .map((b) => { + if (b.type === "text") return b.text; + if (b.type === "toolUse") return `[toolUse: ${b.name}]`; + if (b.type === "toolResult") return `[toolResult]`; + return `[${b.type}]`; + }) + .join("\n"); + } else { + text = JSON.stringify(raw, null, 2); + } + lines.push(`--- [${i}] ${role} ---`); + lines.push(String(text)); + } + lines.push(`===== /${label} =====`); + return lines.join("\n"); +} + export function createMemoryOpenVikingContextEngine(params: { id: string; name: string; @@ -142,7 +373,7 @@ export function createMemoryOpenVikingContextEngine(params: { logger: Logger; getClient: () => Promise; resolveAgentId: (sessionId: string) => string; -}): ContextEngineWithSessionMapping { +}): ContextEngineWithCommit { const { id, name, @@ -153,41 +384,40 @@ export function createMemoryOpenVikingContextEngine(params: { resolveAgentId, } = params; - async function doCommitOVSession(sessionKey: string): Promise { + const diagEnabled = cfg.emitStandardDiagnostics; + const diag = (stage: string, sessionId: string, data: Record) => + emitDiag(logger, stage, sessionId, data, diagEnabled); + + async function doCommitOVSession(sessionId: string): Promise { try { const client = await getClient(); - const agentId = resolveAgentId(sessionKey); - const commitResult = await client.commitSession(sessionKey, { wait: true, agentId }); + const agentId = resolveAgentId(sessionId); + const commitResult = await client.commitSession(sessionId, { wait: true, agentId }); + const memCount = totalExtractedMemories(commitResult.memories_extracted); + if (commitResult.status === "failed") { + warnOrInfo(logger, `openviking: commit Phase 2 failed for session=${sessionId}: ${commitResult.error ?? "unknown"}`); + return false; + } + if (commitResult.status === "timeout") { + warnOrInfo(logger, `openviking: commit Phase 2 timed out for session=${sessionId}, task_id=${commitResult.task_id ?? "none"}`); + return false; + } logger.info( - `openviking: committed OV session for sessionKey=${sessionKey}, archived=${commitResult.archived ?? false}, memories=${commitResult.memories_extracted ?? 0}, task_id=${commitResult.task_id ?? "none"}`, + `openviking: committed OV session=${sessionId}, archived=${commitResult.archived ?? false}, memories=${memCount}, task_id=${commitResult.task_id ?? "none"}`, ); - await client.deleteSession(sessionKey, agentId).catch(() => {}); + return true; } catch (err) { - warnOrInfo(logger, `openviking: commit failed for sessionKey=${sessionKey}: ${String(err)}`); + warnOrInfo(logger, `openviking: commit failed for session=${sessionId}: ${String(err)}`); + return false; } } - function extractSessionKey(runtimeContext: Record | undefined): string | undefined { - if (!runtimeContext) { - return undefined; - } - const key = runtimeContext.sessionKey; - return typeof key === "string" && key.trim() ? key.trim() : undefined; - } - return { info: { id, name, version, - }, - - // --- session-mapping extensions --- - - getOVSessionForKey: (sessionKey: string) => sessionKey, - - async resolveOVSession(sessionKey: string): Promise { - return sessionKey; + ownsCompaction: true, }, commitOVSession: doCommitOVSession, @@ -203,10 +433,137 @@ export function createMemoryOpenVikingContextEngine(params: { }, async assemble(assembleParams): Promise { - return { - messages: assembleParams.messages, - estimatedTokens: estimateTokens(assembleParams.messages), - }; + const { messages } = assembleParams; + const tokenBudget = validTokenBudget(assembleParams.tokenBudget) ?? 128_000; + + const originalTokens = roughEstimate(messages); + logger.info(`openviking: assemble input msgs=${messages.length} ~${originalTokens} tokens, budget=${validTokenBudget(assembleParams.tokenBudget) ?? 128_000}`); + + const OVSessionId = assembleParams.sessionId; + diag("assemble_entry", OVSessionId, { + messagesCount: messages.length, + inputTokenEstimate: originalTokens, + tokenBudget, + messages: messageDigest(messages), + }); + + try { + const client = await getClient(); + const agentId = resolveAgentId(OVSessionId); + const ctx = await client.getSessionContext( + OVSessionId, + tokenBudget, + agentId, + ); + + const hasArchives = !!ctx?.latest_archive_id; + const activeCount = ctx?.messages?.length ?? 0; + const preAbstracts = ctx?.pre_archive_abstracts ?? []; + logger.info( + `openviking: assemble OV ctx hasArchives=${hasArchives} latestId=${ctx?.latest_archive_id ?? "none"} preAbstracts=${preAbstracts.length} active=${activeCount}`, + ); + + if (!ctx || (!hasArchives && activeCount === 0)) { + logger.info("openviking: assemble passthrough (no OV data)"); + diag("assemble_result", OVSessionId, { + passthrough: true, reason: "no_ov_data", + archiveCount: 0, activeCount: 0, + outputMessagesCount: messages.length, + inputTokenEstimate: originalTokens, + estimatedTokens: originalTokens, + tokensSaved: 0, savingPct: 0, + }); + return { messages, estimatedTokens: roughEstimate(messages) }; + } + + if (!hasArchives && ctx.messages.length < messages.length) { + logger.info(`openviking: assemble passthrough (OV msgs=${ctx.messages.length} < input msgs=${messages.length})`); + diag("assemble_result", OVSessionId, { + passthrough: true, reason: "ov_msgs_fewer_than_input", + archiveCount: 0, activeCount, + outputMessagesCount: messages.length, + inputTokenEstimate: originalTokens, + estimatedTokens: originalTokens, + tokensSaved: 0, savingPct: 0, + }); + return { messages, estimatedTokens: roughEstimate(messages) }; + } + + const assembled: AgentMessage[] = []; + + if (ctx.latest_archive_overview) { + assembled.push({ + role: "user" as const, + content: `[Session History Summary]\n${ctx.latest_archive_overview}`, + }); + } + + if (preAbstracts.length > 0 || ctx.latest_archive_id) { + const lines: string[] = preAbstracts.map( + (a) => `${a.archive_id}: ${a.abstract}`, + ); + if (ctx.latest_archive_id) { + lines.push( + `(latest: ${ctx.latest_archive_id} — see [Session History Summary] above)`, + ); + } + assembled.push({ + role: "user" as const, + content: `[Archive Index]\n${lines.join("\n")}`, + }); + } + + assembled.push(...ctx.messages.flatMap((m) => convertToAgentMessages(m))); + + normalizeAssistantContent(assembled); + const sanitized = sanitizeToolUseResultPairing(assembled as never[]) as AgentMessage[]; + + if (sanitized.length === 0 && messages.length > 0) { + logger.info("openviking: assemble passthrough (sanitized=0, falling back to original)"); + diag("assemble_result", OVSessionId, { + passthrough: true, reason: "sanitized_empty", + archiveCount: preAbstracts.length + (ctx.latest_archive_id ? 1 : 0), + activeCount, + outputMessagesCount: messages.length, + inputTokenEstimate: originalTokens, + estimatedTokens: originalTokens, + tokensSaved: 0, savingPct: 0, + }); + return { messages, estimatedTokens: roughEstimate(messages) }; + } + + const assembledTokens = roughEstimate(sanitized); + const archiveCount = preAbstracts.length + (ctx.latest_archive_id ? 1 : 0); + logger.info(`openviking: assemble result msgs=${sanitized.length} ~${assembledTokens} tokens (ovEstimate=${ctx.estimatedTokens}), archives=${archiveCount}, active=${activeCount}`); + const tokensSaved = originalTokens - assembledTokens; + const savingPct = originalTokens > 0 ? Math.round((tokensSaved / originalTokens) * 100) : 0; + + diag("assemble_result", OVSessionId, { + passthrough: false, + archiveCount, + activeCount, + outputMessagesCount: sanitized.length, + inputTokenEstimate: originalTokens, + estimatedTokens: assembledTokens, + tokensSaved, + savingPct, + latestArchiveId: ctx.latest_archive_id ?? null, + messages: messageDigest(sanitized), + }); + + return { + messages: sanitized, + estimatedTokens: ctx.estimatedTokens, + ...(hasArchives + ? { systemPromptAddition: buildSystemPromptAddition() } + : {}), + }; + } catch (err) { + diag("assemble_error", OVSessionId, { + error: String(err), + }); + return { messages, estimatedTokens: roughEstimate(messages) }; + } }, async afterTurn(afterTurnParams): Promise { @@ -214,13 +571,17 @@ export function createMemoryOpenVikingContextEngine(params: { return; } + const OVSessionId = afterTurnParams.sessionId; try { - const sessionKey = extractSessionKey(afterTurnParams.runtimeContext); - const agentId = resolveAgentId(sessionKey ?? afterTurnParams.sessionId); + const agentId = resolveAgentId(OVSessionId); const messages = afterTurnParams.messages ?? []; if (messages.length === 0) { - logger.info("openviking: auto-capture skipped (messages=0)"); + logger.info("openviking: afterTurn skipped (messages=0)"); + diag("afterTurn_skip", OVSessionId, { + reason: "no_messages", + totalMessages: 0, + }); return; } @@ -233,54 +594,202 @@ export function createMemoryOpenVikingContextEngine(params: { const { texts: newTexts, newCount } = extractNewTurnTexts(messages, start); if (newTexts.length === 0) { - logger.info("openviking: auto-capture skipped (no new user/assistant messages)"); + logger.info("openviking: afterTurn skipped (no new user/assistant messages)"); + diag("afterTurn_skip", OVSessionId, { + reason: "no_new_turn_messages", + totalMessages: messages.length, + prePromptMessageCount: start, + }); return; } + const newMessages = messages.slice(start).filter((m: any) => { + const r = (m as Record).role as string; + return r === "user" || r === "assistant"; + }) as AgentMessage[]; + const newMsgFull = messageDigest(newMessages); + const newTurnTokens = newMsgFull.reduce((s, d) => s + d.tokens, 0); + + diag("afterTurn_entry", OVSessionId, { + totalMessages: messages.length, + newMessageCount: newCount, + prePromptMessageCount: start, + newTurnTokens, + messages: newMsgFull, + }); + + const client = await getClient(); const turnText = newTexts.join("\n"); - const decision = getCaptureDecision(turnText, cfg.captureMode, cfg.captureMaxLength); - const preview = turnText.length > 80 ? `${turnText.slice(0, 80)}...` : turnText; - logger.info( - "openviking: capture-check " + - `shouldCapture=${String(decision.shouldCapture)} ` + - `reason=${decision.reason} newMsgCount=${newCount} text=\"${preview}\"`, - ); + const sanitized = turnText.replace(/[\s\S]*?<\/relevant-memories>/gi, " ").replace(/\s+/g, " ").trim(); + + if (sanitized) { + await client.addSessionMessage(OVSessionId, "user", sanitized, agentId); + logger.info( + `openviking: afterTurn stored ${newCount} msgs in session=${OVSessionId} (${sanitized.length} chars)`, + ); + } else { + logger.info("openviking: afterTurn skipped store (sanitized text empty)"); + diag("afterTurn_skip", OVSessionId, { + reason: "sanitized_empty", + }); + return; + } - if (!decision.shouldCapture) { - logger.info("openviking: auto-capture skipped (capture decision rejected)"); + const session = await client.getSession(OVSessionId, agentId); + const pendingTokens = session.pending_tokens ?? 0; + + if (pendingTokens < cfg.commitTokenThreshold) { + logger.info( + `openviking: pending_tokens=${pendingTokens}/${cfg.commitTokenThreshold} in session=${OVSessionId}, deferring commit`, + ); + diag("afterTurn_skip", OVSessionId, { + reason: "below_threshold", + pendingTokens, + commitTokenThreshold: cfg.commitTokenThreshold, + }); return; } - const client = await getClient(); - const OVSessionId = sessionKey ?? afterTurnParams.sessionId; - await client.addSessionMessage(OVSessionId, "user", decision.normalizedText, agentId); - const commitResult = await client.commitSession(OVSessionId, { wait: true, agentId }); logger.info( - `openviking: committed ${newCount} messages in session=${OVSessionId}, ` + - `archived=${commitResult.archived ?? false}, memories=${commitResult.memories_extracted ?? 0}, ` + + `openviking: committing session=${OVSessionId} (wait=false), pendingTokens=${pendingTokens}, threshold=${cfg.commitTokenThreshold}`, + ); + const commitResult = await client.commitSession(OVSessionId, { wait: false, agentId }); + logger.info( + `openviking: committed session=${OVSessionId}, ` + + `status=${commitResult.status}, archived=${commitResult.archived ?? false}, ` + `task_id=${commitResult.task_id ?? "none"} ${toJsonLog({ captured: [trimForLog(turnText, 260)] })}`, ); + + diag("afterTurn_commit", OVSessionId, { + pendingTokens, + commitTokenThreshold: cfg.commitTokenThreshold, + status: commitResult.status, + archived: commitResult.archived ?? false, + taskId: commitResult.task_id ?? null, + extractedMemories: (commitResult as any).extracted_memories ?? null, + }); } catch (err) { - warnOrInfo(logger, `openviking: auto-capture failed: ${String(err)}`); + warnOrInfo(logger, `openviking: afterTurn failed: ${String(err)}`); + diag("afterTurn_error", OVSessionId, { + error: String(err), + }); } }, async compact(compactParams): Promise { - const delegated = await tryLegacyCompact(compactParams); - if (delegated) { - return delegated; - } + const OVSessionId = compactParams.sessionId; + diag("compact_entry", OVSessionId, { + tokenBudget: compactParams.tokenBudget ?? null, + force: compactParams.force ?? false, + currentTokenCount: compactParams.currentTokenCount ?? null, + compactionTarget: compactParams.compactionTarget ?? null, + hasCustomInstructions: typeof compactParams.customInstructions === "string" && + compactParams.customInstructions.trim().length > 0, + }); - warnOrInfo( - logger, - "openviking: legacy compaction delegation unavailable; skipping compact", - ); + try { + const client = await getClient(); + const agentId = resolveAgentId(OVSessionId); + logger.info( + `openviking: compact committing session=${OVSessionId} (wait=true)`, + ); + const commitResult = await client.commitSession(OVSessionId, { wait: true, agentId }); + const memCount = totalExtractedMemories(commitResult.memories_extracted); + + if (commitResult.status === "failed") { + warnOrInfo( + logger, + `openviking: compact commit Phase 2 failed for session=${OVSessionId}: ${commitResult.error ?? "unknown"}`, + ); + diag("compact_result", OVSessionId, { + ok: false, + compacted: false, + reason: "commit_failed", + status: commitResult.status, + archived: commitResult.archived ?? false, + taskId: commitResult.task_id ?? null, + error: commitResult.error ?? null, + }); + return { + ok: false, + compacted: false, + reason: "commit_failed", + result: commitResult, + }; + } - return { - ok: true, - compacted: false, - reason: "legacy_compact_unavailable", - }; + if (commitResult.status === "timeout") { + warnOrInfo( + logger, + `openviking: compact commit Phase 2 timed out for session=${OVSessionId}, task_id=${commitResult.task_id ?? "none"}`, + ); + diag("compact_result", OVSessionId, { + ok: false, + compacted: false, + reason: "commit_timeout", + status: commitResult.status, + archived: commitResult.archived ?? false, + taskId: commitResult.task_id ?? null, + }); + return { + ok: false, + compacted: false, + reason: "commit_timeout", + result: commitResult, + }; + } + + logger.info( + `openviking: compact committed session=${OVSessionId}, archived=${commitResult.archived ?? false}, memories=${memCount}, task_id=${commitResult.task_id ?? "none"}`, + ); + + if (!commitResult.archived) { + diag("compact_result", OVSessionId, { + ok: true, + compacted: false, + reason: "commit_no_archive", + status: commitResult.status, + archived: commitResult.archived ?? false, + taskId: commitResult.task_id ?? null, + memories: memCount, + }); + return { + ok: true, + compacted: false, + reason: "commit_no_archive", + result: commitResult, + }; + } + + diag("compact_result", OVSessionId, { + ok: true, + compacted: true, + reason: "commit_completed", + status: commitResult.status, + archived: commitResult.archived ?? false, + taskId: commitResult.task_id ?? null, + memories: memCount, + }); + return { + ok: true, + compacted: true, + reason: "commit_completed", + result: commitResult, + }; + } catch (err) { + warnOrInfo(logger, `openviking: compact commit failed for session=${OVSessionId}: ${String(err)}`); + diag("compact_error", OVSessionId, { + error: String(err), + }); + return { + ok: false, + compacted: false, + reason: "commit_error", + result: { + error: String(err), + }, + }; + } }, }; } diff --git a/examples/openclaw-plugin/index.ts b/examples/openclaw-plugin/index.ts index 70fcd09ab..f2b21b6a6 100644 --- a/examples/openclaw-plugin/index.ts +++ b/examples/openclaw-plugin/index.ts @@ -5,7 +5,8 @@ import { Type } from "@sinclair/typebox"; import { memoryOpenVikingConfigSchema } from "./config.js"; import { OpenVikingClient, localClientCache, localClientPendingPromises, isMemoryUri } from "./client.js"; -import type { FindResultItem, PendingClientEntry } from "./client.js"; +import type { FindResultItem, PendingClientEntry, CommitSessionResult, OVMessage } from "./client.js"; +import { formatMessageFaithful } from "./context-engine.js"; import { isTranscriptLikeIngest, extractLatestUserText, @@ -27,7 +28,7 @@ import { prepareLocalPort, } from "./process-manager.js"; import { createMemoryOpenVikingContextEngine } from "./context-engine.js"; -import type { ContextEngineWithSessionMapping } from "./context-engine.js"; +import type { ContextEngineWithCommit } from "./context-engine.js"; type PluginLogger = { debug?: (message: string) => void; @@ -42,19 +43,27 @@ type HookAgentContext = { sessionKey?: string; }; +type ToolDefinition = { + name: string; + label: string; + description: string; + parameters: unknown; + execute: (_toolCallId: string, params: Record) => Promise; +}; + +type ToolContext = { + sessionKey?: string; + sessionId?: string; + agentId?: string; +}; + type OpenClawPluginApi = { pluginConfig?: unknown; logger: PluginLogger; - registerTool: ( - tool: { - name: string; - label: string; - description: string; - parameters: unknown; - execute: (_toolCallId: string, params: Record) => Promise; - }, - opts?: { name?: string; names?: string[] }, - ) => void; + registerTool: { + (tool: ToolDefinition, opts?: { name?: string; names?: string[] }): void; + (factory: (ctx: ToolContext) => ToolDefinition): void; + }; registerService: (service: { id: string; start: (ctx?: unknown) => void | Promise; @@ -72,6 +81,12 @@ const MAX_OPENVIKING_STDERR_LINES = 200; const MAX_OPENVIKING_STDERR_CHARS = 256_000; const AUTO_RECALL_TIMEOUT_MS = 5_000; +function totalCommitMemories(r: CommitSessionResult): number { + const m = r.memories_extracted; + if (!m || typeof m !== "object") return 0; + return Object.values(m).reduce((sum, n) => sum + (n ?? 0), 0); +} + const contextEnginePlugin = { id: "openviking", name: "Context Engine (OpenViking)", @@ -239,7 +254,6 @@ const contextEnginePlugin = { text: Type.String({ description: "Information to store as memory source text" }), role: Type.Optional(Type.String({ description: "Session role, default user" })), sessionId: Type.Optional(Type.String({ description: "Existing OpenViking session ID" })), - sessionKey: Type.Optional(Type.String({ description: "OpenClaw sessionKey — uses the persistent 1:1 mapped OV session" })), }), async execute(_toolCallId: string, params: Record) { const { text } = params as { text: string }; @@ -248,30 +262,53 @@ const contextEnginePlugin = { ? (params as { role: string }).role : "user"; const sessionIdIn = (params as { sessionId?: string }).sessionId; - const sessionKeyIn = (params as { sessionKey?: string }).sessionKey; api.logger.info?.( - `openviking: memory_store invoked (textLength=${text?.length ?? 0}, sessionId=${sessionIdIn ?? "auto"}, sessionKey=${sessionKeyIn ?? "none"})`, + `openviking: memory_store invoked (textLength=${text?.length ?? 0}, sessionId=${sessionIdIn ?? "auto"})`, ); let sessionId = sessionIdIn; - let usedMappedSession = false; - const storeAgentId = sessionKeyIn ? resolveAgentId(sessionKeyIn) : undefined; + let usedTempSession = false; + const storeAgentId = sessionId ? resolveAgentId(sessionId) : undefined; try { const c = await getClient(); - if (!sessionId && sessionKeyIn && contextEngineRef) { - sessionId = await contextEngineRef.resolveOVSession(sessionKeyIn); - usedMappedSession = true; - } if (!sessionId) { - return { - content: [{ type: "text", text: "Either sessionKey or sessionId is required to store memory." }], - details: { action: "rejected", reason: "missing_session_identifier" }, - }; + sessionId = `memory-store-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; + usedTempSession = true; } await c.addSessionMessage(sessionId, role, text, storeAgentId); const commitResult = await c.commitSession(sessionId, { wait: true, agentId: storeAgentId }); - const memoriesCount = commitResult.memories_extracted ?? 0; + const memoriesCount = totalCommitMemories(commitResult); + if (commitResult.status === "failed") { + api.logger.warn( + `openviking: memory_store commit failed (sessionId=${sessionId}): ${commitResult.error ?? "unknown"}`, + ); + return { + content: [{ type: "text", text: `Memory extraction failed for session ${sessionId}: ${commitResult.error ?? "unknown"}` }], + details: { + action: "failed", + sessionId, + status: "failed", + error: commitResult.error, + usedTempSession, + }, + }; + } + if (commitResult.status === "timeout") { + api.logger.warn( + `openviking: memory_store commit timed out (sessionId=${sessionId}), task_id=${commitResult.task_id ?? "none"}. Memories may still be extracting in background.`, + ); + return { + content: [{ type: "text", text: `Memory extraction timed out for session ${sessionId}. It may still complete in the background (task_id=${commitResult.task_id ?? "none"}).` }], + details: { + action: "timeout", + sessionId, + status: "timeout", + taskId: commitResult.task_id, + usedTempSession, + }, + }; + } if (memoriesCount === 0) { api.logger.warn( `openviking: memory_store committed but 0 memories extracted (sessionId=${sessionId}). ` + @@ -287,7 +324,14 @@ const contextEnginePlugin = { text: `Stored in OpenViking session ${sessionId} and committed ${memoriesCount} memories.`, }, ], - details: { action: "stored", sessionId, memoriesCount, archived: commitResult.archived ?? false, usedMappedSession }, + details: { + action: "stored", + sessionId, + memoriesCount, + status: commitResult.status, + archived: commitResult.archived ?? false, + usedTempSession, + }, }; } catch (err) { api.logger.warn(`openviking: memory_store failed: ${String(err)}`); @@ -400,7 +444,77 @@ const contextEnginePlugin = { }, { name: "memory_forget" }, ); - let contextEngineRef: ContextEngineWithSessionMapping | null = null; + api.registerTool((ctx: ToolContext) => ({ + name: "ov_archive_expand", + label: "Archive Expand (OpenViking)", + description: + "Retrieve original messages from a compressed session archive. " + + "Use when a session summary lacks specific details " + + "such as exact commands, file paths, code snippets, or config values. " + + "Check [Archive Index] to find the right archive ID.", + parameters: Type.Object({ + archiveId: Type.String({ + description: + 'Archive ID from [Archive Index] (e.g. "archive_002")', + }), + }), + async execute(_toolCallId: string, params: Record) { + const archiveId = String((params as { archiveId?: string }).archiveId ?? "").trim(); + if (!archiveId) { + return { + content: [{ type: "text", text: "Error: archiveId is required." }], + details: { error: "missing_param", param: "archiveId" }, + }; + } + + const sessionId = ctx.sessionId ?? ""; + if (!sessionId) { + return { + content: [{ type: "text", text: "Error: no active session." }], + details: { error: "no_session" }, + }; + } + + try { + const client = await getClient(); + const agentId = resolveAgentId(sessionId); + const detail = await client.getSessionArchive( + sessionId, + archiveId, + agentId, + ); + + const header = [ + `## ${detail.archive_id}`, + detail.abstract ? `**Summary**: ${detail.abstract}` : "", + `**Messages**: ${detail.messages.length}`, + "", + ].filter(Boolean).join("\n"); + + const body = detail.messages + .map((m: OVMessage) => formatMessageFaithful(m)) + .join("\n\n"); + + return { + content: [{ type: "text", text: `${header}\n${body}` }], + details: { + action: "expanded", + archiveId: detail.archive_id, + messageCount: detail.messages.length, + sessionId, + }, + }; + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + return { + content: [{ type: "text", text: `Failed to expand ${archiveId}: ${msg}` }], + details: { error: msg, archiveId, sessionId }, + }; + } + }, + })); + + let contextEngineRef: ContextEngineWithCommit | null = null; const sessionAgentIds = new Map(); const rememberSessionAgentId = (ctx: { @@ -430,7 +544,7 @@ const contextEnginePlugin = { api.on("before_prompt_build", async (event: unknown, ctx?: HookAgentContext) => { rememberSessionAgentId(ctx ?? {}); - const hookSessionId = ctx?.sessionId ?? ctx?.sessionKey ?? ""; + const hookSessionId = ctx?.sessionId ?? ""; const agentId = resolveAgentId(hookSessionId); let client: OpenVikingClient; try { @@ -561,11 +675,13 @@ const contextEnginePlugin = { rememberSessionAgentId(ctx ?? {}); }); api.on("before_reset", async (_event: unknown, ctx?: HookAgentContext) => { - const sessionKey = ctx?.sessionKey; - if (sessionKey && contextEngineRef) { + const sessionId = ctx?.sessionId; + if (sessionId && contextEngineRef) { try { - await contextEngineRef.commitOVSession(sessionKey); - api.logger.info(`openviking: committed OV session on reset for sessionKey=${sessionKey}`); + const ok = await contextEngineRef.commitOVSession(sessionId); + if (ok) { + api.logger.info(`openviking: committed OV session on reset for session=${sessionId}`); + } } catch (err) { api.logger.warn(`openviking: failed to commit OV session on reset: ${String(err)}`); } @@ -589,7 +705,7 @@ const contextEnginePlugin = { return contextEngineRef; }); api.logger.info( - "openviking: registered context-engine (before_prompt_build=auto-recall, afterTurn=auto-capture, sessionKey=1:1 mapping)", + "openviking: registered context-engine (before_prompt_build=auto-recall, afterTurn=auto-capture, assemble=archive+active, sessionId=1:1 mapping)", ); } else { api.logger.warn( diff --git a/examples/openclaw-plugin/openclaw.plugin.json b/examples/openclaw-plugin/openclaw.plugin.json index 1edf6c683..e6d918d4b 100644 --- a/examples/openclaw-plugin/openclaw.plugin.json +++ b/examples/openclaw-plugin/openclaw.plugin.json @@ -90,6 +90,12 @@ "advanced": true, "help": "Maximum estimated tokens for auto-recall memory injection" }, + "commitTokenThreshold": { + "label": "Commit Token Threshold", + "placeholder": "2000", + "advanced": true, + "help": "Minimum estimated pending tokens before auto-commit triggers. Set to 0 to commit every turn." + }, "ingestReplyAssist": { "label": "Ingest Reply Assist", "help": "When transcript-like memory ingestion is detected, add a lightweight reply instruction to reduce NO_REPLY.", @@ -106,6 +112,11 @@ "placeholder": "120", "help": "Minimum sanitized text length required before ingest reply assist can trigger.", "advanced": true + }, + "emitStandardDiagnostics": { + "label": "Standard diagnostics (diag JSON lines)", + "advanced": true, + "help": "Emit structured openviking: diag {...} for assemble/afterTurn. Set false to disable." } }, "configSchema": { @@ -163,6 +174,9 @@ "recallTokenBudget": { "type": "number" }, + "commitTokenThreshold": { + "type": "number" + }, "ingestReplyAssist": { "type": "boolean" }, @@ -171,6 +185,9 @@ }, "ingestReplyAssistMinChars": { "type": "number" + }, + "emitStandardDiagnostics": { + "type": "boolean" } } } diff --git a/examples/openclaw-plugin/session-transcript-repair.ts b/examples/openclaw-plugin/session-transcript-repair.ts new file mode 100644 index 000000000..ec5ab6b0d --- /dev/null +++ b/examples/openclaw-plugin/session-transcript-repair.ts @@ -0,0 +1,530 @@ +/** + * Tool use/result pairing repair for assembled context. + * + * Copied from openclaw core (src/agents/session-transcript-repair.ts). + * Only change: replaced `import type { AgentMessage } from "@mariozechner/pi-agent-core"` + * with import from local tool-call-id.ts to avoid the external dependency. + */ + +import type { AgentMessage } from "./tool-call-id.js"; +import { extractToolCallsFromAssistant, extractToolResultId } from "./tool-call-id.js"; + +const TOOL_CALL_NAME_MAX_CHARS = 64; +const TOOL_CALL_NAME_RE = /^[A-Za-z0-9_-]+$/; + +type RawToolCallBlock = { + type?: unknown; + id?: unknown; + name?: unknown; + input?: unknown; + arguments?: unknown; +}; + +function isRawToolCallBlock(block: unknown): block is RawToolCallBlock { + if (!block || typeof block !== "object") { + return false; + } + const type = (block as { type?: unknown }).type; + return ( + typeof type === "string" && + (type === "toolCall" || type === "toolUse" || type === "functionCall") + ); +} + +function hasToolCallInput(block: RawToolCallBlock): boolean { + const hasInput = "input" in block ? block.input !== undefined && block.input !== null : false; + const hasArguments = + "arguments" in block ? block.arguments !== undefined && block.arguments !== null : false; + return hasInput || hasArguments; +} + +function hasNonEmptyStringField(value: unknown): boolean { + return typeof value === "string" && value.trim().length > 0; +} + +function hasToolCallId(block: RawToolCallBlock): boolean { + return hasNonEmptyStringField(block.id); +} + +function normalizeAllowedToolNames(allowedToolNames?: Iterable): Set | null { + if (!allowedToolNames) { + return null; + } + const normalized = new Set(); + for (const name of allowedToolNames) { + if (typeof name !== "string") { + continue; + } + const trimmed = name.trim(); + if (trimmed) { + normalized.add(trimmed.toLowerCase()); + } + } + return normalized.size > 0 ? normalized : null; +} + +function hasToolCallName(block: RawToolCallBlock, allowedToolNames: Set | null): boolean { + if (typeof block.name !== "string") { + return false; + } + const trimmed = block.name.trim(); + if (!trimmed) { + return false; + } + if (trimmed.length > TOOL_CALL_NAME_MAX_CHARS || !TOOL_CALL_NAME_RE.test(trimmed)) { + return false; + } + if (!allowedToolNames) { + return true; + } + return allowedToolNames.has(trimmed.toLowerCase()); +} + +function redactSessionsSpawnAttachmentsArgs(value: unknown): unknown { + if (!value || typeof value !== "object") { + return value; + } + const rec = value as Record; + const raw = rec.attachments; + if (!Array.isArray(raw)) { + return value; + } + const next = raw.map((item) => { + if (!item || typeof item !== "object") { + return item; + } + const a = item as Record; + if (!Object.hasOwn(a, "content")) { + return item; + } + const { content: _content, ...rest } = a; + return { ...rest, content: "__OPENCLAW_REDACTED__" }; + }); + return { ...rec, attachments: next }; +} + +function sanitizeToolCallBlock(block: RawToolCallBlock): RawToolCallBlock { + const rawName = typeof block.name === "string" ? block.name : undefined; + const trimmedName = rawName?.trim(); + const hasTrimmedName = typeof trimmedName === "string" && trimmedName.length > 0; + const normalizedName = hasTrimmedName ? trimmedName : undefined; + const nameChanged = hasTrimmedName && rawName !== trimmedName; + + const isSessionsSpawn = normalizedName?.toLowerCase() === "sessions_spawn"; + + if (!isSessionsSpawn) { + if (!nameChanged) { + return block; + } + return { ...(block as Record), name: normalizedName } as RawToolCallBlock; + } + + // Redact large/sensitive inline attachment content from persisted transcripts. + // Apply redaction to both `.arguments` and `.input` properties since block structures can vary + const nextArgs = redactSessionsSpawnAttachmentsArgs(block.arguments); + const nextInput = redactSessionsSpawnAttachmentsArgs(block.input); + if (nextArgs === block.arguments && nextInput === block.input && !nameChanged) { + return block; + } + + const next = { ...(block as Record) }; + if (nameChanged && normalizedName) { + next.name = normalizedName; + } + if (nextArgs !== block.arguments || Object.hasOwn(block, "arguments")) { + next.arguments = nextArgs; + } + if (nextInput !== block.input || Object.hasOwn(block, "input")) { + next.input = nextInput; + } + return next as RawToolCallBlock; +} + +function makeMissingToolResult(params: { + toolCallId: string; + toolName?: string; +}): Extract { + return { + role: "toolResult", + toolCallId: params.toolCallId, + toolName: params.toolName ?? "unknown", + content: [ + { + type: "text", + text: "[openclaw] missing tool result in session history; inserted synthetic error result for transcript repair.", + }, + ], + isError: true, + timestamp: Date.now(), + } as Extract; +} + +function trimNonEmptyString(value: unknown): string | undefined { + if (typeof value !== "string") { + return undefined; + } + const trimmed = value.trim(); + return trimmed || undefined; +} + +function normalizeToolResultName( + message: Extract, + fallbackName?: string, +): Extract { + const rawToolName = (message as { toolName?: unknown }).toolName; + const normalizedToolName = trimNonEmptyString(rawToolName); + if (normalizedToolName) { + if (rawToolName === normalizedToolName) { + return message; + } + return { ...message, toolName: normalizedToolName }; + } + + const normalizedFallback = trimNonEmptyString(fallbackName); + if (normalizedFallback) { + return { ...message, toolName: normalizedFallback }; + } + + if (typeof rawToolName === "string") { + return { ...message, toolName: "unknown" }; + } + return message; +} + +export { makeMissingToolResult }; + +export type ToolCallInputRepairReport = { + messages: AgentMessage[]; + droppedToolCalls: number; + droppedAssistantMessages: number; +}; + +export type ToolCallInputRepairOptions = { + allowedToolNames?: Iterable; +}; + +export type ToolUseResultPairingOptions = { + preserveErroredAssistantResults?: boolean; +}; + +export function stripToolResultDetails(messages: AgentMessage[]): AgentMessage[] { + let touched = false; + const out: AgentMessage[] = []; + for (const msg of messages) { + if (!msg || typeof msg !== "object" || (msg as { role?: unknown }).role !== "toolResult") { + out.push(msg); + continue; + } + if (!("details" in msg)) { + out.push(msg); + continue; + } + const sanitized = { ...(msg as object) } as { details?: unknown }; + delete sanitized.details; + touched = true; + out.push(sanitized as unknown as AgentMessage); + } + return touched ? out : messages; +} + +export function repairToolCallInputs( + messages: AgentMessage[], + options?: ToolCallInputRepairOptions, +): ToolCallInputRepairReport { + let droppedToolCalls = 0; + let droppedAssistantMessages = 0; + let changed = false; + const out: AgentMessage[] = []; + const allowedToolNames = normalizeAllowedToolNames(options?.allowedToolNames); + + for (const msg of messages) { + if (!msg || typeof msg !== "object") { + out.push(msg); + continue; + } + + if (msg.role !== "assistant" || !Array.isArray(msg.content)) { + out.push(msg); + continue; + } + + const nextContent: typeof msg.content = []; + let droppedInMessage = 0; + let messageChanged = false; + + for (const block of msg.content) { + if ( + isRawToolCallBlock(block) && + (!hasToolCallInput(block) || + !hasToolCallId(block) || + !hasToolCallName(block, allowedToolNames)) + ) { + droppedToolCalls += 1; + droppedInMessage += 1; + changed = true; + messageChanged = true; + continue; + } + if (isRawToolCallBlock(block)) { + if ( + (block as { type?: unknown }).type === "toolCall" || + (block as { type?: unknown }).type === "toolUse" || + (block as { type?: unknown }).type === "functionCall" + ) { + // Only sanitize (redact) sessions_spawn blocks; all others are passed through + // unchanged to preserve provider-specific shapes (e.g. toolUse.input for Anthropic). + const blockName = + typeof (block as { name?: unknown }).name === "string" + ? (block as { name: string }).name.trim() + : undefined; + if (blockName?.toLowerCase() === "sessions_spawn") { + const sanitized = sanitizeToolCallBlock(block); + if (sanitized !== block) { + changed = true; + messageChanged = true; + } + nextContent.push(sanitized as typeof block); + } else { + if (typeof (block as { name?: unknown }).name === "string") { + const rawName = (block as { name: string }).name; + const trimmedName = rawName.trim(); + if (rawName !== trimmedName && trimmedName) { + const renamed = { ...(block as object), name: trimmedName } as typeof block; + nextContent.push(renamed); + changed = true; + messageChanged = true; + } else { + nextContent.push(block); + } + } else { + nextContent.push(block); + } + } + continue; + } + } else { + nextContent.push(block); + } + } + + if (droppedInMessage > 0) { + if (nextContent.length === 0) { + droppedAssistantMessages += 1; + changed = true; + continue; + } + out.push({ ...msg, content: nextContent }); + continue; + } + + if (messageChanged) { + out.push({ ...msg, content: nextContent }); + continue; + } + + out.push(msg); + } + + return { + messages: changed ? out : messages, + droppedToolCalls, + droppedAssistantMessages, + }; +} + +export function sanitizeToolCallInputs( + messages: AgentMessage[], + options?: ToolCallInputRepairOptions, +): AgentMessage[] { + return repairToolCallInputs(messages, options).messages; +} + +export function sanitizeToolUseResultPairing( + messages: AgentMessage[], + options?: ToolUseResultPairingOptions, +): AgentMessage[] { + return repairToolUseResultPairing(messages, options).messages; +} + +export type ToolUseRepairReport = { + messages: AgentMessage[]; + added: Array>; + droppedDuplicateCount: number; + droppedOrphanCount: number; + moved: boolean; +}; + +export function repairToolUseResultPairing( + messages: AgentMessage[], + options?: ToolUseResultPairingOptions, +): ToolUseRepairReport { + // Anthropic (and Cloud Code Assist) reject transcripts where assistant tool calls are not + // immediately followed by matching tool results. Session files can end up with results + // displaced (e.g. after user turns) or duplicated. Repair by: + // - moving matching toolResult messages directly after their assistant toolCall turn + // - inserting synthetic error toolResults for missing ids + // - dropping duplicate toolResults for the same id (anywhere in the transcript) + const out: AgentMessage[] = []; + const added: Array> = []; + const seenToolResultIds = new Set(); + let droppedDuplicateCount = 0; + let droppedOrphanCount = 0; + let moved = false; + let changed = false; + + const pushToolResult = (msg: Extract) => { + const id = extractToolResultId(msg); + if (id && seenToolResultIds.has(id)) { + droppedDuplicateCount += 1; + changed = true; + return; + } + if (id) { + seenToolResultIds.add(id); + } + out.push(msg); + }; + + for (let i = 0; i < messages.length; i += 1) { + const msg = messages[i]; + if (!msg || typeof msg !== "object") { + out.push(msg); + continue; + } + + const role = (msg as { role?: unknown }).role; + if (role !== "assistant") { + // Tool results must only appear directly after the matching assistant tool call turn. + // Any "free-floating" toolResult entries in session history can make strict providers + // (Anthropic-compatible APIs, MiniMax, Cloud Code Assist) reject the entire request. + if (role !== "toolResult") { + out.push(msg); + } else { + droppedOrphanCount += 1; + changed = true; + } + continue; + } + + const assistant = msg as Extract; + + const toolCalls = extractToolCallsFromAssistant(assistant); + if (toolCalls.length === 0) { + out.push(msg); + continue; + } + + const toolCallIds = new Set(toolCalls.map((t) => t.id)); + const toolCallNamesById = new Map(toolCalls.map((t) => [t.id, t.name] as const)); + + const spanResultsById = new Map>(); + const remainder: AgentMessage[] = []; + + let j = i + 1; + for (; j < messages.length; j += 1) { + const next = messages[j]; + if (!next || typeof next !== "object") { + remainder.push(next); + continue; + } + + const nextRole = (next as { role?: unknown }).role; + if (nextRole === "assistant") { + break; + } + + if (nextRole === "toolResult") { + const toolResult = next as Extract; + const id = extractToolResultId(toolResult); + if (id && toolCallIds.has(id)) { + if (seenToolResultIds.has(id)) { + droppedDuplicateCount += 1; + changed = true; + continue; + } + const normalizedToolResult = normalizeToolResultName( + toolResult, + toolCallNamesById.get(id), + ); + if (normalizedToolResult !== toolResult) { + changed = true; + } + if (!spanResultsById.has(id)) { + spanResultsById.set(id, normalizedToolResult); + } + continue; + } + } + + // Drop tool results that don't match the current assistant tool calls. + if (nextRole !== "toolResult") { + remainder.push(next); + } else { + droppedOrphanCount += 1; + changed = true; + } + } + + // Aborted/errored assistant turns should never synthesize missing tool results, but + // the replay sanitizer can still legitimately retain real tool results for surviving + // tool calls in the same turn after malformed siblings are dropped. + const stopReason = (assistant as { stopReason?: string }).stopReason; + if (stopReason === "error" || stopReason === "aborted") { + out.push(msg); + if (options?.preserveErroredAssistantResults) { + for (const toolCall of toolCalls) { + const result = spanResultsById.get(toolCall.id); + if (!result) { + continue; + } + pushToolResult(result); + } + } + for (const rem of remainder) { + out.push(rem); + } + i = j - 1; + continue; + } + + out.push(msg); + + if (spanResultsById.size > 0 && remainder.length > 0) { + moved = true; + changed = true; + } + + for (const call of toolCalls) { + const existing = spanResultsById.get(call.id); + if (existing) { + pushToolResult(existing); + } else { + const missing = makeMissingToolResult({ + toolCallId: call.id, + toolName: call.name, + }); + added.push(missing); + changed = true; + pushToolResult(missing); + } + } + + for (const rem of remainder) { + if (!rem || typeof rem !== "object") { + out.push(rem); + continue; + } + out.push(rem); + } + i = j - 1; + } + + const changedOrMoved = changed || moved; + return { + messages: changedOrMoved ? out : messages, + added, + droppedDuplicateCount, + droppedOrphanCount, + moved: changedOrMoved, + }; +} diff --git a/examples/openclaw-plugin/test-memory-chain.py b/examples/openclaw-plugin/test-memory-chain.py new file mode 100644 index 000000000..877fbf31f --- /dev/null +++ b/examples/openclaw-plugin/test-memory-chain.py @@ -0,0 +1,934 @@ +#!/usr/bin/env python3 +""" +OpenClaw 记忆链路完整测试脚本 + +验证 OpenViking 记忆插件重构后的端到端链路: +1. afterTurn: 本轮消息无损写入 OpenViking session,sessionId 一致 +2. commit: 归档消息 + 提取长期记忆 + .meta.json 写入 +3. assemble: 同用户继续对话时, 从 latest_archive_overview + active messages 重组上下文 +4. assemble budget trimming: 小 token budget 下 latest_archive_overview 被裁剪 +5. sessionId 一致性: 整条链路使用统一的 sessionId (无 sessionKey) +6. 新用户记忆召回: 验证 before_prompt_build auto-recall + +测试流程: +Phase 1: 多轮对话 (12 轮) — afterTurn 写入 +Phase 2: afterTurn 验证 — 检查 OV session 内部状态 +Phase 3: Commit 验证 — 触发 commit, 检查归档结构 +Phase 4: Assemble 验证 — 同用户继续对话, 验证上下文重组 +Phase 5: SessionId 一致性验证 +Phase 6: 新用户记忆召回 + +前提: +- OpenViking 服务已启动 (默认 http://127.0.0.1:8000) +- OpenClaw Gateway 已启动并配置了 OpenViking 插件 + +用法: + python test-memory-chain.py + python test-memory-chain.py --gateway http://127.0.0.1:18790 --openviking http://127.0.0.1:8000 + python test-memory-chain.py --phase chat + python test-memory-chain.py --phase afterTurn + python test-memory-chain.py --phase commit + python test-memory-chain.py --phase assemble + python test-memory-chain.py --phase session-id + python test-memory-chain.py --phase recall + python test-memory-chain.py --verbose + +依赖: + pip install requests rich +""" + +import argparse +import json +import time +import uuid +from datetime import datetime +from typing import Any + +import requests +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.table import Table +from rich.tree import Tree + +# ── 常量 ─────────────────────────────────────────────────────────────────── + +USER_ID = f"test-chain-{uuid.uuid4().hex[:8]}" +DISPLAY_NAME = "测试用户" +DEFAULT_GATEWAY = "http://127.0.0.1:18790" +DEFAULT_OPENVIKING = "http://127.0.0.1:8000" +AGENT_ID = "openclaw" + +console = Console() + +# ── 测试结果收集 ────────────────────────────────────────────────────────── + +assertions: list[dict] = [] + + +def check(label: str, condition: bool, detail: str = ""): + """记录一个断言结果。""" + assertions.append({"label": label, "ok": condition, "detail": detail}) + icon = "[green]✓[/green]" if condition else "[red]✗[/red]" + msg = f" {icon} {label}" + if detail: + msg += f" [dim]({detail})[/dim]" + console.print(msg) + + +# ── 对话数据 ────────────────────────────────────────────────────────────── + +CHAT_MESSAGES = [ + "你好,我是一个软件工程师,我叫张明,在一家科技公司工作。我主要负责后端服务开发,使用的技术栈是 Python 和 Go。最近我们在重构一个订单系统,遇到了不少挑战。", + "关于订单系统的问题,主要是性能瓶颈。我们发现在高峰期,数据库连接池经常被耗尽。目前用的是 PostgreSQL,连接池大小设置的是100,但每秒峰值请求量有5000。你有什么建议吗?", + "谢谢你的建议。我还想问一下,我们目前的缓存策略用的是 Redis,但缓存击穿的问题很严重。热点数据过期后,大量请求直接打到数据库。我们尝试过加互斥锁,但性能下降很多。", + "对了,关于代码风格,我们团队更倾向于使用函数式编程的思想,尽量避免副作用。变量命名用 snake_case,文档用中文写。代码审查很严格,每个 PR 至少需要两人 review。", + "说到工作流程,我们每天早上9点站会,周三下午技术分享会。我一般上午写代码,下午处理 code review 和会议。晚上如果不加班,会看看技术书籍或者写写博客。", + "我最近在学习分布式系统的设计,正在看《数据密集型应用系统设计》这本书。之前看完了《深入理解计算机系统》,收获很大。你有什么好的分布式系统学习资料推荐吗?", + "目前订单系统重构的进度大概完成了60%,还剩下支付模块和库存同步模块。支付模块比较复杂,需要对接多个支付渠道。我们打算用消息队列来解耦库存同步。", + "消息队列我们在 Kafka 和 RabbitMQ 之间犹豫。Kafka 吞吐量高,但运维复杂;RabbitMQ 功能丰富,但性能稍差。我们的消息量大概每天1000万条,你觉得选哪个好?", + "我们团队有8个人,3个后端、2个前端、1个测试、1个运维,还有1个产品经理。后端老王经验最丰富,遇到难题都找他。测试小李很细心,bug检出率很高。", + "对了,跟我聊天的时候注意几点:我喜欢简洁直接的回答,不要太啰嗦;技术问题最好带代码示例;如果不确定的问题要说明,不要瞎编。谢谢!", + "补充一下,我们的监控用的是 Prometheus + Grafana,日志用 ELK Stack。最近在考虑引入链路追踪,OpenTelemetry 看起来不错,但不知道跟现有系统集成麻不麻烦。", + "昨天线上出了个诡异的 bug,某个接口偶发超时,但日志里看不出什么问题。后来发现是下游服务的连接数满了,但监控指标没配好,没报警。这种问题怎么预防比较好?", +] + +# assemble 阶段: 同用户继续对话,用于验证 assemble 是否携带了摘要上下文 +ASSEMBLE_FOLLOWUP_MESSAGES = [ + { + "question": "对了,我之前提到的订单系统重构进展到哪了?支付模块开始了吗?", + "anchor_keywords": ["订单系统", "支付模块", "60%"], + "hook": "assemble — latest_archive_overview 重组", + }, + { + "question": "我们团队消息队列最终选了什么?之前我跟你讨论过 Kafka 和 RabbitMQ 的取舍。", + "anchor_keywords": ["Kafka", "RabbitMQ", "消息队列"], + "hook": "assemble — latest_archive_overview 重组", + }, +] + +# 新用户记忆召回 +RECALL_QUESTIONS = [ + { + "question": "我是做什么工作的?用什么技术栈?请简洁回答", + "expected_keywords": ["软件工程师", "Python", "Go"], + }, + { + "question": "我最近在做什么项目?遇到了什么技术挑战?请简洁回答", + "expected_keywords": ["订单系统", "性能瓶颈", "缓存"], + }, + { + "question": "跟我聊天有什么注意事项?请简洁回答", + "expected_keywords": ["简洁", "代码示例"], + }, +] + + +# ── Gateway / OpenViking API ───────────────────────────────────────────── + + +def send_message(gateway_url: str, message: str, user_id: str) -> dict: + """通过 OpenClaw Responses API 发送消息。""" + resp = requests.post( + f"{gateway_url}/v1/responses", + json={"model": "openclaw", "input": message, "user": user_id}, + timeout=300, + ) + resp.raise_for_status() + return resp.json() + + +def extract_reply_text(data: dict) -> str: + """从 Responses API 响应中提取助手回复文本。""" + for item in data.get("output", []): + if item.get("type") == "message" and item.get("role") == "assistant": + for part in item.get("content", []): + if part.get("type") in ("text", "output_text"): + return part.get("text", "") + return "(无回复)" + + +class OpenVikingInspector: + """OpenViking 内部状态检查器。""" + + def __init__(self, base_url: str, api_key: str = "", agent_id: str = AGENT_ID): + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.agent_id = agent_id + + def _headers(self) -> dict: + h: dict[str, str] = {"Content-Type": "application/json"} + if self.api_key: + h["X-API-Key"] = self.api_key + if self.agent_id: + h["X-OpenViking-Agent"] = self.agent_id + return h + + def _get(self, path: str, timeout: int = 10) -> dict | None: + try: + resp = requests.get(f"{self.base_url}{path}", headers=self._headers(), timeout=timeout) + if resp.status_code == 200: + data = resp.json() + return data.get("result", data) + return None + except Exception as e: + console.print(f"[dim]GET {path} 失败: {e}[/dim]") + return None + + def _post(self, path: str, body: dict | None = None, timeout: int = 30) -> dict | None: + try: + resp = requests.post( + f"{self.base_url}{path}", + headers=self._headers(), + json=body or {}, + timeout=timeout, + ) + if resp.status_code == 200: + data = resp.json() + return data.get("result", data) + return None + except Exception as e: + console.print(f"[dim]POST {path} 失败: {e}[/dim]") + return None + + def health_check(self) -> bool: + try: + resp = requests.get(f"{self.base_url}/health", timeout=5) + return resp.status_code == 200 + except Exception: + return False + + def get_session(self, session_id: str) -> dict | None: + return self._get(f"/api/v1/sessions/{session_id}") + + def get_session_messages(self, session_id: str) -> list | None: + result = self._get(f"/api/v1/sessions/{session_id}/messages") + if isinstance(result, list): + return result + if isinstance(result, dict): + return result.get("messages", []) + return None + + def get_session_context(self, session_id: str, token_budget: int = 128000) -> dict | None: + return self._get(f"/api/v1/sessions/{session_id}/context?token_budget={token_budget}") + + def commit_session(self, session_id: str, wait: bool = True) -> dict | None: + result = self._post(f"/api/v1/sessions/{session_id}/commit", timeout=120) + if not result: + return None + + if wait and result.get("task_id"): + task_id = result["task_id"] + deadline = time.time() + 120 + while time.time() < deadline: + time.sleep(0.5) + task = self._get(f"/api/v1/tasks/{task_id}") + if not task: + continue + if task.get("status") == "completed": + result["status"] = "completed" + result["memories_extracted"] = task.get("result", {}).get( + "memories_extracted", {} + ) + return result + if task.get("status") == "failed": + result["status"] = "failed" + result["error"] = task.get("error") + return result + + return result + + def search_memories( + self, query: str, target_uri: str = "viking://user/memories", limit: int = 10 + ) -> list: + result = self._post( + "/api/v1/search/find", + {"query": query, "target_uri": target_uri, "limit": limit}, + ) + if isinstance(result, dict): + return result.get("memories", []) + return [] + + def list_fs(self, uri: str) -> list: + result = self._get(f"/api/v1/fs/ls?uri={uri}&output=original") + return result if isinstance(result, list) else [] + + def read_fs(self, uri: str) -> str | None: + """读取 fs 中某个文件的内容。""" + result = self._get(f"/api/v1/content/read?uri={uri}") + if isinstance(result, str): + return result + if isinstance(result, dict): + return result.get("content") + return None + + +# ── 渲染函数 ────────────────────────────────────────────────────────────── + + +def render_reply(text: str, title: str = "回复"): + lines = text.split("\n") + if len(lines) > 25: + text = "\n".join(lines[:25]) + f"\n\n... (共 {len(lines)} 行,已截断)" + console.print(Panel(Markdown(text), title=f"[green]{title}[/green]", border_style="green")) + + +def render_json(data: Any, title: str = "JSON"): + console.print( + Panel(json.dumps(data, indent=2, ensure_ascii=False, default=str)[:2000], title=title) + ) + + +def render_session_info(info: dict, title: str = "Session 信息"): + table = Table(title=title, show_header=True) + table.add_column("属性", style="cyan") + table.add_column("值", style="green") + for key, value in info.items(): + if isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False) + table.add_row(str(key), str(value)[:120]) + console.print(table) + + +# ── Phase 1: 多轮对话 ──────────────────────────────────────────────────── + + +def run_phase_chat(gateway_url: str, user_id: str, delay: float, verbose: bool) -> tuple[int, int]: + """Phase 1: 多轮对话 — 测试 afterTurn 写入。""" + console.print() + console.rule(f"[bold]Phase 1: 多轮对话 ({len(CHAT_MESSAGES)} 轮) — afterTurn 写入[/bold]") + console.print(f"[yellow]用户ID:[/yellow] {user_id}") + console.print(f"[yellow]Gateway:[/yellow] {gateway_url}") + console.print() + + total = len(CHAT_MESSAGES) + ok = fail = 0 + + for i, msg in enumerate(CHAT_MESSAGES, 1): + console.rule(f"[dim]Turn {i}/{total}[/dim]", style="dim") + console.print( + Panel( + msg[:200] + ("..." if len(msg) > 200 else ""), + title=f"[bold cyan]用户 [{i}/{total}][/bold cyan]", + border_style="cyan", + ) + ) + try: + data = send_message(gateway_url, msg, user_id) + reply = extract_reply_text(data) + render_reply(reply[:500] + ("..." if len(reply) > 500 else "")) + ok += 1 + except Exception as e: + console.print(f"[red][ERROR][/red] {e}") + fail += 1 + + if i < total: + time.sleep(delay) + + console.print() + console.print(f"[yellow]对话完成:[/yellow] {ok} 成功, {fail} 失败") + + wait = max(delay * 2, 5) + console.print(f"[yellow]等待 {wait:.0f}s 让 afterTurn 处理完成...[/yellow]") + time.sleep(wait) + + return ok, fail + + +# ── Phase 2: afterTurn 验证 ────────────────────────────────────────────── + + +def run_phase_after_turn(openviking_url: str, user_id: str, verbose: bool) -> bool: + """Phase 2: afterTurn 验证 — 检查 OV session 内部状态确认消息已写入。""" + console.print() + console.rule("[bold]Phase 2: afterTurn 验证 — 检查 OV session 消息写入[/bold]") + console.print() + console.print("[dim]验证点:[/dim]") + console.print("[dim]- afterTurn 应将每轮消息写入 OV session[/dim]") + console.print("[dim]- session.message_count > 0[/dim]") + console.print("[dim]- pending_tokens > 0 (消息尚未 commit)[/dim]") + console.print("[dim]- sessionId 应为 OpenClaw 传入的 user_id[/dim]") + console.print() + + inspector = OpenVikingInspector(openviking_url) + + # 2.1 健康检查 + console.print("[bold]2.1 OpenViking 健康检查[/bold]") + healthy = inspector.health_check() + check("OpenViking 服务可达", healthy) + if not healthy: + return False + + # 2.2 Session 存在且有消息 + console.print("\n[bold]2.2 Session 存在性 & 消息计数[/bold]") + session_info = inspector.get_session(user_id) + check("Session 存在", session_info is not None, f"session_id={user_id}") + + if not session_info: + console.print("[red]Session 不存在,无法继续验证[/red]") + return False + + if verbose: + render_session_info(session_info, f"Session: {user_id}") + + msg_count = session_info.get("message_count", 0) + check( + "message_count > 0 (afterTurn 写入成功)", + msg_count > 0, + f"message_count={msg_count}", + ) + + # pending_tokens 表示尚未 commit 的 token 数 + pending = session_info.get("pending_tokens", 0) + check( + "pending_tokens > 0 (有待 commit 的内容)", + pending > 0, + f"pending_tokens={pending}", + ) + + # 2.3 检查消息内容: 至少部分对话内容能在 OV 消息中找到 + console.print("\n[bold]2.3 消息内容抽样校验[/bold]") + messages = inspector.get_session_messages(user_id) + if messages is not None: + check("能获取到 session 消息列表", True, f"共 {len(messages)} 条消息") + + # 取第一条用户消息的特征文本做匹配 + sample_text = "张明" + all_text = json.dumps(messages, ensure_ascii=False) + check( + f"消息内容包含特征文本「{sample_text}」", + sample_text in all_text, + "验证 afterTurn 写入的内容与发送一致", + ) + + sample_text_2 = "PostgreSQL" + check( + f"消息内容包含特征文本「{sample_text_2}」", + sample_text_2 in all_text, + "验证多轮消息写入", + ) + else: + check("能获取到 session 消息列表", False, "GET messages 返回 None") + + # 2.4 context 在 commit 前应返回 messages + console.print("\n[bold]2.4 Commit 前 context[/bold]") + ctx = inspector.get_session_context(user_id) + if ctx: + ctx_msg_count = len(ctx.get("messages", [])) + has_summary_archive = bool(ctx.get("latest_archive_overview")) + check( + "context 返回 messages > 0", + ctx_msg_count > 0, + f"messages={ctx_msg_count}", + ) + check( + "commit 前 latest_archive_overview 为空", + not has_summary_archive, + f"latest_archive_overview={ctx.get('latest_archive_overview')}", + ) + if verbose and ctx.get("stats"): + console.print(f" [dim]stats: {ctx['stats']}[/dim]") + else: + check("context 可调用", False, "返回 None") + + return True + + +# ── Phase 3: Commit 验证 ───────────────────────────────────────────────── + + +def run_phase_commit(openviking_url: str, user_id: str, verbose: bool) -> bool: + """Phase 3: Commit 验证 — 触发 commit, 检查归档结构和记忆提取。""" + console.print() + console.rule("[bold]Phase 3: Commit 验证 — 触发 session.commit()[/bold]") + console.print() + console.print("[dim]验证点:[/dim]") + console.print("[dim]- commit 返回 status=completed/accepted[/dim]") + console.print("[dim]- 消息被归档 (archived=true)[/dim]") + console.print("[dim]- 提取出记忆 (memories_extracted > 0)[/dim]") + console.print("[dim]- 归档目录含 .overview.md 和 .meta.json[/dim]") + console.print() + + inspector = OpenVikingInspector(openviking_url) + + # 3.1 执行 commit + console.print("[bold]3.1 执行 session.commit()[/bold]") + console.print("[dim]正在等待 commit 完成 (可能需要 1-2 分钟)...[/dim]") + + commit_result = inspector.commit_session(user_id, wait=True) + check("commit 返回结果", commit_result is not None) + + if not commit_result: + console.print("[red]Commit 失败,无法继续[/red]") + return False + + if verbose: + render_json(commit_result, "Commit 结果") + + status = commit_result.get("status", "unknown") + check( + "commit status 为 completed 或 accepted", + status in ("completed", "accepted"), + f"status={status}", + ) + + archived = commit_result.get("archived", False) + check("archived=true (消息已归档)", archived is True, f"archived={archived}") + + memories = commit_result.get("memories_extracted", {}) + total_mem = sum(memories.values()) if memories else 0 + check( + "memories_extracted > 0 (提取出记忆)", + total_mem > 0, + f"total={total_mem}, categories={memories}", + ) + + # 3.2 commit 后 session 状态 + console.print("\n[bold]3.2 Commit 后 session 状态[/bold]") + post_session = inspector.get_session(user_id) + if post_session: + commit_count = post_session.get("commit_count", 0) + check( + "commit_count >= 1", + commit_count >= 1, + f"commit_count={commit_count}", + ) + + post_pending = post_session.get("pending_tokens", 0) + # commit 后 pending_tokens 应该很低 (归档后清空了旧消息) + console.print(f" [dim]commit 后 pending_tokens={post_pending}[/dim]") + + # 3.3 检查归档目录结构 + console.print("\n[bold]3.3 归档目录结构检查[/bold]") + # 尝试用 context 来间接确认 latest_archive_overview 存在 + ctx_after = inspector.get_session_context(user_id) + if ctx_after: + has_summary_archive = bool(ctx_after.get("latest_archive_overview")) + check( + "commit 后 context 返回 latest_archive_overview", + has_summary_archive, + f"latest_archive_overview={ctx_after.get('latest_archive_overview')}", + ) + + if has_summary_archive: + overview = ctx_after.get("latest_archive_overview", "") + check( + "latest_archive_overview 非空 (摘要已生成)", + len(overview) > 10, + f"overview 长度={len(overview)} chars", + ) + if verbose: + console.print(f" [dim]overview 前 200 字: {overview[:200]}...[/dim]") + else: + check("commit 后 context 可调用", False) + + # 3.4 检查 estimatedTokens 合理性 + if ctx_after: + stats = ctx_after.get("stats", {}) + archive_tokens = stats.get("archiveTokens", 0) + check( + "archiveTokens > 0 (归档 token 计数合理)", + archive_tokens > 0, + f"archiveTokens={archive_tokens}", + ) + + return True + + +# ── Phase 4: Assemble 验证 ─────────────────────────────────────────────── + + +def run_phase_assemble( + gateway_url: str, openviking_url: str, user_id: str, delay: float, verbose: bool +) -> bool: + """Phase 4: Assemble 验证 — 同用户继续对话,验证上下文从 latest archive overview 重组。""" + console.print() + console.rule("[bold]Phase 4: Assemble 验证 — 同用户继续对话[/bold]") + console.print() + console.print("[dim]验证点:[/dim]") + console.print( + "[dim]- 同用户对话触发 assemble(): 从 OV latest_archive_overview + active messages 重组上下文[/dim]" + ) + console.print("[dim]- 回复应能引用 Phase 1 中已被归档的信息[/dim]") + console.print("[dim]- context 应返回 latest_archive_overview (证明 assemble 有数据源)[/dim]") + console.print() + + inspector = OpenVikingInspector(openviking_url) + + # 4.1 确认 assemble 的数据源 (latest_archive_overview) 就绪 + console.print("[bold]4.1 确认 assemble 数据源[/bold]") + ctx = inspector.get_session_context(user_id) + if ctx: + has_summary_archive = bool(ctx.get("latest_archive_overview")) + check( + "context 返回 latest_archive_overview", + has_summary_archive, + f"latest_archive_overview={ctx.get('latest_archive_overview')}", + ) + else: + check("context 可用", False) + return False + + # 4.2 assemble budget trimming: 用极小 budget 验证裁剪 + console.print("\n[bold]4.2 Assemble budget trimming[/bold]") + tiny_ctx = inspector.get_session_context(user_id, token_budget=1) + if tiny_ctx: + stats = tiny_ctx.get("stats", {}) + total_archives = stats.get("totalArchives", 0) + included = stats.get("includedArchives", 0) + dropped = stats.get("droppedArchives", 0) + check( + "budget=1 时 latest_archive_overview 被裁剪", + included == 0 or dropped > 0, + f"total={total_archives}, included={included}, dropped={dropped}", + ) + active_tokens = stats.get("activeTokens", 0) + console.print( + f" [dim]activeTokens={active_tokens}, archiveTokens={stats.get('archiveTokens', 0)}[/dim]" + ) + else: + check("tiny budget context 可调用", False) + + # 4.3 同用户继续对话 — assemble 应重组归档上下文 + console.print("\n[bold]4.3 同用户继续对话 — 验证 assemble 重组归档内容[/bold]") + console.print(f"[yellow]用户ID:[/yellow] {user_id} (同一用户,继续对话)") + console.print() + + total = len(ASSEMBLE_FOLLOWUP_MESSAGES) + for i, item in enumerate(ASSEMBLE_FOLLOWUP_MESSAGES, 1): + q = item["question"] + keywords = item["anchor_keywords"] + + console.rule(f"[dim]Assemble 验证 {i}/{total}[/dim]", style="dim") + console.print( + Panel( + f"{q}\n\n[dim]锚点关键词: {', '.join(keywords)}[/dim]\n[dim]Hook: {item['hook']}[/dim]", + title=f"[bold cyan]Assemble Q{i}[/bold cyan]", + border_style="cyan", + ) + ) + + try: + data = send_message(gateway_url, q, user_id) + reply = extract_reply_text(data) + render_reply(reply) + + reply_lower = reply.lower() + hits = [kw for kw in keywords if kw.lower() in reply_lower] + hit_rate = len(hits) / len(keywords) if keywords else 0 + check( + f"Assemble Q{i}: 回复包含归档内容 (命中率 >= 50%)", + hit_rate >= 0.5, + f"命中={hits}, 未命中={[k for k in keywords if k not in hits]}, rate={hit_rate:.0%}", + ) + except Exception as e: + check(f"Assemble Q{i}: 发送成功", False, str(e)) + + if i < total: + time.sleep(delay) + + # 4.4 对话后验证 afterTurn 继续写入 (新消息进入 active messages) + console.print("\n[bold]4.4 Assemble 后 afterTurn 继续写入[/bold]") + time.sleep(3) + post_ctx = inspector.get_session_context(user_id) + if post_ctx: + post_msg_count = len(post_ctx.get("messages", [])) + check( + "继续对话后 active messages 增加", + post_msg_count > 0, + f"active messages={post_msg_count}", + ) + + return True + + +# ── Phase 5: SessionId 一致性验证 ──────────────────────────────────────── + + +def run_phase_session_id(openviking_url: str, user_id: str, verbose: bool) -> bool: + """Phase 5: SessionId 一致性验证 — 确认整条链路使用统一的 sessionId。""" + console.print() + console.rule("[bold]Phase 5: SessionId 一致性验证[/bold]") + console.print() + console.print("[dim]验证点:[/dim]") + console.print("[dim]- 重构后 sessionId 统一为 OpenClaw 传入的 user_id[/dim]") + console.print("[dim]- OV session_id == user_id (无 sessionKey 前缀/后缀)[/dim]") + console.print("[dim]- context 用同一 sessionId 可查到数据[/dim]") + console.print() + + inspector = OpenVikingInspector(openviking_url) + + # 5.1 session_id 就是 user_id + console.print("[bold]5.1 SessionId == UserId[/bold]") + session = inspector.get_session(user_id) + check( + f"OV session 以 user_id={user_id} 为 ID 可查到", + session is not None, + "sessionId 统一: 插件直接用 user_id 作为 OV session_id", + ) + + # 5.2 不存在以 sessionKey 变体为 ID 的 session + console.print("\n[bold]5.2 无 sessionKey 残留[/bold]") + # 如果旧代码有 sessionKey 逻辑, 可能会创建带前缀的 session + stale_variants = [ + f"sk:{user_id}", + f"sessionKey:{user_id}", + f"key:{user_id}", + ] + for variant in stale_variants: + stale = inspector.get_session(variant) + is_absent = stale is None or stale.get("message_count", 0) == 0 + check( + f"不存在残留 session「{variant}」", + is_absent, + "旧 sessionKey 映射应已移除" if is_absent else f"发现残留: {stale}", + ) + + # 5.3 context 用 user_id 能查到数据 + console.print("\n[bold]5.3 同一 sessionId 查询归档[/bold]") + ctx = inspector.get_session_context(user_id) + if ctx: + has_data = bool(ctx.get("latest_archive_overview")) or len(ctx.get("messages", [])) > 0 + check( + "context(user_id) 返回数据", + has_data, + f"latest_archive_overview={ctx.get('latest_archive_overview')}, messages={len(ctx.get('messages', []))}", + ) + else: + check("context(user_id) 可调用", False) + + # 5.4 验证 commit 也是用同一 sessionId (session 有 commit_count > 0) + console.print("\n[bold]5.4 Commit 使用同一 sessionId[/bold]") + if session: + cc = session.get("commit_count", 0) + check( + "session(user_id) 有 commit 记录", + cc > 0, + f"commit_count={cc}, 说明 commit 也走 user_id 而非 sessionKey", + ) + + return True + + +# ── Phase 6: 新用户记忆召回 ────────────────────────────────────────────── + + +def run_phase_recall(gateway_url: str, user_id: str, delay: float, verbose: bool) -> list: + """Phase 6: 新用户记忆召回 — 验证 before_prompt_build auto-recall。""" + console.print() + console.rule(f"[bold]Phase 6: 新用户记忆召回 ({len(RECALL_QUESTIONS)} 轮) — auto-recall[/bold]") + console.print() + console.print("[dim]验证点:[/dim]") + console.print("[dim]- 新用户 (新 session) 发送问题[/dim]") + console.print("[dim]- before_prompt_build 通过 memory search 注入相关记忆[/dim]") + console.print("[dim]- 回复应包含 Phase 1 对话中的关键信息[/dim]") + console.print() + + verify_user = f"{user_id}-recall-{uuid.uuid4().hex[:4]}" + console.print(f"[yellow]验证用户:[/yellow] {verify_user} (新 session)") + console.print() + + results = [] + total = len(RECALL_QUESTIONS) + + for i, item in enumerate(RECALL_QUESTIONS, 1): + q = item["question"] + expected = item["expected_keywords"] + + console.rule(f"[dim]Recall {i}/{total}[/dim]", style="dim") + console.print( + Panel( + f"{q}\n\n[dim]期望关键词: {', '.join(expected)}[/dim]", + title=f"[bold cyan]Recall Q{i}[/bold cyan]", + border_style="cyan", + ) + ) + + try: + data = send_message(gateway_url, q, verify_user) + reply = extract_reply_text(data) + render_reply(reply) + + reply_lower = reply.lower() + hits = [kw for kw in expected if kw.lower() in reply_lower] + hit_rate = len(hits) / len(expected) if expected else 0 + success = hit_rate >= 0.5 + + check( + f"Recall Q{i}: 关键词命中率 >= 50%", + success, + f"命中={hits}, rate={hit_rate:.0%}", + ) + results.append({"question": q, "hits": hits, "hit_rate": hit_rate, "success": success}) + except Exception as e: + check(f"Recall Q{i}: 发送成功", False, str(e)) + results.append({"question": q, "hits": [], "hit_rate": 0, "success": False}) + + if i < total: + time.sleep(delay) + + return results + + +# ── 完整测试 ────────────────────────────────────────────────────────────── + + +def run_full_test(gateway_url: str, openviking_url: str, user_id: str, delay: float, verbose: bool): + console.print() + console.print( + Panel.fit( + f"[bold]OpenClaw 记忆链路完整测试[/bold]\n\n" + f"Gateway: {gateway_url}\n" + f"OpenViking: {openviking_url}\n" + f"User ID: {user_id}\n" + f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + title="测试信息", + ) + ) + + # Phase 1: Chat + chat_ok, chat_fail = run_phase_chat(gateway_url, user_id, delay, verbose) + + # Phase 2: afterTurn + run_phase_after_turn(openviking_url, user_id, verbose) + + # Phase 3: Commit + run_phase_commit(openviking_url, user_id, verbose) + + console.print("\n[yellow]等待 10s 让记忆提取完成...[/yellow]") + time.sleep(10) + + # Phase 4: Assemble (同用户继续) + run_phase_assemble(gateway_url, openviking_url, user_id, delay, verbose) + + # Phase 5: SessionId 一致性 + run_phase_session_id(openviking_url, user_id, verbose) + + # Phase 6: 新用户召回 + run_phase_recall(gateway_url, user_id, delay, verbose) + + # ── 汇总报告 ────────────────────────────────────────────────────────── + console.print() + console.rule("[bold]测试报告[/bold]") + + passed = sum(1 for a in assertions if a["ok"]) + failed = sum(1 for a in assertions if not a["ok"]) + total = len(assertions) + + table = Table(title=f"断言结果: {passed}/{total} 通过") + table.add_column("#", style="bold", width=4) + table.add_column("状态", width=6) + table.add_column("断言", max_width=60) + table.add_column("详情", style="dim", max_width=50) + + for i, a in enumerate(assertions, 1): + status = "[green]PASS[/green]" if a["ok"] else "[red]FAIL[/red]" + table.add_row(str(i), status, a["label"][:60], (a.get("detail") or "")[:50]) + + console.print(table) + + # 按阶段汇总 + tree = Tree(f"[bold]通过: {passed}/{total}, 失败: {failed}[/bold]") + tree.add(f"Phase 1: 多轮对话 — {chat_ok} 成功 / {chat_fail} 失败") + + fail_list = [a for a in assertions if not a["ok"]] + if fail_list: + fail_branch = tree.add(f"[red]失败断言 ({len(fail_list)})[/red]") + for a in fail_list: + fail_branch.add(f"[red]✗[/red] {a['label']}") + + console.print(tree) + + if failed == 0: + console.print("\n[green bold]全部通过!端到端链路验证成功。[/green bold]") + else: + console.print(f"\n[red bold]有 {failed} 个断言失败,请检查上方详情。[/red bold]") + + +# ── 入口 ─────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="OpenClaw 记忆链路完整测试", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + python test-memory-chain.py + python test-memory-chain.py --gateway http://127.0.0.1:18790 + python test-memory-chain.py --phase chat + python test-memory-chain.py --phase afterTurn --user-id test-chain-abc123 + python test-memory-chain.py --phase assemble --user-id test-chain-abc123 + python test-memory-chain.py --verbose + """, + ) + parser.add_argument( + "--gateway", + default=DEFAULT_GATEWAY, + help=f"OpenClaw Gateway 地址 (默认: {DEFAULT_GATEWAY})", + ) + parser.add_argument( + "--openviking", + default=DEFAULT_OPENVIKING, + help=f"OpenViking 服务地址 (默认: {DEFAULT_OPENVIKING})", + ) + parser.add_argument( + "--user-id", + default=USER_ID, + help="测试用户ID (默认: 随机生成)", + ) + parser.add_argument( + "--phase", + choices=["all", "chat", "afterTurn", "commit", "assemble", "session-id", "recall"], + default="all", + help="运行阶段 (默认: all)", + ) + parser.add_argument( + "--delay", + type=float, + default=2.0, + help="轮次间等待秒数 (默认: 2)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="详细输出", + ) + args = parser.parse_args() + + gateway_url = args.gateway.rstrip("/") + openviking_url = args.openviking.rstrip("/") + user_id = args.user_id + + console.print("[bold]OpenClaw 记忆链路测试[/bold]") + console.print(f"[yellow]Gateway:[/yellow] {gateway_url}") + console.print(f"[yellow]OpenViking:[/yellow] {openviking_url}") + console.print(f"[yellow]User ID:[/yellow] {user_id}") + + if args.phase == "all": + run_full_test(gateway_url, openviking_url, user_id, args.delay, args.verbose) + elif args.phase == "chat": + run_phase_chat(gateway_url, user_id, args.delay, args.verbose) + elif args.phase == "afterTurn": + run_phase_after_turn(openviking_url, user_id, args.verbose) + elif args.phase == "commit": + run_phase_commit(openviking_url, user_id, args.verbose) + elif args.phase == "assemble": + run_phase_assemble(gateway_url, openviking_url, user_id, args.delay, args.verbose) + elif args.phase == "session-id": + run_phase_session_id(openviking_url, user_id, args.verbose) + elif args.phase == "recall": + run_phase_recall(gateway_url, user_id, args.delay, args.verbose) + + # 打印最终断言统计 + if assertions: + passed = sum(1 for a in assertions if a["ok"]) + total = len(assertions) + console.print(f"\n[yellow]断言统计: {passed}/{total} 通过[/yellow]") + + console.print("\n[yellow]测试结束。[/yellow]") + + +if __name__ == "__main__": + main() diff --git a/examples/openclaw-plugin/test-tool-capture.py b/examples/openclaw-plugin/test-tool-capture.py new file mode 100644 index 000000000..0d536ef36 --- /dev/null +++ b/examples/openclaw-plugin/test-tool-capture.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python3 +""" +测试 extractNewTurnTexts 改动:验证 toolUse/toolResult 内容是否被正确捕获到 OV session 中。 + +测试策略: +1. 发送一条消息,触发模型使用工具(如 native_tool / code_execution) +2. 等待 afterTurn 完成 +3. 从 OV session 中读取已存储的消息 +4. 断言存储的消息中包含 toolUse 和 toolResult 相关内容 + +用法: + python test-tool-capture.py + python test-tool-capture.py --verbose + python test-tool-capture.py --gateway http://127.0.0.1:18789 --openviking http://127.0.0.1:1933 + +依赖: + pip install requests rich +""" + +import argparse +import json +import re +import time +import uuid +from datetime import datetime + +import requests +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +# ── 常量 ────────────────────────────────────────────────────────────────── + +GATEWAY_URL = "http://127.0.0.1:18789" +OPENVIKING_URL = "http://127.0.0.1:1933" +AGENT_ID = "openclaw" + +console = Console() +assertions: list[dict] = [] + + +def check(label: str, condition: bool, detail: str = ""): + assertions.append({"label": label, "ok": condition, "detail": detail}) + icon = "[green]✓[/green]" if condition else "[red]✗[/red]" + msg = f" {icon} {label}" + if detail: + msg += f" [dim]({detail})[/dim]" + console.print(msg) + + +def load_gateway_token() -> str: + """从 openclaw.json 读取 gateway auth token。""" + try: + import pathlib + cfg_path = pathlib.Path.home() / ".openclaw" / "openclaw.json" + cfg = json.loads(cfg_path.read_text()) + return cfg.get("gateway", {}).get("auth", {}).get("token", "") + except Exception: + return "" + + +# ── API helpers ────────────────────────────────────────────────────────── + + +def send_message(gateway_url: str, message: str, user_id: str, token: str) -> dict: + """通过 OpenClaw Responses API 发送消息。""" + headers = {"Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + resp = requests.post( + f"{gateway_url}/v1/responses", + headers=headers, + json={"model": "openclaw", "input": message, "user": user_id}, + timeout=300, + ) + resp.raise_for_status() + return resp.json() + + +def extract_reply_text(data: dict) -> str: + for item in data.get("output", []): + if item.get("type") == "message" and item.get("role") == "assistant": + for part in item.get("content", []): + if part.get("type") in ("text", "output_text"): + return part.get("text", "") + return "(无回复)" + + +def has_tool_use_in_output(data: dict) -> bool: + """检查 Responses API 返回中是否有 tool_use / function_call。""" + for item in data.get("output", []): + item_type = item.get("type", "") + if item_type in ("function_call", "tool_use", "computer_call"): + return True + if item.get("role") == "assistant": + for part in item.get("content", []): + if part.get("type") in ("tool_use", "toolUse"): + return True + return False + + +class OVInspector: + def __init__(self, base_url: str, agent_id: str = AGENT_ID): + self.base_url = base_url.rstrip("/") + self.agent_id = agent_id + + def _headers(self) -> dict: + h: dict[str, str] = {"Content-Type": "application/json"} + if self.agent_id: + h["X-OpenViking-Agent"] = self.agent_id + return h + + def _get(self, path: str, timeout: int = 10): + try: + resp = requests.get(f"{self.base_url}{path}", headers=self._headers(), timeout=timeout) + if resp.status_code == 200: + data = resp.json() + return data.get("result", data) + return None + except Exception as e: + console.print(f"[dim]GET {path} 失败: {e}[/dim]") + return None + + def list_sessions(self) -> list: + result = self._get("/api/v1/sessions") + if isinstance(result, list): + return result + return [] + + def get_session(self, session_id: str): + return self._get(f"/api/v1/sessions/{session_id}") + + def get_session_context(self, session_id: str, token_budget: int = 128000): + return self._get(f"/api/v1/sessions/{session_id}/context?token_budget={token_budget}") + + def find_latest_session(self) -> str | None: + """找到最近更新的 session ID(gateway 内部使用 UUID,非 user_id)。 + 通过检查每个 session 的 updated_at 来找到最新的。""" + sessions = self.list_sessions() + real_sessions = [ + s for s in sessions + if isinstance(s, dict) and not s.get("session_id", "").startswith("memory-store-") + ] + if not real_sessions: + return None + + best_id = None + best_time = "" + for s in real_sessions: + sid = s.get("session_id", "") + if not sid: + continue + detail = self.get_session(sid) + if not detail: + continue + updated = detail.get("updated_at", "") + if updated > best_time: + best_time = updated + best_id = sid + + return best_id or real_sessions[-1].get("session_id") + + +# ── 核心测试 ────────────────────────────────────────────────────────────── + + +TOOL_TRIGGER_MESSAGES = [ + { + "input": "请帮我计算 factorial(7) 的结果,用代码算一下", + "description": "触发代码执行工具", + "expect_keywords": ["5040", "factorial"], + }, + { + "input": "我叫李明,记住我是一名数据工程师,擅长 Spark 和 Flink,偏好用 Scala 写代码。请同时告诉我今天星期几。", + "description": "信息存储 + 可能触发工具", + "expect_keywords": ["李明", "数据工程师"], + }, + { + "input": "帮我写一段 Python 代码计算斐波那契数列前10个数,并运行它告诉我结果", + "description": "触发代码执行并返回结果", + "expect_keywords": ["斐波那契", "fibonacci"], + }, +] + + +def run_test( + gateway_url: str, + openviking_url: str, + user_id: str, + delay: float, + verbose: bool, +): + token = load_gateway_token() + inspector = OVInspector(openviking_url) + + console.print( + Panel( + f"[bold]Tool Capture 测试[/bold]\n\n" + f"Gateway: {gateway_url}\n" + f"OpenViking: {openviking_url}\n" + f"User ID: {user_id}\n" + f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + title="测试信息", + ) + ) + + # ── Phase 1: 发送消息 ──────────────────────────────────────────────── + + console.rule("[bold]Phase 1: 发送消息触发 afterTurn[/bold]") + + gateway_responses = [] + for i, msg_cfg in enumerate(TOOL_TRIGGER_MESSAGES): + console.print(f"\n[cyan]消息 {i + 1}/{len(TOOL_TRIGGER_MESSAGES)}:[/cyan] {msg_cfg['description']}") + console.print(f" [dim]> {msg_cfg['input'][:80]}...[/dim]") + + try: + data = send_message(gateway_url, msg_cfg["input"], user_id, token) + reply = extract_reply_text(data) + has_tool = has_tool_use_in_output(data) + + console.print(f" [green]回复:[/green] {reply[:120]}...") + if has_tool: + console.print(" [yellow]检测到 tool_use 在响应中[/yellow]") + + if verbose: + console.print(f" [dim]完整响应: {json.dumps(data, ensure_ascii=False)[:500]}[/dim]") + + gateway_responses.append({ + "index": i, + "msg": msg_cfg, + "response": data, + "reply": reply, + "has_tool": has_tool, + }) + + check( + f"消息 {i + 1} 发送成功", + True, + f"reply_len={len(reply)}", + ) + except Exception as e: + console.print(f" [red]发送失败: {e}[/red]") + check(f"消息 {i + 1} 发送成功", False, str(e)) + + if i < len(TOOL_TRIGGER_MESSAGES) - 1: + time.sleep(delay) + + # ── Phase 2: 等待 afterTurn 写入 ─────────────────────────────────── + + console.rule("[bold]Phase 2: 检查 OV session 中的存储内容[/bold]") + console.print("[yellow]等待 afterTurn 写入 OV session...[/yellow]") + time.sleep(8) + + # Gateway 使用内部 UUID 作为 session ID,需要从 OV 列表中找到最新的 + ov_session_id = inspector.find_latest_session() + if not ov_session_id: + console.print("[red] OV 中没有找到任何 session[/red]") + check("OV session 存在", False, "no sessions found") + print_summary() + return + + console.print(f" [cyan]OV session ID: {ov_session_id}[/cyan]") + + session_info = inspector.get_session(ov_session_id) + if session_info: + msg_count = session_info.get("message_count", "?") + console.print(f" Session found: message_count={msg_count}") + check("OV session 存在", True, f"id={ov_session_id[:16]}...") + else: + console.print("[red] OV session 详情获取失败[/red]") + check("OV session 存在", False, "session detail failed") + print_summary() + return + + # 通过 context API 获取消息(包含 parts) + ctx = inspector.get_session_context(ov_session_id) + messages = ctx.get("messages", []) if ctx else [] + if not messages: + console.print("[red] OV session 消息为空[/red]") + check("OV session 有消息", False, "context messages empty") + print_summary() + return + + console.print(f" [green]OV session 消息数: {len(messages)}[/green]") + check("OV session 有消息", len(messages) > 0, f"count={len(messages)}") + + # ── Phase 3: 分析存储的内容是否包含 tool 信息 ────────────────────── + + console.rule("[bold]Phase 3: 验证 toolUse/toolResult 内容被捕获[/bold]") + + all_stored_text = "" + for msg in messages: + if not isinstance(msg, dict): + continue + parts = msg.get("parts", []) + for part in parts: + if isinstance(part, dict) and part.get("type") == "text": + all_stored_text += (part.get("text", "") or "") + "\n" + + if verbose: + console.print(Panel( + all_stored_text[:3000] + ("..." if len(all_stored_text) > 3000 else ""), + title="OV 存储的全部文本", + )) + + # 检查 toolUse 标记是否存在 + has_tool_use_marker = bool(re.search(r'\[toolUse:', all_stored_text, re.IGNORECASE)) + check( + "存储文本包含 [toolUse:] 标记", + has_tool_use_marker, + f"found={has_tool_use_marker}", + ) + + # 检查 toolResult 标记是否存在 + has_tool_result_marker = bool(re.search(r'result\]:', all_stored_text, re.IGNORECASE)) + check( + "存储文本包含 tool result 标记", + has_tool_result_marker, + f"found={has_tool_result_marker}", + ) + + # 检查 assistant 标记 + has_assistant = bool(re.search(r'\[assistant\]:', all_stored_text, re.IGNORECASE)) + check( + "存储文本包含 [assistant] 标记", + has_assistant, + f"found={has_assistant}", + ) + + # 检查 user 标记 + has_user = bool(re.search(r'\[user\]:', all_stored_text, re.IGNORECASE)) + check( + "存储文本包含 [user] 标记", + has_user, + f"found={has_user}", + ) + + # 检查关键内容是否保留 + for msg_cfg in TOOL_TRIGGER_MESSAGES: + for kw in msg_cfg.get("expect_keywords", []): + found = kw.lower() in all_stored_text.lower() + check( + f"存储文本包含关键词: {kw}", + found, + f"keyword='{kw}' found={found}", + ) + + # ── Phase 4: 对比改动前后的行为 ────────────────────────────────────── + + console.rule("[bold]Phase 4: 改动前后对比分析[/bold]") + + # 旧版本:只有 [user] 和 [assistant] 的文本 + # 新版本:应该额外包含 [toolUse: xxx] 和 [xxx result] 的内容 + tool_related_lines = [] + for line in all_stored_text.split("\n"): + stripped = line.strip() + if re.search(r'\[toolUse:', stripped, re.IGNORECASE): + tool_related_lines.append(("toolUse", stripped[:150])) + elif re.search(r'result\]:', stripped, re.IGNORECASE): + tool_related_lines.append(("toolResult", stripped[:150])) + + if tool_related_lines: + table = Table(title="捕获到的 Tool 相关内容") + table.add_column("类型", style="cyan", width=12) + table.add_column("内容预览", max_width=120) + for kind, preview in tool_related_lines: + table.add_row(kind, preview) + console.print(table) + + check( + "tool 相关行数 > 0(新逻辑生效)", + len(tool_related_lines) > 0, + f"tool_lines={len(tool_related_lines)}", + ) + + # ── 汇总 ───────────────────────────────────────────────────────────── + + print_summary() + + +def print_summary(): + console.print() + console.rule("[bold]测试汇总[/bold]") + + passed = sum(1 for a in assertions if a["ok"]) + failed = sum(1 for a in assertions if not a["ok"]) + total = len(assertions) + + table = Table(title=f"断言结果: {passed}/{total} 通过") + table.add_column("#", style="bold", width=4) + table.add_column("状态", width=6) + table.add_column("断言", max_width=60) + table.add_column("详情", style="dim", max_width=50) + + for i, a in enumerate(assertions, 1): + status = "[green]PASS[/green]" if a["ok"] else "[red]FAIL[/red]" + table.add_row(str(i), status, a["label"][:60], (a.get("detail") or "")[:50]) + + console.print(table) + + if failed == 0: + console.print("\n[green bold]全部通过!toolUse/toolResult 捕获验证成功。[/green bold]") + else: + console.print(f"\n[red bold]有 {failed} 个断言失败。[/red bold]") + console.print("[yellow]注: 如果模型没有调用工具,toolUse/toolResult 标记可能不存在 — 这不代表代码有 bug。[/yellow]") + console.print("[yellow]可以在 gateway 日志中确认 afterTurn 的存储内容。[/yellow]") + + +# ── 入口 ────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser(description="测试 toolUse/toolResult 捕获") + parser.add_argument("--gateway", default=GATEWAY_URL, help="Gateway 地址") + parser.add_argument("--openviking", default=OPENVIKING_URL, help="OpenViking 地址") + parser.add_argument("--delay", type=float, default=3.0, help="消息间延迟秒数") + parser.add_argument("--verbose", "-v", action="store_true", help="详细输出") + args = parser.parse_args() + + user_id = f"test-tool-{uuid.uuid4().hex[:8]}" + + run_test( + gateway_url=args.gateway.rstrip("/"), + openviking_url=args.openviking.rstrip("/"), + user_id=user_id, + delay=args.delay, + verbose=args.verbose, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/openclaw-plugin/text-utils.ts b/examples/openclaw-plugin/text-utils.ts index 40f224a45..a9f4decef 100644 --- a/examples/openclaw-plugin/text-utils.ts +++ b/examples/openclaw-plugin/text-utils.ts @@ -318,8 +318,47 @@ export function extractTextsFromUserMessages(messages: unknown[]): string[] { return texts; } +function formatToolUseBlock(b: Record): string { + const name = typeof b.name === "string" ? b.name : "unknown"; + let inputStr = ""; + if (b.input !== undefined && b.input !== null) { + try { + inputStr = typeof b.input === "string" ? b.input : JSON.stringify(b.input); + } catch { + inputStr = String(b.input); + } + } + return inputStr + ? `[toolUse: ${name}]\n${inputStr}` + : `[toolUse: ${name}]`; +} + +function formatToolResultContent(content: unknown): string { + if (typeof content === "string") return content.trim(); + if (Array.isArray(content)) { + const parts: string[] = []; + for (const block of content) { + const b = block as Record; + if (b?.type === "text" && typeof b.text === "string") { + parts.push((b.text as string).trim()); + } + } + return parts.join("\n"); + } + if (content !== undefined && content !== null) { + try { + return JSON.stringify(content); + } catch { + return String(content); + } + } + return ""; +} + /** - * 提取从 startIndex 开始的新消息(user + assistant),返回格式化的文本。 + * 提取从 startIndex 开始的新消息(user + assistant + toolResult),返回格式化的文本。 + * 保留 toolUse 完整内容(tool name + input)和 toolResult 完整内容, + * 跳过 system 消息(框架注入的元数据)。 */ export function extractNewTurnTexts( messages: unknown[], @@ -331,8 +370,18 @@ export function extractNewTurnTexts( const msg = messages[i] as Record; if (!msg || typeof msg !== "object") continue; const role = msg.role as string; - if (role !== "user" && role !== "assistant") continue; + if (!role || role === "system") continue; count++; + + if (role === "toolResult") { + const toolName = typeof msg.toolName === "string" ? msg.toolName : "tool"; + const resultText = formatToolResultContent(msg.content); + if (resultText) { + texts.push(`[${toolName} result]: ${resultText}`); + } + continue; + } + const content = msg.content; if (typeof content === "string" && content.trim()) { texts.push(`[${role}]: ${content.trim()}`); @@ -341,6 +390,8 @@ export function extractNewTurnTexts( const b = block as Record; if (b?.type === "text" && typeof b.text === "string") { texts.push(`[${role}]: ${(b.text as string).trim()}`); + } else if (b?.type === "toolUse") { + texts.push(`[${role}]: ${formatToolUseBlock(b)}`); } } } diff --git a/examples/openclaw-plugin/tool-call-id.ts b/examples/openclaw-plugin/tool-call-id.ts new file mode 100644 index 000000000..0b6204838 --- /dev/null +++ b/examples/openclaw-plugin/tool-call-id.ts @@ -0,0 +1,331 @@ +/** + * Tool call ID utilities. + * + * Copied from openclaw core (src/agents/tool-call-id.ts). + * Only change: replaced `import type { AgentMessage } from "@mariozechner/pi-agent-core"` + * with a local type definition to avoid the external dependency. + */ + +import { createHash } from "node:crypto"; + +// Local AgentMessage type replacing @mariozechner/pi-agent-core import. +// Discriminated union so Extract works correctly. +export type AgentMessage = + | { role: "user"; content?: unknown } + | { role: "assistant"; content?: unknown; stopReason?: string } + | { + role: "toolResult"; + content?: unknown; + toolCallId?: string; + toolUseId?: string; + toolName?: string; + isError?: boolean; + timestamp?: number; + }; + +export type ToolCallIdMode = "strict" | "strict9"; + +const STRICT9_LEN = 9; +const TOOL_CALL_TYPES = new Set(["toolCall", "toolUse", "functionCall"]); + +export type ToolCallLike = { + id: string; + name?: string; +}; + +/** + * Sanitize a tool call ID to be compatible with various providers. + * + * - "strict" mode: only [a-zA-Z0-9] + * - "strict9" mode: only [a-zA-Z0-9], length 9 (Mistral tool call requirement) + */ +export function sanitizeToolCallId(id: string, mode: ToolCallIdMode = "strict"): string { + if (!id || typeof id !== "string") { + if (mode === "strict9") { + return "defaultid"; + } + return "defaulttoolid"; + } + + if (mode === "strict9") { + const alphanumericOnly = id.replace(/[^a-zA-Z0-9]/g, ""); + if (alphanumericOnly.length >= STRICT9_LEN) { + return alphanumericOnly.slice(0, STRICT9_LEN); + } + if (alphanumericOnly.length > 0) { + return shortHash(alphanumericOnly, STRICT9_LEN); + } + return shortHash("sanitized", STRICT9_LEN); + } + + // Some providers require strictly alphanumeric tool call IDs. + const alphanumericOnly = id.replace(/[^a-zA-Z0-9]/g, ""); + return alphanumericOnly.length > 0 ? alphanumericOnly : "sanitizedtoolid"; +} + +export function extractToolCallsFromAssistant( + msg: Extract, +): ToolCallLike[] { + const content = msg.content; + if (!Array.isArray(content)) { + return []; + } + + const toolCalls: ToolCallLike[] = []; + for (const block of content) { + if (!block || typeof block !== "object") { + continue; + } + const rec = block as { type?: unknown; id?: unknown; name?: unknown }; + if (typeof rec.id !== "string" || !rec.id) { + continue; + } + if (typeof rec.type === "string" && TOOL_CALL_TYPES.has(rec.type)) { + toolCalls.push({ + id: rec.id, + name: typeof rec.name === "string" ? rec.name : undefined, + }); + } + } + return toolCalls; +} + +export function extractToolResultId( + msg: Extract, +): string | null { + const toolCallId = (msg as { toolCallId?: unknown }).toolCallId; + if (typeof toolCallId === "string" && toolCallId) { + return toolCallId; + } + const toolUseId = (msg as { toolUseId?: unknown }).toolUseId; + if (typeof toolUseId === "string" && toolUseId) { + return toolUseId; + } + return null; +} + +export function isValidCloudCodeAssistToolId(id: string, mode: ToolCallIdMode = "strict"): boolean { + if (!id || typeof id !== "string") { + return false; + } + if (mode === "strict9") { + return /^[a-zA-Z0-9]{9}$/.test(id); + } + // Strictly alphanumeric for providers with tighter tool ID constraints + return /^[a-zA-Z0-9]+$/.test(id); +} + +function shortHash(text: string, length = 8): string { + return createHash("sha256").update(text).digest("hex").slice(0, length); +} + +function makeUniqueToolId(params: { id: string; used: Set; mode: ToolCallIdMode }): string { + if (params.mode === "strict9") { + const base = sanitizeToolCallId(params.id, params.mode); + const candidate = base.length >= STRICT9_LEN ? base.slice(0, STRICT9_LEN) : ""; + if (candidate && !params.used.has(candidate)) { + return candidate; + } + + for (let i = 0; i < 1000; i += 1) { + const hashed = shortHash(`${params.id}:${i}`, STRICT9_LEN); + if (!params.used.has(hashed)) { + return hashed; + } + } + + return shortHash(`${params.id}:${Date.now()}`, STRICT9_LEN); + } + + const MAX_LEN = 40; + + const base = sanitizeToolCallId(params.id, params.mode).slice(0, MAX_LEN); + if (!params.used.has(base)) { + return base; + } + + const hash = shortHash(params.id); + // Use separator based on mode: none for strict, underscore for non-strict variants + const separator = params.mode === "strict" ? "" : "_"; + const maxBaseLen = MAX_LEN - separator.length - hash.length; + const clippedBase = base.length > maxBaseLen ? base.slice(0, maxBaseLen) : base; + const candidate = `${clippedBase}${separator}${hash}`; + if (!params.used.has(candidate)) { + return candidate; + } + + for (let i = 2; i < 1000; i += 1) { + const suffix = params.mode === "strict" ? `x${i}` : `_${i}`; + const next = `${candidate.slice(0, MAX_LEN - suffix.length)}${suffix}`; + if (!params.used.has(next)) { + return next; + } + } + + const ts = params.mode === "strict" ? `t${Date.now()}` : `_${Date.now()}`; + return `${candidate.slice(0, MAX_LEN - ts.length)}${ts}`; +} + +function createOccurrenceAwareResolver(mode: ToolCallIdMode): { + resolveAssistantId: (id: string) => string; + resolveToolResultId: (id: string) => string; +} { + const used = new Set(); + const assistantOccurrences = new Map(); + const orphanToolResultOccurrences = new Map(); + const pendingByRawId = new Map(); + + const allocate = (seed: string): string => { + const next = makeUniqueToolId({ id: seed, used, mode }); + used.add(next); + return next; + }; + + const resolveAssistantId = (id: string): string => { + const occurrence = (assistantOccurrences.get(id) ?? 0) + 1; + assistantOccurrences.set(id, occurrence); + const next = allocate(occurrence === 1 ? id : `${id}:${occurrence}`); + const pending = pendingByRawId.get(id); + if (pending) { + pending.push(next); + } else { + pendingByRawId.set(id, [next]); + } + return next; + }; + + const resolveToolResultId = (id: string): string => { + const pending = pendingByRawId.get(id); + if (pending && pending.length > 0) { + const next = pending.shift()!; + if (pending.length === 0) { + pendingByRawId.delete(id); + } + return next; + } + + const occurrence = (orphanToolResultOccurrences.get(id) ?? 0) + 1; + orphanToolResultOccurrences.set(id, occurrence); + return allocate(`${id}:tool_result:${occurrence}`); + }; + + return { resolveAssistantId, resolveToolResultId }; +} + +function rewriteAssistantToolCallIds(params: { + message: Extract; + resolveId: (id: string) => string; +}): Extract { + const content = params.message.content; + if (!Array.isArray(content)) { + return params.message; + } + + let changed = false; + const next = content.map((block) => { + if (!block || typeof block !== "object") { + return block; + } + const rec = block as { type?: unknown; id?: unknown }; + const type = rec.type; + const id = rec.id; + if ( + (type !== "functionCall" && type !== "toolUse" && type !== "toolCall") || + typeof id !== "string" || + !id + ) { + return block; + } + const nextId = params.resolveId(id); + if (nextId === id) { + return block; + } + changed = true; + return { ...(block as unknown as Record), id: nextId }; + }); + + if (!changed) { + return params.message; + } + return { ...params.message, content: next as typeof params.message.content }; +} + +function rewriteToolResultIds(params: { + message: Extract; + resolveId: (id: string) => string; +}): Extract { + const toolCallId = + typeof params.message.toolCallId === "string" && params.message.toolCallId + ? params.message.toolCallId + : undefined; + const toolUseId = (params.message as { toolUseId?: unknown }).toolUseId; + const toolUseIdStr = typeof toolUseId === "string" && toolUseId ? toolUseId : undefined; + const sharedRawId = + toolCallId && toolUseIdStr && toolCallId === toolUseIdStr ? toolCallId : undefined; + + const sharedResolvedId = sharedRawId ? params.resolveId(sharedRawId) : undefined; + const nextToolCallId = + sharedResolvedId ?? (toolCallId ? params.resolveId(toolCallId) : undefined); + const nextToolUseId = + sharedResolvedId ?? (toolUseIdStr ? params.resolveId(toolUseIdStr) : undefined); + + if (nextToolCallId === toolCallId && nextToolUseId === toolUseIdStr) { + return params.message; + } + + return { + ...params.message, + ...(nextToolCallId && { toolCallId: nextToolCallId }), + ...(nextToolUseId && { toolUseId: nextToolUseId }), + } as Extract; +} + +/** + * Sanitize tool call IDs for provider compatibility. + * + * @param messages - The messages to sanitize + * @param mode - "strict" (alphanumeric only) or "strict9" (alphanumeric length 9) + */ +export function sanitizeToolCallIdsForCloudCodeAssist( + messages: AgentMessage[], + mode: ToolCallIdMode = "strict", +): AgentMessage[] { + // Strict mode: only [a-zA-Z0-9] + // Strict9 mode: only [a-zA-Z0-9], length 9 (Mistral tool call requirement) + // Sanitization can introduce collisions, and some providers also reject raw + // duplicate tool-call IDs. Track assistant occurrences in-order so repeated + // raw IDs receive distinct rewritten IDs, while matching tool results consume + // the same rewritten IDs in encounter order. + const { resolveAssistantId, resolveToolResultId } = createOccurrenceAwareResolver(mode); + + let changed = false; + const out = messages.map((msg) => { + if (!msg || typeof msg !== "object") { + return msg; + } + const role = (msg as { role?: unknown }).role; + if (role === "assistant") { + const next = rewriteAssistantToolCallIds({ + message: msg as Extract, + resolveId: resolveAssistantId, + }); + if (next !== msg) { + changed = true; + } + return next; + } + if (role === "toolResult") { + const next = rewriteToolResultIds({ + message: msg as Extract, + resolveId: resolveToolResultId, + }); + if (next !== msg) { + changed = true; + } + return next; + } + return msg; + }); + + return changed ? out : messages; +} diff --git a/openviking/async_client.py b/openviking/async_client.py index 72bb3c00a..98b4c0f17 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -143,6 +143,18 @@ async def get_session(self, session_id: str, *, auto_create: bool = False) -> Di await self._ensure_initialized() return await self._client.get_session(session_id, auto_create=auto_create) + async def get_session_context( + self, session_id: str, token_budget: int = 128_000 + ) -> Dict[str, Any]: + """Get assembled session context.""" + await self._ensure_initialized() + return await self._client.get_session_context(session_id, token_budget=token_budget) + + async def get_session_archive(self, session_id: str, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for a session.""" + await self._ensure_initialized() + return await self._client.get_session_archive(session_id, archive_id) + async def delete_session(self, session_id: str) -> None: """Delete a session.""" await self._ensure_initialized() diff --git a/openviking/client/local.py b/openviking/client/local.py index ca9e8d0ae..85215c1ff 100644 --- a/openviking/client/local.py +++ b/openviking/client/local.py @@ -19,6 +19,18 @@ from openviking_cli.utils import run_async +def _to_jsonable(value: Any) -> Any: + """Convert internal objects into JSON-serializable values.""" + to_dict = getattr(value, "to_dict", None) + if callable(to_dict): + return to_dict() + if isinstance(value, list): + return [_to_jsonable(item) for item in value] + if isinstance(value, dict): + return {k: _to_jsonable(v) for k, v in value.items()} + return value + + class LocalClient(BaseClient): """Local Client for OpenViking (embedded mode). @@ -328,6 +340,22 @@ async def get_session(self, session_id: str, *, auto_create: bool = False) -> Di result["user"] = session.user.to_dict() return result + async def get_session_context( + self, session_id: str, token_budget: int = 128_000 + ) -> Dict[str, Any]: + """Get assembled session context.""" + session = self._service.sessions.session(self._ctx, session_id) + await session.load() + result = await session.get_session_context(token_budget=token_budget) + return _to_jsonable(result) + + async def get_session_archive(self, session_id: str, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for a session.""" + session = self._service.sessions.session(self._ctx, session_id) + await session.load() + result = await session.get_session_archive(archive_id) + return _to_jsonable(result) + async def delete_session(self, session_id: str) -> None: """Delete a session.""" await self._service.sessions.delete(session_id, self._ctx) diff --git a/openviking/client/session.py b/openviking/client/session.py index cf0a8ac29..e0c3ceced 100644 --- a/openviking/client/session.py +++ b/openviking/client/session.py @@ -87,5 +87,13 @@ async def load(self) -> Dict[str, Any]: """ return await self._client.get_session(self.session_id) + async def get_session_context(self, token_budget: int = 128_000) -> Dict[str, Any]: + """Get assembled session context.""" + return await self._client.get_session_context(self.session_id, token_budget=token_budget) + + async def get_archive(self, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for the session.""" + return await self._client.get_session_archive(self.session_id, archive_id) + def __repr__(self) -> str: return f"Session(id={self.session_id}, user={self.user.__str__()})" diff --git a/openviking/message/message.py b/openviking/message/message.py index c40808527..fdf185e07 100644 --- a/openviking/message/message.py +++ b/openviking/message/message.py @@ -31,6 +31,37 @@ def content(self) -> str: return p.text return "" + @property + def estimated_tokens(self) -> int: + """Estimate token count from all parts (ceil(len/4) heuristic). + + Counts fields that actually appear in the assembled prompt: + - TextPart.text: always emitted + - ContextPart.abstract: injected as text (uri is not sent to the model) + - ToolPart: tool_id (appears in toolUse.id / toolResult.toolCallId), + tool_name, tool_input (JSON), tool_output + + Known limitation: ToolPart estimation undercounts by ~10-20 tokens per + tool call because tool_id/toolName appear twice in the assembled transcript + (toolUse + toolResult), and small literals like "(no output)" / "{}" are + not counted. Under 128k budgets this is negligible; for smaller budgets + (8k/16k) or tool-dense sessions, consider adding a conservative per-tool + buffer instead of mirroring the full convertToAgentMessages logic. + """ + total_chars = 0 + for p in self.parts: + if isinstance(p, TextPart): + total_chars += len(p.text) + elif isinstance(p, ContextPart): + total_chars += len(p.abstract) + elif isinstance(p, ToolPart): + total_chars += len(p.tool_id) + len(p.tool_name) + if p.tool_input: + total_chars += len(json.dumps(p.tool_input, ensure_ascii=False)) + if p.tool_output: + total_chars += len(p.tool_output) + return -(-total_chars // 4) # ceil division + def to_dict(self) -> dict: """Serialize to JSONL.""" created_at_val = self.created_at or datetime.now(timezone.utc) diff --git a/openviking/server/routers/sessions.py b/openviking/server/routers/sessions.py index de94d2bf2..0977685bf 100644 --- a/openviking/server/routers/sessions.py +++ b/openviking/server/routers/sessions.py @@ -13,7 +13,6 @@ from openviking.server.dependencies import get_service from openviking.server.identity import RequestContext from openviking.server.models import ErrorInfo, Response -from openviking.service.task_tracker import get_task_tracker router = APIRouter(prefix="/api/v1/sessions", tags=["sessions"]) logger = logging.getLogger(__name__) @@ -138,9 +137,47 @@ async def get_session( ) result = session.meta.to_dict() result["user"] = session.user.to_dict() + pending_tokens = sum(len(m.content) // 4 for m in session.messages) + result["pending_tokens"] = pending_tokens return Response(status="ok", result=result) +@router.get("/{session_id}/context") +async def get_session_context( + session_id: str = Path(..., description="Session ID"), + token_budget: int = Query(128_000, description="Token budget for session context"), + _ctx: RequestContext = Depends(get_request_context), +): + """Get assembled session context.""" + service = get_service() + session = service.sessions.session(_ctx, session_id) + await session.load() + result = await session.get_session_context(token_budget=token_budget) + return Response(status="ok", result=_to_jsonable(result)) + + +@router.get("/{session_id}/archives/{archive_id}") +async def get_session_archive( + session_id: str = Path(..., description="Session ID"), + archive_id: str = Path(..., description="Archive ID"), + _ctx: RequestContext = Depends(get_request_context), +): + """Get one completed archive for a session.""" + from openviking_cli.exceptions import NotFoundError + + service = get_service() + session = service.sessions.session(_ctx, session_id) + await session.load() + try: + result = await session.get_session_archive(archive_id) + except NotFoundError: + return Response( + status="error", + error=ErrorInfo(code="NOT_FOUND", message=f"Archive {archive_id} not found"), + ) + return Response(status="ok", result=_to_jsonable(result)) + + @router.delete("/{session_id}") async def delete_session( session_id: str = Path(..., description="Session ID"), @@ -164,18 +201,6 @@ async def commit_session( polling progress via ``GET /tasks/{task_id}``. """ service = get_service() - tracker = get_task_tracker() - - # Reject if same session already has a commit in progress - if tracker.has_running("session_commit", session_id): - return Response( - status="error", - error=ErrorInfo( - code="CONFLICT", - message=f"Session {session_id} already has a commit in progress", - ), - ) - result = await service.sessions.commit_async(session_id, _ctx) return Response(status="ok", result=result).model_dump(exclude_none=True) diff --git a/openviking/session/__init__.py b/openviking/session/__init__.py index 13eefae7b..ae0973e21 100644 --- a/openviking/session/__init__.py +++ b/openviking/session/__init__.py @@ -4,10 +4,6 @@ from typing import Optional -from openviking.storage import VikingDBManager -from openviking_cli.utils import get_logger -from openviking_cli.utils.config import get_openviking_config - from openviking.session.compressor import ExtractionStats, SessionCompressor from openviking.session.memory_archiver import ( ArchivalCandidate, @@ -28,6 +24,9 @@ ToolSkillCandidateMemory, ) from openviking.session.session import Session, SessionCompression, SessionMeta, SessionStats +from openviking.storage import VikingDBManager +from openviking_cli.utils import get_logger +from openviking_cli.utils.config import get_openviking_config logger = get_logger(__name__) @@ -60,6 +59,7 @@ def create_session_compressor( logger.info("Using v2 memory compressor (templating system)") try: from openviking.session.compressor_v2 import SessionCompressorV2 + return SessionCompressorV2(vikingdb=vikingdb) except Exception as e: logger.warning(f"Failed to load v2 compressor, falling back to v1: {e}") diff --git a/openviking/session/session.py b/openviking/session/session.py index f62998c99..03974abe2 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -28,6 +28,8 @@ logger = get_logger(__name__) +_ARCHIVE_WAIT_POLL_SECONDS = 0.1 + @dataclass class SessionCompression: @@ -301,7 +303,7 @@ def add_message( # Update statistics if role == "user": self._stats.total_turns += 1 - self._stats.total_tokens += len(msg.content) // 4 + self._stats.total_tokens += msg.estimated_tokens self._append_to_jsonl(msg) @@ -347,6 +349,7 @@ async def commit_async(self) -> Dict[str, Any]: """ from openviking.service.task_tracker import get_task_tracker from openviking.storage.transaction import LockContext, get_lock_manager + from openviking_cli.exceptions import FailedPreconditionError # ===== Phase 1: Snapshot + clear (PathLock-protected) ===== # Fast pre-check: skip lock entirely if no messages (common case avoids @@ -361,6 +364,14 @@ async def commit_async(self) -> Dict[str, Any]: "archived": False, } + blocking_archive = await self._get_blocking_failed_archive_ref() + if blocking_archive: + raise FailedPreconditionError( + f"Session {self.session_id} has unresolved failed archive " + f"{blocking_archive['archive_id']}; fix it before committing again.", + details={"archive_id": blocking_archive["archive_id"]}, + ) + # Use filesystem-based distributed lock so this works across workers/processes. session_path = self._viking_fs._uri_to_path(self._session_uri, ctx=self.ctx) async with LockContext(get_lock_manager(), [session_path], lock_mode="point"): @@ -456,13 +467,32 @@ async def _run_memory_extraction( from openviking.telemetry import OperationTelemetry, bind_telemetry tracker = get_task_tracker() - tracker.start(task_id) memories_extracted: Dict[str, int] = {} active_count_updated = 0 telemetry = OperationTelemetry(operation="session_commit_phase2", enabled=True) + archive_index = self._archive_index_from_uri(archive_uri) + redo_task_id: Optional[str] = None try: + if not await self._wait_for_previous_archive_done(archive_index): + await self._write_failed_marker( + archive_uri, + stage="waiting_previous_done", + error=( + f"Previous archive archive_{archive_index - 1:03d} failed; " + "this archive cannot proceed" + ), + blocked_by=f"archive_{archive_index - 1:03d}", + ) + tracker.fail( + task_id, + f"Previous archive archive_{archive_index - 1:03d} failed; " + "cannot continue session commit", + ) + return + + tracker.start(task_id) with bind_telemetry(telemetry): # redo-log protection redo_task_id = str(uuid.uuid4()) @@ -500,6 +530,16 @@ async def _run_memory_extraction( content=summary, ctx=self.ctx, ) + await self._viking_fs.write_file( + uri=f"{archive_uri}/.meta.json", + content=json.dumps( + { + "overview_tokens": -(-len(summary) // 4), + "abstract_tokens": -(-len(abstract) // 4), + } + ), + ctx=self.ctx, + ) # Memory extraction if self._session_compressor: @@ -546,22 +586,11 @@ async def _run_memory_extraction( # Phase 2 complete — update meta with telemetry and commit info snapshot = telemetry.finish("ok") - if snapshot: - llm = snapshot.summary.get("tokens", {}).get("llm", {}) - self._meta.llm_token_usage["prompt_tokens"] += llm.get("input", 0) - self._meta.llm_token_usage["completion_tokens"] += llm.get("output", 0) - self._meta.llm_token_usage["total_tokens"] += llm.get("total", 0) - self._meta.commit_count = self._compression.compression_index - for cat, count in memories_extracted.items(): - self._meta.memories_extracted[cat] = ( - self._meta.memories_extracted.get(cat, 0) + count - ) - self._meta.memories_extracted["total"] = ( - self._meta.memories_extracted.get("total", 0) + count - ) - self._meta.last_commit_at = get_current_timestamp() - self._meta.message_count = len(self._messages) - await self._save_meta() + await self._merge_and_save_commit_meta( + archive_index=archive_index, + memories_extracted=memories_extracted, + telemetry_snapshot=snapshot, + ) # Write .done file last — signals that all state is finalized await self._write_done_file(archive_uri, first_message_id, last_message_id) @@ -577,6 +606,13 @@ async def _run_memory_extraction( ) logger.info(f"Session {self.session_id} memory extraction completed") except Exception as e: + if redo_task_id: + get_lock_manager().redo_log.mark_done(redo_task_id) + await self._write_failed_marker( + archive_uri, + stage="memory_extraction", + error=str(e), + ) tracker.fail(task_id, str(e)) logger.exception(f"Memory extraction failed for session {self.session_id}") @@ -602,6 +638,29 @@ async def _write_done_file( ctx=self.ctx, ) + async def _write_failed_marker( + self, + archive_uri: str, + stage: str, + error: str, + blocked_by: str = "", + ) -> None: + """Persist a terminal failure marker for the archive.""" + if not self._viking_fs: + return + payload = { + "stage": stage, + "error": error, + "failed_at": get_current_timestamp(), + } + if blocked_by: + payload["blocked_by"] = blocked_by + await self._viking_fs.write_file( + uri=f"{archive_uri}/.failed.json", + content=json.dumps(payload, ensure_ascii=False), + ctx=self.ctx, + ) + def _update_active_counts(self) -> int: """Update active_count for used contexts/skills.""" if not self._vikingdb_manager: @@ -634,71 +693,400 @@ async def _update_active_counts_async(self) -> int: logger.info(f"Updated active_count for {updated} contexts/skills") return updated - async def get_context_for_search(self, query: str, max_messages: int = 20) -> Dict[str, Any]: - """Get session context for intent analysis. + async def get_session_context(self, token_budget: int = 128_000) -> Dict[str, Any]: + """Get assembled session context with the latest summary archive and merged messages.""" + context = await self._collect_session_context_components() + merged_messages = context["messages"] + message_tokens = sum(m.estimated_tokens for m in merged_messages) + remaining_budget = max(0, token_budget - message_tokens) - Args: - query: Query string for the current request. - max_messages: Maximum number of current messages to retrieve (default 20) + latest_archive = context["latest_archive"] + include_latest_overview = bool( + latest_archive and latest_archive["overview_tokens"] <= remaining_budget + ) + latest_archive_tokens = latest_archive["overview_tokens"] if include_latest_overview else 0 + if include_latest_overview: + remaining_budget -= latest_archive_tokens + + included_pre_archive_abstracts: List[Dict[str, str]] = [] + pre_archive_tokens = 0 + for item in context["pre_archive_abstracts"]: + if item["tokens"] > remaining_budget: + break + included_pre_archive_abstracts.append( + {"archive_id": item["archive_id"], "abstract": item["abstract"]} + ) + pre_archive_tokens += item["tokens"] + remaining_budget -= item["tokens"] - Returns: - - latest_archive_overview: Latest completed archive overview, if any - - current_messages: Current message list (List[Message]) - """ + archive_tokens = latest_archive_tokens + pre_archive_tokens + included_archives = (1 if include_latest_overview else 0) + len( + included_pre_archive_abstracts + ) + dropped_archives = max( + 0, context["total_archives"] - context["failed_archives"] - included_archives + ) + + return { + "latest_archive_overview": ( + latest_archive["overview"] if include_latest_overview else "" + ), + "latest_archive_id": latest_archive["archive_id"] if latest_archive else "", + "pre_archive_abstracts": included_pre_archive_abstracts, + "messages": [m.to_dict() for m in merged_messages], + "estimatedTokens": message_tokens + archive_tokens, + "stats": { + "totalArchives": context["total_archives"], + "includedArchives": included_archives, + "droppedArchives": dropped_archives, + "failedArchives": context["failed_archives"], + "activeTokens": message_tokens, + "archiveTokens": archive_tokens, + }, + } + + async def get_context_for_search(self, query: str, max_messages: int = 20) -> Dict[str, Any]: + """Get session context for intent analysis.""" del query # Current query no longer affects historical archive selection. - current_messages = list(self._messages[-max_messages:]) if self._messages else [] - latest_archive_overview = await self._get_latest_completed_archive_overview() + context = await self._collect_session_context_components() + current_messages = context["messages"] + if max_messages > 0: + current_messages = current_messages[-max_messages:] + else: + current_messages = [] return { - "latest_archive_overview": latest_archive_overview, + "latest_archive_overview": ( + context["latest_archive"]["overview"] if context["latest_archive"] else "" + ), "current_messages": current_messages, } + async def get_context_for_assemble(self, token_budget: int = 128_000) -> Dict[str, Any]: + """Backward-compatible alias for the assembled session context.""" + return await self.get_session_context(token_budget=token_budget) + + async def get_session_archive(self, archive_id: str) -> Dict[str, Any]: + """Get one completed archive by archive ID.""" + from openviking_cli.exceptions import NotFoundError + + for archive in await self._get_completed_archive_refs(): + if archive["archive_id"] != archive_id: + continue + + overview = await self._read_archive_overview(archive["archive_uri"]) + if not overview: + break + + abstract = await self._read_archive_abstract(archive["archive_uri"], overview) + return { + "archive_id": archive_id, + "abstract": abstract, + "overview": overview, + "messages": [ + m.to_dict() for m in await self._read_archive_messages(archive["archive_uri"]) + ], + } + + raise NotFoundError(archive_id, "session archive") + # ============= Internal methods ============= - async def _get_latest_completed_archive_overview( - self, - exclude_archive_uri: Optional[str] = None, - ) -> str: - """Return the newest completed archive overview, skipping incomplete archives.""" + async def _collect_session_context_components(self) -> Dict[str, Any]: + """Collect the latest summary archive and merged pending/live messages.""" + completed_archives = await self._get_completed_archive_refs() + latest_archive = None + pre_archive_abstracts: List[Dict[str, Any]] = [] + failed_archives = 0 + + for archive in completed_archives: + if latest_archive is None: + overview = await self._read_archive_overview(archive["archive_uri"]) + if not overview: + failed_archives += 1 + continue + + latest_archive = { + "archive_id": archive["archive_id"], + "archive_uri": archive["archive_uri"], + "overview": overview, + "overview_tokens": await self._read_archive_overview_tokens( + archive["archive_uri"], overview + ), + } + continue + + abstract = await self._read_archive_abstract(archive["archive_uri"]) + if abstract: + pre_archive_abstracts.append( + { + "archive_id": archive["archive_id"], + "abstract": abstract, + "tokens": -(-len(abstract) // 4), + } + ) + else: + failed_archives += 1 + + return { + "latest_archive": latest_archive, + "pre_archive_abstracts": pre_archive_abstracts, + "total_archives": len(completed_archives), + "failed_archives": failed_archives, + "messages": await self._get_pending_archive_messages() + list(self._messages), + } + + async def _list_archive_refs(self) -> List[Dict[str, Any]]: + """List archive refs sorted by archive index descending.""" if not self._viking_fs or self.compression.compression_index <= 0: - return "" + return [] try: history_items = await self._viking_fs.ls(f"{self._session_uri}/history", ctx=self.ctx) except Exception: - return "" + return [] - archive_names: List[str] = [] + refs: List[Dict[str, Any]] = [] for item in history_items: name = item.get("name") if isinstance(item, dict) else item - if name and name.startswith("archive_"): - archive_names.append(name) - - def _archive_index(name: str) -> int: + if not name or not name.startswith("archive_"): + continue try: - return int(name.split("_")[1]) + index = int(name.split("_")[1]) except Exception: - return -1 + continue + + refs.append( + { + "archive_id": name, + "archive_uri": f"{self._session_uri}/history/{name}", + "index": index, + } + ) + return sorted(refs, key=lambda item: item["index"], reverse=True) + + async def _get_completed_archive_refs( + self, + exclude_archive_uri: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Return completed archive refs sorted by archive index descending.""" + completed: List[Dict[str, Any]] = [] exclude = exclude_archive_uri.rstrip("/") if exclude_archive_uri else None - for name in sorted(archive_names, key=_archive_index, reverse=True): - archive_uri = f"{self._session_uri}/history/{name}" - if exclude and archive_uri == exclude: + + for archive in await self._list_archive_refs(): + if exclude and archive["archive_uri"] == exclude: continue try: - await self._viking_fs.read_file(f"{archive_uri}/.done", ctx=self.ctx) - overview = await self._viking_fs.read_file( - f"{archive_uri}/.overview.md", + await self._viking_fs.read_file(f"{archive['archive_uri']}/.done", ctx=self.ctx) + except Exception: + continue + completed.append(archive) + + return completed + + async def _get_blocking_failed_archive_ref(self) -> Optional[Dict[str, Any]]: + """Return the earliest unresolved failed archive, if any.""" + for archive in sorted(await self._list_archive_refs(), key=lambda item: item["index"]): + try: + await self._viking_fs.read_file(f"{archive['archive_uri']}/.done", ctx=self.ctx) + continue + except Exception: + pass + try: + await self._viking_fs.read_file( + f"{archive['archive_uri']}/.failed.json", ctx=self.ctx, ) - if overview: - return overview except Exception: continue + return archive + return None + + async def _read_archive_overview(self, archive_uri: str) -> str: + """Read archive overview text.""" + try: + overview = await self._viking_fs.read_file(f"{archive_uri}/.overview.md", ctx=self.ctx) + except Exception: + return "" + return overview or "" + + async def _read_archive_abstract(self, archive_uri: str, overview: str = "") -> str: + """Read archive abstract text, falling back to summary extraction.""" + try: + abstract = await self._viking_fs.read_file(f"{archive_uri}/.abstract.md", ctx=self.ctx) + except Exception: + abstract = "" - return "" + if abstract: + return abstract + + if not overview: + overview = await self._read_archive_overview(archive_uri) + return self._extract_abstract_from_summary(overview) + + async def _read_archive_overview_tokens(self, archive_uri: str, overview: str) -> int: + """Read overview token estimate from archive metadata.""" + overview_tokens = -(-len(overview) // 4) + try: + meta_content = await self._viking_fs.read_file( + f"{archive_uri}/.meta.json", ctx=self.ctx + ) + overview_tokens = json.loads(meta_content).get("overview_tokens", overview_tokens) + except Exception: + pass + return overview_tokens + + async def _read_archive_messages(self, archive_uri: str) -> List[Message]: + """Read archived messages from one archive.""" + try: + content = await self._viking_fs.read_file(f"{archive_uri}/messages.jsonl", ctx=self.ctx) + except Exception: + return [] + + messages: List[Message] = [] + for line in content.strip().split("\n"): + if not line.strip(): + continue + try: + messages.append(Message.from_dict(json.loads(line))) + except Exception: + continue + + return messages + + async def _get_latest_completed_archive_summary( + self, + exclude_archive_uri: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """Return the newest readable completed archive summary.""" + for archive in await self._get_completed_archive_refs(exclude_archive_uri): + overview = await self._read_archive_overview(archive["archive_uri"]) + if not overview: + continue + + return { + "archive_id": archive["archive_id"], + "archive_uri": archive["archive_uri"], + "overview": overview, + "abstract": await self._read_archive_abstract(archive["archive_uri"], overview), + "overview_tokens": await self._read_archive_overview_tokens( + archive["archive_uri"], overview + ), + } + + return None + + async def _get_latest_completed_archive_overview( + self, + exclude_archive_uri: Optional[str] = None, + ) -> str: + """Return the newest completed archive overview, skipping incomplete archives.""" + summary = await self._get_latest_completed_archive_summary(exclude_archive_uri) + return summary["overview"] if summary else "" + + async def _get_pending_archive_messages(self) -> List[Message]: + """Return messages from incomplete archives newer than the latest completed archive.""" + latest_completed_index = 0 + incomplete_archives: List[Dict[str, Any]] = [] + for archive in sorted(await self._list_archive_refs(), key=lambda item: item["index"]): + try: + await self._viking_fs.read_file(f"{archive['archive_uri']}/.done", ctx=self.ctx) + latest_completed_index = archive["index"] + except Exception: + incomplete_archives.append(archive) + + pending_messages: List[Message] = [] + for archive in incomplete_archives: + if archive["index"] <= latest_completed_index: + continue + pending_messages.extend(await self._read_archive_messages(archive["archive_uri"])) + + return pending_messages + + @staticmethod + def _archive_index_from_uri(archive_uri: str) -> int: + """Parse archive_NNN suffix into an integer index.""" + match = re.search(r"archive_(\d+)$", archive_uri.rstrip("/")) + if not match: + raise ValueError(f"Invalid archive URI: {archive_uri}") + return int(match.group(1)) + + async def _wait_for_previous_archive_done(self, archive_index: int) -> bool: + """Wait until the previous archive is done, or report dependency failure.""" + if archive_index <= 1 or not self._viking_fs: + return True + + previous_archive_uri = ( + f"{self._session_uri}/history/archive_{archive_index - 1:03d}" + ) + while True: + try: + await self._viking_fs.read_file(f"{previous_archive_uri}/.done", ctx=self.ctx) + return True + except Exception: + pass + + try: + await self._viking_fs.read_file( + f"{previous_archive_uri}/.failed.json", + ctx=self.ctx, + ) + return False + except Exception: + pass + + await asyncio.sleep(_ARCHIVE_WAIT_POLL_SECONDS) + + async def _merge_and_save_commit_meta( + self, + archive_index: int, + memories_extracted: Dict[str, int], + telemetry_snapshot: Any, + ) -> None: + """Reload and merge latest meta state before persisting commit results.""" + latest_meta = self._meta + try: + meta_content = await self._viking_fs.read_file( + f"{self._session_uri}/.meta.json", + ctx=self.ctx, + ) + latest_meta = SessionMeta.from_dict(json.loads(meta_content)) + except Exception: + latest_meta = self._meta + + if telemetry_snapshot: + llm = telemetry_snapshot.summary.get("tokens", {}).get("llm", {}) + latest_meta.llm_token_usage["prompt_tokens"] += llm.get("input", 0) + latest_meta.llm_token_usage["completion_tokens"] += llm.get("output", 0) + latest_meta.llm_token_usage["total_tokens"] += llm.get("total", 0) + + latest_meta.commit_count = max(latest_meta.commit_count, archive_index) + for cat, count in memories_extracted.items(): + latest_meta.memories_extracted[cat] = ( + latest_meta.memories_extracted.get(cat, 0) + count + ) + latest_meta.memories_extracted["total"] = ( + latest_meta.memories_extracted.get("total", 0) + count + ) + latest_meta.last_commit_at = get_current_timestamp() + latest_meta.message_count = await self._read_live_message_count() + self._meta = latest_meta + await self._save_meta() + + async def _read_live_message_count(self) -> int: + """Count current live session messages from persisted storage.""" + if not self._viking_fs: + return len(self._messages) + try: + content = await self._viking_fs.read_file( + f"{self._session_uri}/messages.jsonl", + ctx=self.ctx, + ) + except Exception: + return len(self._messages) + return len([line for line in content.strip().split("\n") if line.strip()]) def _extract_abstract_from_summary(self, summary: str) -> str: """Extract one-sentence overview from structured summary.""" diff --git a/openviking/sync_client.py b/openviking/sync_client.py index b74d545ad..d6c333c1c 100644 --- a/openviking/sync_client.py +++ b/openviking/sync_client.py @@ -51,6 +51,16 @@ def get_session(self, session_id: str, *, auto_create: bool = False) -> Dict[str """Get session details.""" return run_async(self._async_client.get_session(session_id, auto_create=auto_create)) + def get_session_context(self, session_id: str, token_budget: int = 128_000) -> Dict[str, Any]: + """Get assembled session context.""" + return run_async( + self._async_client.get_session_context(session_id, token_budget=token_budget) + ) + + def get_session_archive(self, session_id: str, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for a session.""" + return run_async(self._async_client.get_session_archive(session_id, archive_id)) + def delete_session(self, session_id: str) -> None: """Delete a session.""" run_async(self._async_client.delete_session(session_id)) diff --git a/openviking_cli/client/base.py b/openviking_cli/client/base.py index faeaac5a9..30fe8febe 100644 --- a/openviking_cli/client/base.py +++ b/openviking_cli/client/base.py @@ -206,6 +206,18 @@ async def get_session(self, session_id: str, *, auto_create: bool = False) -> Di """Get session details.""" ... + @abstractmethod + async def get_session_context( + self, session_id: str, token_budget: int = 128_000 + ) -> Dict[str, Any]: + """Get assembled session context for a session.""" + ... + + @abstractmethod + async def get_session_archive(self, session_id: str, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for a session.""" + ... + @abstractmethod async def delete_session(self, session_id: str) -> None: """Delete a session.""" diff --git a/openviking_cli/client/http.py b/openviking_cli/client/http.py index da50d4e83..9bfc024b2 100644 --- a/openviking_cli/client/http.py +++ b/openviking_cli/client/http.py @@ -19,6 +19,7 @@ AlreadyExistsError, DeadlineExceededError, EmbeddingFailedError, + FailedPreconditionError, InternalError, InvalidArgumentError, InvalidURIError, @@ -44,6 +45,7 @@ "INVALID_URI": InvalidURIError, "NOT_FOUND": NotFoundError, "ALREADY_EXISTS": AlreadyExistsError, + "FAILED_PRECONDITION": FailedPreconditionError, "UNAUTHENTICATED": UnauthenticatedError, "PERMISSION_DENIED": PermissionDeniedError, "UNAVAILABLE": UnavailableError, @@ -711,6 +713,23 @@ async def get_session(self, session_id: str, *, auto_create: bool = False) -> Di response = await self._http.get(f"/api/v1/sessions/{session_id}", params=params) return self._handle_response(response) + async def get_session_context( + self, session_id: str, token_budget: int = 128_000 + ) -> Dict[str, Any]: + """Get assembled session context.""" + response = await self._http.get( + f"/api/v1/sessions/{session_id}/context", + params={"token_budget": token_budget}, + ) + return self._handle_response(response) + + async def get_session_archive(self, session_id: str, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for a session.""" + response = await self._http.get( + f"/api/v1/sessions/{session_id}/archives/{archive_id}", + ) + return self._handle_response(response) + async def delete_session(self, session_id: str) -> None: """Delete a session.""" response = await self._http.delete(f"/api/v1/sessions/{session_id}") diff --git a/openviking_cli/client/sync_http.py b/openviking_cli/client/sync_http.py index a6b898a60..fba30a372 100644 --- a/openviking_cli/client/sync_http.py +++ b/openviking_cli/client/sync_http.py @@ -90,6 +90,14 @@ def get_session(self, session_id: str, *, auto_create: bool = False) -> Dict[str """Get session details.""" return run_async(self._async_client.get_session(session_id, auto_create=auto_create)) + def get_session_context(self, session_id: str, token_budget: int = 128_000) -> Dict[str, Any]: + """Get assembled session context.""" + return run_async(self._async_client.get_session_context(session_id, token_budget)) + + def get_session_archive(self, session_id: str, archive_id: str) -> Dict[str, Any]: + """Get one completed archive for a session.""" + return run_async(self._async_client.get_session_archive(session_id, archive_id)) + def delete_session(self, session_id: str) -> None: """Delete a session.""" run_async(self._async_client.delete_session(session_id)) diff --git a/openviking_cli/exceptions.py b/openviking_cli/exceptions.py index cd432552c..bcf73b2cd 100644 --- a/openviking_cli/exceptions.py +++ b/openviking_cli/exceptions.py @@ -78,6 +78,13 @@ def __init__(self, message: str, resource: Optional[str] = None): super().__init__(message, code="CONFLICT", details=details) +class FailedPreconditionError(OpenVikingError): + """Operation cannot proceed because a required precondition is unmet.""" + + def __init__(self, message: str, details: Optional[dict] = None): + super().__init__(message, code="FAILED_PRECONDITION", details=details) + + # ============= Authentication Errors ============= diff --git a/tests/server/test_api_sessions.py b/tests/server/test_api_sessions.py index 9fca12532..4d497498d 100644 --- a/tests/server/test_api_sessions.py +++ b/tests/server/test_api_sessions.py @@ -3,10 +3,65 @@ """Tests for session endpoints.""" +import asyncio +import json +from unittest.mock import patch + import httpx +import pytest +from openviking.message import Message from openviking.server.identity import RequestContext, Role from openviking_cli.session.user_id import UserIdentifier +from openviking_cli.utils.config.open_viking_config import OpenVikingConfigSingleton +from tests.utils.mock_agfs import MockLocalAGFS + + +@pytest.fixture(autouse=True) +def _configure_test_env(monkeypatch, tmp_path): + config_path = tmp_path / "ov.conf" + config_path.write_text( + json.dumps( + { + "storage": { + "workspace": str(tmp_path / "workspace"), + "agfs": {"backend": "local", "mode": "binding-client"}, + "vectordb": {"backend": "local"}, + }, + "embedding": { + "dense": { + "provider": "openai", + "model": "test-embedder", + "api_base": "http://127.0.0.1:11434/v1", + "dimension": 1024, + } + }, + "encryption": {"enabled": False}, + } + ), + encoding="utf-8", + ) + + mock_agfs = MockLocalAGFS(root_path=tmp_path / "mock_agfs_root") + + monkeypatch.setenv("OPENVIKING_CONFIG_FILE", str(config_path)) + OpenVikingConfigSingleton.reset_instance() + + with patch("openviking.utils.agfs_utils.create_agfs_client", return_value=mock_agfs): + yield + + OpenVikingConfigSingleton.reset_instance() + + +async def _wait_for_task(client: httpx.AsyncClient, task_id: str, timeout: float = 10.0): + for _ in range(int(timeout / 0.1)): + resp = await client.get(f"/api/v1/tasks/{task_id}") + if resp.status_code == 200: + task = resp.json()["result"] + if task["status"] in ("completed", "failed"): + return task + await asyncio.sleep(0.1) + raise TimeoutError(f"Task {task_id} did not complete within {timeout}s") async def test_create_session(client: httpx.AsyncClient): @@ -38,6 +93,66 @@ async def test_get_session(client: httpx.AsyncClient): assert body["result"]["session_id"] == session_id +async def test_get_session_context(client: httpx.AsyncClient): + create_resp = await client.post("/api/v1/sessions", json={}) + session_id = create_resp.json()["result"]["session_id"] + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "Current live message"}, + ) + + resp = await client.get(f"/api/v1/sessions/{session_id}/context") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["result"]["latest_archive_overview"] == "" + assert body["result"]["latest_archive_id"] == "" + assert body["result"]["pre_archive_abstracts"] == [] + assert [m["parts"][0]["text"] for m in body["result"]["messages"]] == ["Current live message"] + + +async def test_get_session_context_includes_incomplete_archive_messages( + client: httpx.AsyncClient, service +): + create_resp = await client.post("/api/v1/sessions", json={}) + session_id = create_resp.json()["result"]["session_id"] + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "Archived seed"}, + ) + commit_resp = await client.post(f"/api/v1/sessions/{session_id}/commit") + assert commit_resp.status_code == 200 + + ctx = RequestContext(user=UserIdentifier.the_default_user(), role=Role.ROOT) + session = service.sessions.session(ctx, session_id) + await session.load() + pending_messages = [ + Message.create_user("Pending user message"), + Message.create_assistant("Pending assistant response"), + ] + await session._viking_fs.write_file( + uri=f"{session.uri}/history/archive_002/messages.jsonl", + content="\n".join(msg.to_jsonl() for msg in pending_messages) + "\n", + ctx=session.ctx, + ) + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "Current live message"}, + ) + + resp = await client.get(f"/api/v1/sessions/{session_id}/context") + assert resp.status_code == 200 + body = resp.json() + assert [m["parts"][0]["text"] for m in body["result"]["messages"]] == [ + "Pending user message", + "Pending assistant response", + "Current live message", + ] + + async def test_add_message(client: httpx.AsyncClient): create_resp = await client.post("/api/v1/sessions", json={}) session_id = create_resp.json()["result"]["session_id"] @@ -175,3 +290,120 @@ async def fake_extract(_session_id: str, _ctx): body = resp.json() assert body["status"] == "ok" assert body["result"] == [{"uri": "viking://user/memories/mock.md"}] + + +async def test_get_session_context_endpoint_returns_trimmed_latest_archive_and_messages( + client: httpx.AsyncClient, +): + create_resp = await client.post("/api/v1/sessions", json={}) + session_id = create_resp.json()["result"]["session_id"] + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "archived message"}, + ) + commit_resp = await client.post(f"/api/v1/sessions/{session_id}/commit") + task_id = commit_resp.json()["result"]["task_id"] + await _wait_for_task(client, task_id) + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={ + "role": "assistant", + "parts": [ + {"type": "text", "text": "Running tool"}, + { + "type": "tool", + "tool_id": "tool_123", + "tool_name": "demo_tool", + "tool_uri": f"viking://session/{session_id}/tools/tool_123", + "tool_input": {"x": 1}, + "tool_status": "running", + }, + ], + }, + ) + + resp = await client.get(f"/api/v1/sessions/{session_id}/context?token_budget=1") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + + result = body["result"] + assert result["latest_archive_overview"] == "" + assert result["latest_archive_id"] == "archive_001" + assert result["pre_archive_abstracts"] == [] + assert len(result["messages"]) == 1 + assert result["messages"][0]["role"] == "assistant" + assert any( + part["type"] == "tool" and part["tool_id"] == "tool_123" + for part in result["messages"][0]["parts"] + ) + assert result["stats"]["totalArchives"] == 1 + assert result["stats"]["includedArchives"] == 0 + assert result["stats"]["droppedArchives"] == 1 + assert result["stats"]["failedArchives"] == 0 + + +async def test_get_session_archive_endpoint_returns_archive_details(client: httpx.AsyncClient): + create_resp = await client.post("/api/v1/sessions", json={}) + session_id = create_resp.json()["result"]["session_id"] + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "archived question"}, + ) + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "assistant", "content": "archived answer"}, + ) + commit_resp = await client.post(f"/api/v1/sessions/{session_id}/commit") + task_id = commit_resp.json()["result"]["task_id"] + await _wait_for_task(client, task_id) + + resp = await client.get(f"/api/v1/sessions/{session_id}/archives/archive_001") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["result"]["archive_id"] == "archive_001" + assert body["result"]["overview"] + assert body["result"]["abstract"] + assert [m["parts"][0]["text"] for m in body["result"]["messages"]] == [ + "archived question", + "archived answer", + ] + + +async def test_commit_endpoint_rejects_after_failed_archive( + client: httpx.AsyncClient, + service, +): + create_resp = await client.post("/api/v1/sessions", json={}) + session_id = create_resp.json()["result"]["session_id"] + + async def failing_extract(*args, **kwargs): + del args, kwargs + raise RuntimeError("synthetic extraction failure") + + service.sessions._session_compressor.extract_long_term_memories = failing_extract + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "first round"}, + ) + commit_resp = await client.post(f"/api/v1/sessions/{session_id}/commit") + task_id = commit_resp.json()["result"]["task_id"] + task = await _wait_for_task(client, task_id) + assert task["status"] == "failed" + + await client.post( + f"/api/v1/sessions/{session_id}/messages", + json={"role": "user", "content": "second round"}, + ) + resp = await client.post(f"/api/v1/sessions/{session_id}/commit") + + assert resp.status_code == 412 + body = resp.json() + assert body["status"] == "error" + assert body["error"]["code"] == "FAILED_PRECONDITION" + assert "unresolved failed archive" in body["error"]["message"] diff --git a/tests/server/test_http_client_sdk.py b/tests/server/test_http_client_sdk.py index c757c2ce5..2ea3dcab8 100644 --- a/tests/server/test_http_client_sdk.py +++ b/tests/server/test_http_client_sdk.py @@ -3,9 +3,13 @@ """SDK tests using AsyncHTTPClient against a real uvicorn server.""" +import asyncio + import pytest_asyncio +import pytest from openviking_cli.client.http import AsyncHTTPClient +from openviking_cli.exceptions import FailedPreconditionError from tests.server.conftest import SAMPLE_MD_CONTENT, TEST_TMP_DIR @@ -100,11 +104,67 @@ async def test_sdk_session_lifecycle(http_client): info = await client.get_session(session_id) assert info["session_id"] == session_id + context = await client.get_session_context(session_id) + assert context["latest_archive_overview"] == "" + assert context["latest_archive_id"] == "" + assert context["pre_archive_abstracts"] == [] + assert [m["parts"][0]["text"] for m in context["messages"]] == ["Hello from SDK"] + # List sessions = await client.list_sessions() assert isinstance(sessions, list) +async def test_sdk_get_session_archive(http_client): + client, _ = http_client + + session_info = await client.create_session() + session_id = session_info["session_id"] + + await client.add_message(session_id, "user", "Archive me") + commit_result = await client.commit_session(session_id) + task_id = commit_result["task_id"] + + for _ in range(100): + task = await client.get_task(task_id) + if task and task["status"] in ("completed", "failed"): + break + await asyncio.sleep(0.1) + + archive = await client.get_session_archive(session_id, "archive_001") + assert archive["archive_id"] == "archive_001" + assert archive["overview"] + assert archive["abstract"] + assert [m["parts"][0]["text"] for m in archive["messages"]] == ["Archive me"] + + +async def test_sdk_commit_raises_failed_precondition_after_failed_archive(http_client): + client, svc = http_client + + session_info = await client.create_session() + session_id = session_info["session_id"] + + async def failing_extract(*args, **kwargs): + del args, kwargs + raise RuntimeError("synthetic extraction failure") + + svc.session_compressor.extract_long_term_memories = failing_extract + + await client.add_message(session_id, "user", "First round") + commit_result = await client.commit_session(session_id) + task_id = commit_result["task_id"] + + for _ in range(100): + task = await client.get_task(task_id) + if task and task["status"] in ("completed", "failed"): + break + await asyncio.sleep(0.1) + + await client.add_message(session_id, "user", "Second round") + with pytest.raises(FailedPreconditionError, match="unresolved failed archive"): + await client.commit_session(session_id) + + # =================================================================== # Search # =================================================================== diff --git a/tests/session/test_session_commit.py b/tests/session/test_session_commit.py index ed1fddbd2..49c5eb2bc 100644 --- a/tests/session/test_session_commit.py +++ b/tests/session/test_session_commit.py @@ -4,11 +4,15 @@ """Commit tests""" import asyncio +import json + +import pytest from openviking import AsyncOpenViking from openviking.message import TextPart from openviking.service.task_tracker import get_task_tracker from openviking.session import Session +from openviking_cli.exceptions import FailedPreconditionError async def _wait_for_task(task_id: str, timeout: float = 30.0) -> dict: @@ -201,3 +205,31 @@ async def test_active_count_incremented_after_commit(self, client_with_resource_ assert count_after == count_before + 1, ( f"active_count not incremented: before={count_before}, after={count_after}" ) + + async def test_commit_blocks_after_failed_archive(self, client: AsyncOpenViking): + """A failed archive should block the next commit until it is resolved.""" + session = client.session(session_id="failed_archive_blocks_new_commit") + + async def failing_extract(*args, **kwargs): + del args, kwargs + raise RuntimeError("synthetic extraction failure") + + session._session_compressor.extract_long_term_memories = failing_extract + + session.add_message("user", [TextPart("First round message")]) + result = await session.commit_async() + task_result = await _wait_for_task(result["task_id"]) + + assert task_result["status"] == "failed" + + failed_marker = await session._viking_fs.read_file( + f"{result['archive_uri']}/.failed.json", + ctx=session.ctx, + ) + failed_payload = json.loads(failed_marker) + assert failed_payload["stage"] == "memory_extraction" + assert "synthetic extraction failure" in failed_payload["error"] + + session.add_message("user", [TextPart("Second round message")]) + with pytest.raises(FailedPreconditionError, match="unresolved failed archive"): + await session.commit_async() diff --git a/tests/session/test_session_context.py b/tests/session/test_session_context.py index 8eeacd6f9..e5f8d3fce 100644 --- a/tests/session/test_session_context.py +++ b/tests/session/test_session_context.py @@ -4,11 +4,101 @@ """Context retrieval tests""" import asyncio +import json +from unittest.mock import patch + +import pytest +import pytest_asyncio from openviking import AsyncOpenViking -from openviking.message import TextPart +from openviking.message import Message, TextPart +from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult from openviking.service.task_tracker import get_task_tracker from openviking.session import Session +from openviking_cli.utils.config.embedding_config import EmbeddingConfig +from openviking_cli.utils.config.open_viking_config import OpenVikingConfigSingleton +from openviking_cli.utils.config.vlm_config import VLMConfig +from tests.utils.mock_agfs import MockLocalAGFS + + +def _install_fake_embedder(monkeypatch): + class FakeEmbedder(DenseEmbedderBase): + def __init__(self): + super().__init__(model_name="test-fake-embedder") + + def embed(self, text: str, is_query: bool = False) -> EmbedResult: + return EmbedResult(dense_vector=[0.1] * 1024) + + def embed_batch(self, texts: list[str], is_query: bool = False) -> list[EmbedResult]: + return [self.embed(text, is_query=is_query) for text in texts] + + def get_dimension(self) -> int: + return 1024 + + monkeypatch.setattr(EmbeddingConfig, "get_embedder", lambda self: FakeEmbedder()) + + +def _install_fake_vlm(monkeypatch): + async def _fake_get_completion(self, prompt, thinking=False, max_retries=0): + return "# Test Summary\n\nFake summary for testing.\n\n## Details\nTest content." + + async def _fake_get_vision_completion(self, prompt, images, thinking=False): + return "Fake image description for testing." + + monkeypatch.setattr(VLMConfig, "is_available", lambda self: True) + monkeypatch.setattr(VLMConfig, "get_completion_async", _fake_get_completion) + monkeypatch.setattr(VLMConfig, "get_vision_completion_async", _fake_get_vision_completion) + + +def _write_test_config(tmp_path): + config_path = tmp_path / "ov.conf" + config_path.write_text( + json.dumps( + { + "storage": { + "workspace": str(tmp_path / "workspace"), + "agfs": {"backend": "local", "mode": "binding-client"}, + "vectordb": {"backend": "local"}, + }, + "embedding": { + "dense": { + "provider": "openai", + "model": "test-embedder", + "api_base": "http://127.0.0.1:11434/v1", + "dimension": 1024, + } + }, + "encryption": {"enabled": False}, + } + ), + encoding="utf-8", + ) + return config_path + + +@pytest_asyncio.fixture(scope="function") +async def client(test_data_dir, monkeypatch, tmp_path): + config_path = _write_test_config(tmp_path) + mock_agfs = MockLocalAGFS(root_path=tmp_path / "mock_agfs_root") + + OpenVikingConfigSingleton.reset_instance() + await AsyncOpenViking.reset() + monkeypatch.setenv("OPENVIKING_CONFIG_FILE", str(config_path)) + _install_fake_embedder(monkeypatch) + _install_fake_vlm(monkeypatch) + + with patch("openviking.utils.agfs_utils.create_agfs_client", return_value=mock_agfs): + client = AsyncOpenViking(path=str(test_data_dir)) + await client.initialize() + yield client + await client.close() + + OpenVikingConfigSingleton.reset_instance() + await AsyncOpenViking.reset() + + +def _estimate_tokens(text: str) -> int: + return -(-len(text) // 4) async def _wait_for_task(task_id: str, timeout: float = 30.0) -> dict: @@ -88,6 +178,64 @@ async def test_get_context_skips_incomplete_latest_archive(self, client: AsyncOp assert context["latest_archive_overview"] == completed_overview + async def test_get_context_includes_incomplete_archive_messages(self, client: AsyncOpenViking): + """Pending archive messages should be merged with current live messages.""" + session = client.session(session_id="archive_context_pending_messages_test") + + session.add_message("user", [TextPart("First message")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + pending_messages = [ + Message.create_user("Pending user message"), + Message.create_assistant("Pending assistant response"), + ] + await session._viking_fs.write_file( + uri=f"{session.uri}/history/archive_002/messages.jsonl", + content="\n".join(msg.to_jsonl() for msg in pending_messages) + "\n", + ctx=session.ctx, + ) + + session.add_message("user", [TextPart("Current live message")]) + context = await session.get_context_for_search(query="test") + + assert [m.content for m in context["current_messages"]] == [ + "Pending user message", + "Pending assistant response", + "Current live message", + ] + + async def test_get_context_max_messages_applies_after_pending_merge( + self, client: AsyncOpenViking + ): + """max_messages should trim the merged pending + live message sequence.""" + session = client.session(session_id="archive_context_pending_max_messages_test") + + session.add_message("user", [TextPart("First message")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + pending_messages = [ + Message.create_user("Pending 1"), + Message.create_assistant("Pending 2"), + ] + await session._viking_fs.write_file( + uri=f"{session.uri}/history/archive_002/messages.jsonl", + content="\n".join(msg.to_jsonl() for msg in pending_messages) + "\n", + ctx=session.ctx, + ) + + session.add_message("user", [TextPart("Live 1")]) + session.add_message("assistant", [TextPart("Live 2")]) + + context = await session.get_context_for_search(query="test", max_messages=3) + + assert [m.content for m in context["current_messages"]] == [ + "Pending 2", + "Live 1", + "Live 2", + ] + async def test_get_context_empty_session(self, session: Session): """Test getting context from empty session""" context = await session.get_context_for_search(query="test") @@ -100,20 +248,366 @@ async def test_get_context_after_commit(self, client: AsyncOpenViking): """Test getting context after commit""" session = client.session(session_id="post_commit_context_test") - # Add messages session.add_message("user", [TextPart("Test message before commit")]) session.add_message("assistant", [TextPart("Response before commit")]) - # Commit result = await session.commit_async() await _wait_for_task(result["task_id"]) - # Add new messages session.add_message("user", [TextPart("New message after commit")]) - # Getting context should include archive summary context = await session.get_context_for_search(query="test") assert isinstance(context, dict) assert context["latest_archive_overview"] assert len(context["current_messages"]) == 1 + + async def test_get_context_tracks_multiple_rapid_commits_by_done_boundary( + self, client: AsyncOpenViking + ): + """Context should only advance latest overview when the earlier archive is .done.""" + session = client.session(session_id="archive_context_done_boundary_test") + first_gate = asyncio.Event() + second_gate = asyncio.Event() + second_started = asyncio.Event() + + async def gated_extract(messages, **kwargs): + del kwargs + contents = " ".join(m.content for m in messages) + if "First round" in contents: + await first_gate.wait() + return [] + second_started.set() + await second_gate.wait() + return [] + + session._session_compressor.extract_long_term_memories = gated_extract + + session.add_message("user", [TextPart("First round user")]) + session.add_message("assistant", [TextPart("First round assistant")]) + result1 = await session.commit_async() + + session.add_message("user", [TextPart("Second round user")]) + session.add_message("assistant", [TextPart("Second round assistant")]) + result2 = await session.commit_async() + + context = await session.get_context_for_search(query="test") + assert context["latest_archive_overview"] == "" + assert [m.content for m in context["current_messages"]] == [ + "First round user", + "First round assistant", + "Second round user", + "Second round assistant", + ] + + first_gate.set() + await asyncio.wait_for(second_started.wait(), timeout=5.0) + + first_overview = await session._viking_fs.read_file( + f"{result1['archive_uri']}/.overview.md", + ctx=session.ctx, + ) + context = await session.get_context_for_search(query="test") + assert context["latest_archive_overview"] == first_overview + assert [m.content for m in context["current_messages"]] == [ + "Second round user", + "Second round assistant", + ] + + second_gate.set() + await _wait_for_task(result1["task_id"]) + await _wait_for_task(result2["task_id"]) + + second_overview = await session._viking_fs.read_file( + f"{result2['archive_uri']}/.overview.md", + ctx=session.ctx, + ) + context = await session.get_context_for_search(query="test") + assert context["latest_archive_overview"] == second_overview + assert context["current_messages"] == [] + + +class TestGetSessionContext: + """Test get_session_context""" + + async def test_get_session_context_returns_latest_archive_overview_and_history( + self, client: AsyncOpenViking, monkeypatch + ): + session = client.session(session_id="assemble_trim_test") + summaries = [ + "# Session Summary\n\n" + ("A" * 80), + "# Session Summary\n\n" + ("B" * 20), + ] + + async def fake_generate(_messages, latest_archive_overview=""): + del latest_archive_overview + return summaries.pop(0) + + monkeypatch.setattr(session, "_generate_archive_summary_async", fake_generate) + + session.add_message("user", [TextPart("first turn")]) + session.add_message("assistant", [TextPart("first reply")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + session.add_message("user", [TextPart("second turn")]) + session.add_message("assistant", [TextPart("second reply")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + session.add_message("user", [TextPart("active tail")]) + + newest_summary = "# Session Summary\n\n" + ("B" * 20) + active_tokens = sum(message.estimated_tokens for message in session.messages) + token_budget = active_tokens + _estimate_tokens(newest_summary) + + context = await session.get_session_context(token_budget=token_budget) + + assert context["latest_archive_overview"] == newest_summary + assert context["latest_archive_id"] == "archive_002" + assert context["pre_archive_abstracts"] == [] + assert len(context["messages"]) == 1 + assert context["messages"][0]["parts"][0]["text"] == "active tail" + assert context["estimatedTokens"] == token_budget + assert context["stats"] == { + "totalArchives": 2, + "includedArchives": 1, + "droppedArchives": 1, + "failedArchives": 0, + "activeTokens": active_tokens, + "archiveTokens": _estimate_tokens(newest_summary), + } + + async def test_get_session_context_counts_active_tool_parts( + self, session_with_tool_call: tuple[Session, str, str] + ): + session, _message_id, tool_id = session_with_tool_call + + context = await session.get_session_context() + + assert len(context["messages"]) == 1 + tool_parts = [part for part in context["messages"][0]["parts"] if part["type"] == "tool"] + assert tool_parts[0]["tool_id"] == tool_id + assert context["stats"]["activeTokens"] == session.messages[0].estimated_tokens + assert context["stats"]["activeTokens"] > _estimate_tokens("Executing tool...") + + async def test_get_session_context_reads_latest_overview_and_previous_abstracts( + self, client: AsyncOpenViking, monkeypatch + ): + """Overview should only be read for the latest archive; older archives use abstracts.""" + session = client.session(session_id="assemble_lazy_read_test") + summaries = [ + "# Summary\n\n" + ("A" * 80), + "# Summary\n\n" + ("B" * 80), + "# Summary\n\n" + ("C" * 80), + ] + + async def fake_generate(_messages, latest_archive_overview=""): + del latest_archive_overview + return summaries.pop(0) + + monkeypatch.setattr(session, "_generate_archive_summary_async", fake_generate) + + for word in ("first", "second", "third"): + session.add_message("user", [TextPart(f"{word} turn")]) + session.add_message("assistant", [TextPart(f"{word} reply")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + session.add_message("user", [TextPart("active tail")]) + + newest_summary = "# Summary\n\n" + ("C" * 80) + previous_abstract = "# Summary" + active_tokens = sum(m.estimated_tokens for m in session.messages) + token_budget = ( + active_tokens + + _estimate_tokens(newest_summary) + + (_estimate_tokens(previous_abstract) * 2) + ) + + original_read_file = session._viking_fs.read_file + read_uris: list[str] = [] + + async def tracking_read_file(*args, **kwargs): + uri = args[0] if args else kwargs.get("uri") + read_uris.append(uri) + return await original_read_file(*args, **kwargs) + + monkeypatch.setattr(session._viking_fs, "read_file", tracking_read_file) + + context = await session.get_session_context(token_budget=token_budget) + + assert context["latest_archive_overview"] == newest_summary + assert context["latest_archive_id"] == "archive_003" + assert context["pre_archive_abstracts"] == [ + {"archive_id": "archive_002", "abstract": "# Summary"}, + {"archive_id": "archive_001", "abstract": "# Summary"}, + ] + assert context["stats"]["includedArchives"] == 3 + assert context["stats"]["droppedArchives"] == 0 + + overview_reads = [u for u in read_uris if u.endswith(".overview.md")] + abstract_reads = [u for u in read_uris if u.endswith(".abstract.md")] + assert all("archive_003" in u for u in overview_reads), ( + f"Only newest archive overview should be read, got: {overview_reads}" + ) + assert all( + "archive_003" not in u and ("archive_002" in u or "archive_001" in u) + for u in abstract_reads + ), f"Only previous archive abstracts should be read, got: {abstract_reads}" + assert not any("archive_001/.overview.md" in u for u in overview_reads), ( + "Oldest archive overview should not be read" + ) + assert not any("archive_003/.abstract.md" in u for u in abstract_reads), ( + "Latest archive abstract should not be read for context history" + ) + + async def test_get_session_context_drops_oldest_pre_archive_abstracts_first( + self, client: AsyncOpenViking, monkeypatch + ): + session = client.session(session_id="assemble_trim_oldest_abstracts_test") + summaries = [ + "# Summary\n\n" + ("A" * 80), + "# Summary\n\n" + ("B" * 80), + "# Summary\n\n" + ("C" * 80), + ] + + async def fake_generate(_messages, latest_archive_overview=""): + del latest_archive_overview + return summaries.pop(0) + + monkeypatch.setattr(session, "_generate_archive_summary_async", fake_generate) + + for word in ("first", "second", "third"): + session.add_message("user", [TextPart(f"{word} turn")]) + session.add_message("assistant", [TextPart(f"{word} reply")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + session.add_message("user", [TextPart("active tail")]) + + newest_summary = "# Summary\n\n" + ("C" * 80) + previous_abstract = "# Summary" + active_tokens = sum(m.estimated_tokens for m in session.messages) + token_budget = ( + active_tokens + _estimate_tokens(newest_summary) + _estimate_tokens(previous_abstract) + ) + + context = await session.get_session_context(token_budget=token_budget) + + assert context["latest_archive_overview"] == newest_summary + assert context["latest_archive_id"] == "archive_003" + assert context["pre_archive_abstracts"] == [ + {"archive_id": "archive_002", "abstract": "# Summary"} + ] + assert context["estimatedTokens"] == token_budget + assert context["stats"]["totalArchives"] == 3 + assert context["stats"]["includedArchives"] == 2 + assert context["stats"]["droppedArchives"] == 1 + + async def test_get_session_context_falls_back_to_older_completed_archive( + self, client: AsyncOpenViking, monkeypatch + ): + session = client.session(session_id="assemble_failed_archive_test") + summaries = [ + "# Session Summary\n\narchive one", + "# Session Summary\n\narchive two", + ] + + async def fake_generate(_messages, latest_archive_overview=""): + del latest_archive_overview + return summaries.pop(0) + + monkeypatch.setattr(session, "_generate_archive_summary_async", fake_generate) + + session.add_message("user", [TextPart("turn one")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + session.add_message("user", [TextPart("turn two")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + original_read_file = session._viking_fs.read_file + + async def flaky_read_file(*args, **kwargs): + uri = args[0] if args else kwargs.get("uri") + if isinstance(uri, str) and uri.endswith("archive_002/.overview.md"): + raise RuntimeError("simulated archive read failure") + return await original_read_file(*args, **kwargs) + + monkeypatch.setattr(session._viking_fs, "read_file", flaky_read_file) + + context = await session.get_session_context(token_budget=128_000) + + assert context["latest_archive_overview"] == "# Session Summary\n\narchive one" + assert context["latest_archive_id"] == "archive_001" + assert context["pre_archive_abstracts"] == [] + assert context["stats"]["totalArchives"] == 2 + assert context["stats"]["includedArchives"] == 1 + assert context["stats"]["droppedArchives"] == 0 + assert context["stats"]["failedArchives"] == 1 + + async def test_get_session_context_budget_trim_keeps_latest_archive_id( + self, client: AsyncOpenViking, monkeypatch + ): + session = client.session(session_id="assemble_trim_id_test") + + async def fake_generate(_messages, latest_archive_overview=""): + del latest_archive_overview + return "# Session Summary\n\n" + ("Z" * 80) + + monkeypatch.setattr(session, "_generate_archive_summary_async", fake_generate) + + session.add_message("user", [TextPart("turn one")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + context = await session.get_session_context(token_budget=1) + + assert context["latest_archive_overview"] == "" + assert context["latest_archive_id"] == "archive_001" + assert context["pre_archive_abstracts"] == [] + assert context["stats"]["includedArchives"] == 0 + assert context["stats"]["droppedArchives"] == 1 + + +class TestGetSessionArchive: + """Test get_session_archive""" + + async def test_get_session_archive_returns_messages_and_summary( + self, client: AsyncOpenViking, monkeypatch + ): + session = client.session(session_id="session_archive_expand_test") + summaries = [ + "# Session Summary\n\narchive one", + "# Session Summary\n\narchive two", + ] + + async def fake_generate(_messages, latest_archive_overview=""): + del latest_archive_overview + return summaries.pop(0) + + monkeypatch.setattr(session, "_generate_archive_summary_async", fake_generate) + + session.add_message("user", [TextPart("turn one")]) + session.add_message("assistant", [TextPart("reply one")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + session.add_message("user", [TextPart("turn two")]) + result = await session.commit_async() + await _wait_for_task(result["task_id"]) + + archive = await session.get_session_archive("archive_001") + + assert archive["archive_id"] == "archive_001" + assert archive["abstract"] == "# Session Summary" + assert archive["overview"] == "# Session Summary\n\narchive one" + assert [m["parts"][0]["text"] for m in archive["messages"]] == ["turn one", "reply one"] + + async def test_get_session_archive_raises_for_missing_archive(self, client: AsyncOpenViking): + session = client.session(session_id="missing_session_archive_test") + + with pytest.raises(Exception, match="Session archive not found: archive_999"): + await session.get_session_archive("archive_999") diff --git a/tests/test_session_task_tracking.py b/tests/test_session_task_tracking.py index 270afd011..723742a0f 100644 --- a/tests/test_session_task_tracking.py +++ b/tests/test_session_task_tracking.py @@ -222,11 +222,11 @@ async def failing_extract(_context, _user, _session_id): assert "memory_extraction_failed" in result["error"] -# ── Duplicate commit rejection ── +# ── Duplicate commit acceptance ── -async def test_duplicate_commit_rejected(api_client): - """Second commit on same session should be rejected while first is running.""" +async def test_duplicate_commit_returns_second_task(api_client): + """Second commit on same session should also be accepted with its own task.""" client, service = api_client session_id = await _new_session_with_message(client) @@ -237,11 +237,14 @@ async def test_duplicate_commit_rejected(api_client): # First commit resp1 = await client.post(f"/api/v1/sessions/{session_id}/commit") assert resp1.json()["result"]["status"] == "accepted" + task_id_1 = resp1.json()["result"]["task_id"] - # Second commit should be rejected + # Second commit should also be accepted resp2 = await client.post(f"/api/v1/sessions/{session_id}/commit") - assert resp2.json()["status"] == "error" - assert "already has a commit in progress" in resp2.json()["error"]["message"] + assert resp2.status_code == 200 + assert resp2.json()["result"]["status"] == "accepted" + task_id_2 = resp2.json()["result"]["task_id"] + assert task_id_1 != task_id_2 gate.set() await asyncio.sleep(0.1)