diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..b781633 --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +# API Keys (at least one required) +OPENAI_API_KEY= +ANTHROPIC_API_KEY= +GEMINI_API_KEY= + +# Backend-v2 +BIND_ADDR=127.0.0.1:8080 +LOG_LEVEL=info diff --git a/.gitignore b/.gitignore index c8a702b..a8a299d 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,9 @@ Thumbs.db # Backend data (downloaded via scripts/setup-data.sh) backend/data/ +# Speedwagon store (auto-created on first request) +backend-v2/.speedwagon/ + # Claude Code / OMC / Plugin state .claude/ CLAUDE.md diff --git a/Cargo.lock b/Cargo.lock index 7db674f..4166f2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,7 @@ dependencies = [ "axum-extra", "bytes", "chrono", + "dashmap", "dotenvy", "futures-util", "http-body-util", @@ -97,7 +98,7 @@ dependencies = [ [[package]] name = "ailoy" version = "0.3.0" -source = "git+https://github.com/brekkylab/ailoy.git?rev=922d25b2c0733b7c95df4d4fe232a73cfdcef928#922d25b2c0733b7c95df4d4fe232a73cfdcef928" +source = "git+https://github.com/brekkylab/ailoy.git?rev=098a8289272ccce3cf5719e20732050cb047d04a#098a8289272ccce3cf5719e20732050cb047d04a" dependencies = [ "anyhow", "async-stream", @@ -1930,6 +1931,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.11.0" @@ -2913,6 +2928,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -4428,7 +4449,7 @@ dependencies = [ "libc", "libloading", "log", - "lru 0.12.5", + "lru 0.16.4", "msb_krun_arch", "msb_krun_hvf", "msb_krun_polly", @@ -7646,7 +7667,7 @@ checksum = "f2f6fb2847f6742cd76af783a2a2c49e9375d0a111c7bef6f71cd9e738c72d6e" dependencies = [ "memoffset", "tempfile", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3b73529..be4d94b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,5 @@ members = [ ] [workspace.dependencies] -ailoy = { git = "https://github.com/brekkylab/ailoy.git", rev = "922d25b2c0733b7c95df4d4fe232a73cfdcef928", features = ["sandbox"] } -# ailoy = { path = "../ailoy", features = ["sandbox"] } +ailoy = { git = "https://github.com/brekkylab/ailoy.git", rev = "098a8289272ccce3cf5719e20732050cb047d04a", features = ["sandbox"] } speedwagon = { path = "./speedwagon", default-features = false } diff --git a/backend-v2/Cargo.toml b/backend-v2/Cargo.toml index 6cb1c83..abed982 100644 --- a/backend-v2/Cargo.toml +++ b/backend-v2/Cargo.toml @@ -20,6 +20,7 @@ axum = { version = "0.8", features = ["multipart"] } axum-extra = { version = "0.10", features = ["typed-header"] } bytes = "1" chrono = { version = "0.4.42", features = ["serde", "clock"] } +dashmap = "6" dotenvy = "0.15" futures-util = "0.3" schemars = { version = "0.8", features = ["uuid1", "chrono"] } @@ -28,7 +29,7 @@ serde_json = "1.0.149" speedwagon = { workspace = true } sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio-rustls"] } thiserror = "2.0.17" -tokio = { version = "1.48.0", features = ["sync", "fs", "rt", "rt-multi-thread", "macros"] } +tokio = { version = "1.48.0", features = ["sync", "fs", "rt", "rt-multi-thread", "macros", "time"] } tokio-stream = "0.1" tower-http = { version = "0.6", features = ["cors"] } tracing = "0.1" diff --git a/backend-v2/src/error.rs b/backend-v2/src/error.rs index a51a20c..e59a40c 100644 --- a/backend-v2/src/error.rs +++ b/backend-v2/src/error.rs @@ -1,6 +1,11 @@ use schemars::JsonSchema; use serde::Serialize; +use axum::{Json, http::StatusCode}; + +pub type ApiError = (StatusCode, Json); +pub type ApiResult = Result; + #[derive(Debug, JsonSchema, Serialize)] pub struct AppError { error: String, @@ -12,19 +17,12 @@ impl AppError { error: errstr.into(), } } - // fn not_found(msg: impl Into) -> ApiErr { - // (StatusCode::NOT_FOUND, Json(Self { error: msg.into() })) - // } - // fn conflict(msg: impl Into) -> ApiErr { - // (StatusCode::CONFLICT, Json(Self { error: msg.into() })) - // } - // fn bad_request(msg: impl Into) -> ApiErr { - // (StatusCode::BAD_REQUEST, Json(Self { error: msg.into() })) - // } - // fn internal(msg: impl Into) -> ApiErr { - // ( - // StatusCode::INTERNAL_SERVER_ERROR, - // Json(Self { error: msg.into() }), - // ) - // } + + pub fn internal(msg: impl Into) -> ApiError { + (StatusCode::INTERNAL_SERVER_ERROR, Json(Self::new(msg))) + } + + pub fn not_found(msg: impl Into) -> ApiError { + (StatusCode::NOT_FOUND, Json(Self::new(msg))) + } } diff --git a/backend-v2/src/main.rs b/backend-v2/src/main.rs index c039a8b..84814ca 100644 --- a/backend-v2/src/main.rs +++ b/backend-v2/src/main.rs @@ -1,12 +1,16 @@ +use std::path::PathBuf; use std::sync::Arc; use agent_k_backend::{repository, router, state::AppState}; use aide::axum::ApiRouter; use aide::openapi::{Info, OpenApi}; use aide::scalar::Scalar; +use ailoy::agent::default_provider_mut; +use ailoy::lang_model::LangModelProvider; use axum::Extension; use axum::response::IntoResponse; -use tokio::sync::Mutex; +use speedwagon::{Store, build_tools}; +use tokio::sync::RwLock; use tower_http::cors::{Any, CorsLayer}; #[tokio::main] @@ -46,7 +50,35 @@ async fn main() -> std::io::Result<()> { .await .expect("failed to initialise repository"); - let app_state = Arc::new(Mutex::new(AppState::new(repo))); + let store_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(".speedwagon"); + let store = Arc::new(RwLock::new( + Store::new(store_path).expect("speedwagon store init"), + )); + + // Register API keys with the global provider (needed by Agent::try_with_tools) + { + let mut provider = default_provider_mut().await; + + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + provider + .models + .insert("openai/*".into(), LangModelProvider::openai(key)); + } + if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { + provider + .models + .insert("anthropic/*".into(), LangModelProvider::anthropic(key)); + } + if let Ok(key) = std::env::var("GEMINI_API_KEY") { + provider + .models + .insert("google/*".into(), LangModelProvider::gemini(key)); + } + + provider.tools = build_tools(store.clone()); + } + + let app_state = Arc::new(AppState::new(repo, store)); let app = router::get_router(app_state) .finish_api(&mut openapi) .merge( diff --git a/backend-v2/src/model/session.rs b/backend-v2/src/model/session.rs index 9cc7c5f..b0c924f 100644 --- a/backend-v2/src/model/session.rs +++ b/backend-v2/src/model/session.rs @@ -7,9 +7,9 @@ use uuid::Uuid; #[serde(deny_unknown_fields)] pub struct CreateSessionRequest {} -/// API representation of a session. +/// API response for GET /sessions (list) and POST /sessions (create) -- no messages #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub struct Session { +pub struct SessionResponse { pub id: Uuid, pub created_at: DateTime, pub updated_at: DateTime, diff --git a/backend-v2/src/router.rs b/backend-v2/src/router.rs index a3fe7df..1c9e28c 100644 --- a/backend-v2/src/router.rs +++ b/backend-v2/src/router.rs @@ -6,11 +6,9 @@ use aide::axum::{ routing::{delete, post}, }; use ailoy::{ - agent::AgentBuilder, - lang_model::{LangModel, LangModelProvider}, + agent::{Agent, AgentBuilder, AgentCard}, message::{Message, MessageOutput, Part, Role}, runenv::{Sandbox, SandboxConfig}, - tool::{BuiltinToolProvider, make_builtin_tool}, }; use axum::{ Json, @@ -20,23 +18,23 @@ use axum::{ }; use chrono::Utc; use futures_util::StreamExt; -use tokio::sync::Mutex; +use speedwagon::SpeedwagonSpec; use uuid::Uuid; use crate::{ - error::AppError, - model::{CreateSessionRequest, SendMessageRequest, SendMessageResponse, Session}, + error::{ApiResult, AppError}, + model::{CreateSessionRequest, SendMessageRequest, SendMessageResponse, SessionResponse}, state::AppState, }; -const DEFAULT_MODEL: &str = "anthropic/claude-haiku-4-5-20251001"; +const DEFAULT_MODEL: &str = "openai/gpt-5.4-mini"; fn sandbox_name_for(id: &Uuid) -> String { let s = id.simple().to_string(); format!("session-{}", &s[..12]) } -pub fn get_router(state: Arc>) -> ApiRouter { +pub fn get_router(state: Arc) -> ApiRouter { ApiRouter::new() .api_route("/sessions", post(create_session)) .api_route("/sessions/{id}", delete(delete_session)) @@ -52,10 +50,81 @@ pub fn get_router(state: Arc>) -> ApiRouter { .with_state(state) } +async fn build_agent(sandbox: Sandbox) -> Result { + // Speedwagon RAG subagent + let sw_card = AgentCard { + name: "speedwagon".into(), + description: "Search the knowledge base for answers. \ + This tool has access to uploaded documents that may contain \ + information the model doesn't have. \ + Use it for any question that could be answered from the knowledge base." + .into(), + skills: vec![], + }; + let sw_spec = SpeedwagonSpec::new().card(sw_card.clone()).into_spec(); + + AgentBuilder::new(DEFAULT_MODEL) + .instruction(concat!( + "You are a versatile assistant with access to code execution tools ", + "(bash, python), web search, and a knowledge base (speedwagon). ", + "You MUST use the speedwagon tool to search the document corpus ", + "before answering ANY factual question — even if you think you already know the answer. ", + "The corpus contains authoritative information that may differ from your training data. ", + "Use bash and python tools for computation, data analysis, and code execution tasks. ", + "Only skip tools for greetings or casual conversation.", + )) + .tool("bash") + .tool("python_repl") + .tool("web_search") + .runenv(sandbox) + .subagent(sw_spec) + .build() + .await + .map_err(|e| e.to_string()) +} + +// Alternative: main agent uses speedwagon tools directly (no subagent delegation). +// Materialize speedwagon ToolFactory entries for the main agent's spec so it can +// call search functions itself, instead of routing through a dedicated subagent. +// +// async fn build_agent(sandbox: Arc, toolset: &ToolSet) -> Result { +// let (bash, python, web_search) = tokio::try_join!( +// make_builtin_tool(&BuiltinToolProvider::Bash {}), +// make_builtin_tool(&BuiltinToolProvider::PythonRepl {}), +// make_builtin_tool(&BuiltinToolProvider::WebSearch {}), +// ) +// .map_err(|e| e.to_string())?; + +// let model = build_lang_model(DEFAULT_MODEL)?; +// let stub_spec = AgentSpec::new(DEFAULT_MODEL); + +// let mut builder = AgentBuilder::new(model) +// .instruction(concat!( +// "You are a versatile assistant with access to code execution tools ", +// "(bash, python), web search, and a knowledge base. ", +// "You MUST use the knowledge base search tools ", +// "before answering ANY factual question. ", +// "Use bash and python tools for computation and code execution tasks. ", +// "Only skip tools for greetings or casual conversation.", +// )) +// .tool(bash) +// .tool(python) +// .tool(web_search) +// .runenv(sandbox); + +// // Materialize each speedwagon ToolFactory into a concrete Tool. +// // ToolFactory::make(spec) selects the right implementation (e.g. sandbox-aware). +// for (_name, factory) in toolset.iter() { +// builder = builder.tool(factory.make(&stub_spec)); +// } + +// builder.build().await.map_err(|e| e.to_string()) +// } + async fn create_session( - State(state): State>>, + State(state): State>, Json(_payload): Json, -) -> Result<(StatusCode, Json), (StatusCode, Json)> { +) -> ApiResult<(StatusCode, Json)> { let id = Uuid::new_v4(); let sandbox_name = sandbox_name_for(&id); @@ -64,23 +133,27 @@ async fn create_session( persist: true, ..Default::default() }; - let sandbox = Arc::new(Sandbox::new(cfg).await.map_err(internal)?); + let sandbox = Sandbox::new(cfg) + .await + .map_err(|e| AppError::internal(e.to_string()))?; - let agent = build_agent(sandbox).await.map_err(internal)?; + let agent = build_agent(sandbox) + .await + .map_err(|e| AppError::internal(e))?; let now = Utc::now(); - - { - let mut st = state.lock().await; - st.repository.create_session(id).await.map_err(internal)?; - st.insert_agent(id, agent); - } + state + .repository + .create_session(id) + .await + .map_err(|e| AppError::internal(e.to_string()))?; + state.insert_agent(id, agent); tracing::info!(%id, sandbox = %sandbox_name, "session created"); Ok(( StatusCode::CREATED, - Json(Session { + Json(SessionResponse { id, created_at: now, updated_at: now, @@ -89,42 +162,37 @@ async fn create_session( } async fn delete_session( - State(state): State>>, + State(state): State>, Path(id): Path, -) -> Result)> { - let agent_arc = { - let mut st = state.lock().await; - - // Verify session exists in DB (covers the in-memory-less case too). - if st - .repository - .get_session(id) - .await - .map_err(internal)? - .is_none() - { - return Err(( - StatusCode::NOT_FOUND, - Json(AppError::new("session not found")), - )); - } +) -> ApiResult { + if state + .repository + .get_session(id) + .await + .map_err(|e| AppError::internal(e.to_string()))? + .is_none() + { + return Err(AppError::not_found("session not found")); + } - st.repository.delete_session(id).await.map_err(internal)?; - st.remove_agent(&id) - }; + state + .repository + .delete_session(id) + .await + .map_err(|e| AppError::internal(e.to_string()))?; + let agent_arc = state.remove_agent(&id); - // Wait for any in-progress run before dropping. if let Some(arc) = agent_arc { drop(arc.lock().await); drop(arc); } - let name = sandbox_name_for(&id); - ailoy::runenv::remove_persisted(&name) - .await - .map_err(internal)?; + let sandbox_name = sandbox_name_for(&id); + if let Err(e) = ailoy::runenv::remove_persisted(&sandbox_name).await { + tracing::warn!(%id, "failed to remove persisted sandbox: {e}"); + } - tracing::info!(%id, sandbox = %name, "session deleted"); + tracing::info!(%id, "session deleted"); Ok(StatusCode::NO_CONTENT) } @@ -134,111 +202,95 @@ async fn delete_session( /// the session and its message history are in the DB. This function rebuilds /// the agent and restores the history so the next turn starts with full context. async fn resolve_agent( - state: &Arc>, + state: &Arc, id: Uuid, -) -> Result>, (StatusCode, Json)> { - // Fast path: agent already in memory. - { - let st = state.lock().await; - if let Some(arc) = st.get_agent(&id) { - return Ok(arc); - } +) -> ApiResult>> { + if let Some(arc) = state.get_agent(&id) { + return Ok(arc); } - // Slow path: session must exist in DB. - let (session_exists, history, repo) = { - let st = state.lock().await; - let exists = st - .repository - .get_session(id) - .await - .map_err(internal)? - .is_some(); - let history = if exists { - st.repository.get_messages(id).await.map_err(internal)? - } else { - vec![] - }; - (exists, history, st.repository.clone()) - }; + let session_exists = state + .repository + .get_session(id) + .await + .map_err(|e| AppError::internal(e.to_string()))? + .is_some(); if !session_exists { - return Err(( - StatusCode::NOT_FOUND, - Json(AppError::new("session not found")), - )); + return Err(AppError::not_found("session not found")); } - // Build agent outside the mutex (async I/O). + let history = state + .repository + .get_messages(id) + .await + .map_err(|e| AppError::internal(e.to_string()))?; + let sandbox_name = sandbox_name_for(&id); let cfg = SandboxConfig { - name: Some(sandbox_name.clone()), + name: Some(sandbox_name), persist: true, ..Default::default() }; - let sandbox = Arc::new(Sandbox::new(cfg).await.map_err(internal)?); - let mut agent = build_agent(sandbox).await.map_err(internal)?; - - // Restore persisted history so the agent has full conversation context. - agent.state.history = history; + let sandbox = Sandbox::new(cfg) + .await + .map_err(|e| AppError::internal(e.to_string()))?; - tracing::info!(%id, sandbox = %sandbox_name, "agent lazy-created with history restored"); + let mut agent = build_agent(sandbox) + .await + .map_err(|e| AppError::internal(e))?; - let _ = repo; // repo clone kept alive until here + agent.state.history = history; + tracing::info!(%id, "agent lazy-created with history restored"); - // Insert — if another request won the race, use theirs. - let mut st = state.lock().await; - if let Some(existing) = st.get_agent(&id) { + if let Some(existing) = state.get_agent(&id) { return Ok(existing); } - st.insert_agent(id, agent); - Ok(st.get_agent(&id).unwrap()) + state.insert_agent(id, agent); + Ok(state.get_agent(&id).unwrap()) } async fn get_message_history( - State(state): State>>, + State(state): State>, Path(id): Path, -) -> Result>, (StatusCode, Json)> { - let st = state.lock().await; - if st +) -> ApiResult>> { + if state .repository .get_session(id) .await - .map_err(internal)? + .map_err(|e| AppError::internal(e.to_string()))? .is_none() { - return Err(( - StatusCode::NOT_FOUND, - Json(AppError::new("session not found")), - )); + return Err(AppError::not_found("session not found")); } - let messages = st.repository.get_messages(id).await.map_err(internal)?; + let messages = state + .repository + .get_messages(id) + .await + .map_err(|e| AppError::internal(e.to_string()))?; Ok(Json(messages)) } async fn clear_message_history( - State(state): State>>, + State(state): State>, Path(id): Path, -) -> Result)> { - let agent_arc = { - let st = state.lock().await; - if st - .repository - .get_session(id) - .await - .map_err(internal)? - .is_none() - { - return Err(( - StatusCode::NOT_FOUND, - Json(AppError::new("session not found")), - )); - } - st.repository.clear_messages(id).await.map_err(internal)?; - st.get_agent(&id) - }; +) -> ApiResult { + if state + .repository + .get_session(id) + .await + .map_err(|e| AppError::internal(e.to_string()))? + .is_none() + { + return Err(AppError::not_found("session not found")); + } + state + .repository + .clear_messages(id) + .await + .map_err(|e| AppError::internal(e.to_string()))?; - if let Some(arc) = agent_arc { + if let Some(arc) = state.get_agent(&id) { arc.lock().await.state.history.clear(); } @@ -247,10 +299,10 @@ async fn clear_message_history( } async fn send_message( - State(state): State>>, + State(state): State>, Path(id): Path, Json(payload): Json, -) -> Result, (StatusCode, Json)> { +) -> ApiResult> { let agent_arc = resolve_agent(&state, id).await?; let prev_len = agent_arc.lock().await.get_history().len(); @@ -261,37 +313,31 @@ async fn send_message( let mut stream = agent.run(msg); let mut outputs: Vec = Vec::new(); while let Some(item) = stream.next().await { - outputs.push(item.map_err(internal)?); + outputs.push(item.map_err(|e| AppError::internal(e.to_string()))?); } outputs }; - // Persist newly added history entries. let new_messages = { let agent = agent_arc.lock().await; agent.get_history()[prev_len..].to_vec() }; state - .lock() - .await .repository .append_messages(id, &new_messages) .await - .map_err(internal)?; + .map_err(|e| AppError::internal(e.to_string()))?; Ok(Json(outputs)) } async fn send_message_stream( - State(state): State>>, + State(state): State>, Path(id): Path, Json(payload): Json, -) -> Result< - Sse> + Send + 'static>, - (StatusCode, Json), -> { +) -> ApiResult> + Send + 'static>> { let agent_arc = resolve_agent(&state, id).await?; - let repo = state.lock().await.repository.clone(); + let repo = state.repository.clone(); let prev_len = agent_arc.lock().await.get_history().len(); let content = payload.content; @@ -315,10 +361,8 @@ async fn send_message_stream( } } } - // Drop the mutable stream borrow before taking an immutable borrow of history. drop(run); - // Persist after stream is fully consumed. let new_msgs = agent.get_history()[prev_len..].to_vec(); if let Err(e) = repo.append_messages(id, &new_msgs).await { tracing::error!(%id, "failed to persist messages: {e}"); @@ -329,58 +373,3 @@ async fn send_message_stream( Ok(Sse::new(stream).keep_alive(KeepAlive::default())) } - -async fn build_agent(sandbox: Arc) -> Result { - let (bash, python, web_search) = tokio::try_join!( - make_builtin_tool(&BuiltinToolProvider::Bash {}), - make_builtin_tool(&BuiltinToolProvider::PythonRepl {}), - make_builtin_tool(&BuiltinToolProvider::WebSearch {}), - ) - .map_err(|e| e.to_string())?; - let model = build_lang_model(DEFAULT_MODEL)?; - AgentBuilder::new(model) - .tool(bash) - .tool(python) - .tool(web_search) - .runenv(sandbox) - .build() - .await - .map_err(|e| e.to_string()) -} - -fn build_lang_model(model_full_id: &str) -> Result { - if let Some(m) = model_full_id.strip_prefix("anthropic/") { - let key = std::env::var("ANTHROPIC_API_KEY") - .map_err(|_| "ANTHROPIC_API_KEY not set".to_string())?; - Ok(LangModel::new( - m.to_string(), - LangModelProvider::anthropic(key), - )) - } else if let Some(m) = model_full_id.strip_prefix("openai/") { - let key = - std::env::var("OPENAI_API_KEY").map_err(|_| "OPENAI_API_KEY not set".to_string())?; - Ok(LangModel::new( - m.to_string(), - LangModelProvider::openai(key), - )) - } else if let Some(m) = model_full_id.strip_prefix("google/") { - let key = - std::env::var("GEMINI_API_KEY").map_err(|_| "GEMINI_API_KEY not set".to_string())?; - Ok(LangModel::new( - m.to_string(), - LangModelProvider::gemini(key), - )) - } else { - Err(format!( - "unknown provider prefix in model id: {}", - model_full_id - )) - } -} - -fn internal(e: impl std::fmt::Display) -> (StatusCode, Json) { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(AppError::new(e.to_string())), - ) -} diff --git a/backend-v2/src/state.rs b/backend-v2/src/state.rs index 5d9659e..b048ee5 100644 --- a/backend-v2/src/state.rs +++ b/backend-v2/src/state.rs @@ -1,34 +1,37 @@ -use std::collections::HashMap; use std::sync::Arc; use ailoy::agent::Agent; +use dashmap::DashMap; +use speedwagon::SharedStore; use tokio::sync::Mutex; use uuid::Uuid; use crate::repository::AppRepository; pub struct AppState { - agents: HashMap>>, + agents: DashMap>>, pub repository: AppRepository, + pub store: SharedStore, } impl AppState { - pub fn new(repository: AppRepository) -> Self { + pub fn new(repository: AppRepository, store: SharedStore) -> Self { Self { - agents: HashMap::new(), + agents: DashMap::new(), repository, + store, } } - pub fn insert_agent(&mut self, id: Uuid, agent: Agent) { + pub fn insert_agent(&self, id: Uuid, agent: Agent) { self.agents.insert(id, Arc::new(Mutex::new(agent))); } - pub fn remove_agent(&mut self, id: &Uuid) -> Option>> { - self.agents.remove(id) + pub fn remove_agent(&self, id: &Uuid) -> Option>> { + self.agents.remove(id).map(|(_, v)| v) } pub fn get_agent(&self, id: &Uuid) -> Option>> { - self.agents.get(id).cloned() + self.agents.get(id).map(|entry| entry.value().clone()) } } diff --git a/backend-v2/tests/common/mod.rs b/backend-v2/tests/common/mod.rs index 63d6c18..1bb88bc 100644 --- a/backend-v2/tests/common/mod.rs +++ b/backend-v2/tests/common/mod.rs @@ -5,11 +5,37 @@ use std::sync::Arc; use agent_k_backend::{repository, router, state::AppState}; use aide::openapi::OpenApi; +use ailoy::{agent::default_provider_mut, lang_model::LangModelProvider, tool::ToolProvider}; use axum::{body::Body, http::Request}; use http_body_util::BodyExt; -use tokio::sync::Mutex; +use speedwagon::Store; +use tokio::sync::RwLock; use tower::ServiceExt; +// ── Provider setup ──────────────────────────────────────────────────────────── + +/// Register all available API keys and basic builtin tools with the global +/// `default_provider`. Call this once per test after `dotenvy::dotenv().ok()`. +pub async fn setup_provider() { + let mut provider = default_provider_mut().await; + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + provider + .models + .insert("openai/*".into(), LangModelProvider::openai(key)); + } + if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { + provider + .models + .insert("anthropic/*".into(), LangModelProvider::anthropic(key)); + } + if let Ok(key) = std::env::var("GEMINI_API_KEY") { + provider + .models + .insert("google/*".into(), LangModelProvider::gemini(key)); + } + provider.tools = ToolProvider::new().bash().python_repl().web_search(); +} + // ── App / state creation ────────────────────────────────────────────────────── /// In-memory SQLite repository — state does not survive across instances. @@ -19,15 +45,25 @@ pub async fn make_repo() -> repository::AppRepository { .unwrap() } +/// Create a SharedStore + ToolSet backed by a temporary directory. +pub fn make_test_store() -> speedwagon::SharedStore { + let store_path = std::env::temp_dir().join(format!("speedwagon-test-{}", uuid::Uuid::new_v4())); + let store = Arc::new(RwLock::new( + Store::new(store_path).expect("test store init"), + )); + store +} + /// Build an app from an already-constructed repository. pub fn make_app_with_repo(repo: repository::AppRepository) -> axum::Router { - let state = Arc::new(Mutex::new(AppState::new(repo))); + let store = make_test_store(); + let state = Arc::new(AppState::new(repo, store)); make_app_with_state(state) } /// Build an app from an already-constructed state (useful when tests need to /// inspect the state directly, e.g. to read agent internals). -pub fn make_app_with_state(state: Arc>) -> axum::Router { +pub fn make_app_with_state(state: Arc) -> axum::Router { router::get_router(state).finish_api(&mut OpenApi::default()) } diff --git a/backend-v2/tests/e2e_test.rs b/backend-v2/tests/e2e_test.rs new file mode 100644 index 0000000..3a446a6 --- /dev/null +++ b/backend-v2/tests/e2e_test.rs @@ -0,0 +1,91 @@ +#[path = "common/mod.rs"] +mod common; + +use std::sync::Arc; + +use aide::openapi::OpenApi; +use ailoy::lang_model::LangModelProvider; + +use agent_k_backend::repository; +use agent_k_backend::router::get_router; +use agent_k_backend::state::AppState; +use ailoy::agent::default_provider_mut; +use common::{extract_text, post_session, send_message}; +use speedwagon::{FileType, Store, build_tools}; +use tokio::sync::RwLock; + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_ingest_message_purge_cycle() { + dotenvy::dotenv().ok(); + + let store_path = std::env::temp_dir().join(format!("speedwagon-e2e-{}", uuid::Uuid::new_v4())); + let store = Arc::new(RwLock::new( + Store::new(store_path).expect("test store init"), + )); + + { + let mut provider = default_provider_mut().await; + if let Ok(key) = std::env::var("OPENAI_API_KEY") { + provider + .models + .insert("openai/*".into(), LangModelProvider::openai(key)); + } + provider.tools = build_tools(store.clone()); + } + + let test_content = b"The capital of Freedonia is Glorkville. This is a unique fact."; + let doc_id = store + .write() + .await + .ingest(test_content.iter().copied(), FileType::MD) + .await + .expect("ingest failed"); + + let repo = repository::create_repository("sqlite::memory:") + .await + .expect("test repo init"); + let state = Arc::new(AppState::new(repo, store.clone())); + let app = get_router(state).finish_api(&mut OpenApi::default()); + + let session_id = post_session(&app).await; + + let outputs = send_message(&app, session_id, "What is the capital of Freedonia?").await; + let arr = outputs.as_array().expect("response must be an array"); + + assert!(!arr.is_empty(), "messages should not be empty"); + + let has_assistant = arr.iter().any(|o| { + o.get("message") + .and_then(|m| m.get("role")) + .and_then(|r| r.as_str()) + == Some("assistant") + }); + assert!( + has_assistant, + "should contain at least one assistant message" + ); + + let text = extract_text(&outputs); + assert!(!text.is_empty(), "assistant text should not be empty"); + assert!( + text.contains("Glorkville"), + "response should mention 'Glorkville' from the ingested document, got: {text}", + ); + + // Purge the document + store.write().await.purge(doc_id).expect("purge failed"); + + // Send same message after purge + let outputs = send_message( + &app, + session_id, + "What is the capital of Freedonia?", + ) + .await; + let post_purge_text = extract_text(&outputs); + assert!( + !post_purge_text.is_empty(), + "post-purge response should not be empty" + ); +} diff --git a/backend-v2/tests/message_history_persistence.rs b/backend-v2/tests/message_history_persistence.rs index 80c1479..eea4dbd 100644 --- a/backend-v2/tests/message_history_persistence.rs +++ b/backend-v2/tests/message_history_persistence.rs @@ -15,9 +15,9 @@ use agent_k_backend::{repository, state::AppState}; use common::{ SessionGuard, clear_message_history, clear_message_history_status, extract_text, get_message_history, get_message_history_status, make_app_with_repo, make_app_with_state, - make_repo, post_session, send_message, send_message_status, try_delete_session, + make_repo, make_test_store, post_session, send_message, send_message_status, setup_provider, + try_delete_session, }; -use tokio::sync::Mutex; use uuid::Uuid; // ── tests ───────────────────────────────────────────────────────────────────── @@ -35,9 +35,8 @@ use uuid::Uuid; async fn session_is_found_after_restart_via_lazy_create() { // Dummy key: validated only when agent.run() reaches the Anthropic API, // which this test never does (it only checks the HTTP status code). - unsafe { - std::env::set_var("ANTHROPIC_API_KEY", "dummy-key-for-lazy-create-test"); - } + dotenvy::dotenv().ok(); + setup_provider().await; let dir = tempfile::tempdir().unwrap(); let db_url = format!("sqlite://{}", dir.path().join("test.db").display()); @@ -84,6 +83,7 @@ async fn session_is_found_after_restart_via_lazy_create() { #[ignore = "requires microsandbox + ANTHROPIC_API_KEY"] async fn agent_restores_history_and_processes_message() { dotenvy::dotenv().ok(); + setup_provider().await; let dir = tempfile::tempdir().unwrap(); let db_url = format!("sqlite://{}", dir.path().join("test.db").display()); @@ -143,9 +143,8 @@ async fn agent_restores_history_and_processes_message() { /// guard is needed. #[tokio::test] async fn unknown_session_returns_404() { - unsafe { - std::env::set_var("ANTHROPIC_API_KEY", "dummy"); - } + dotenvy::dotenv().ok(); + setup_provider().await; let dir = tempfile::tempdir().unwrap(); let db_url = format!("sqlite://{}", dir.path().join("test.db").display()); @@ -168,17 +167,12 @@ async fn unknown_session_returns_404() { /// Creates the session directly in the DB to avoid the microsandbox requirement. #[tokio::test] async fn get_messages_returns_empty_for_new_session() { - let state = Arc::new(Mutex::new(AppState::new(make_repo().await))); + let store = make_test_store(); + let state = Arc::new(AppState::new(make_repo().await, store)); let app = make_app_with_state(state.clone()); let id = Uuid::new_v4(); - state - .lock() - .await - .repository - .create_session(id) - .await - .unwrap(); + state.repository.create_session(id).await.unwrap(); let messages = get_message_history(&app, id).await; assert_eq!( @@ -206,18 +200,18 @@ async fn get_messages_returns_404_for_unknown_session() { async fn get_messages_returns_persisted_messages_in_order() { use ailoy::message::{Message, Part, Role}; - let state = Arc::new(Mutex::new(AppState::new(make_repo().await))); + let store = make_test_store(); + let state = Arc::new(AppState::new(make_repo().await, store)); let app = make_app_with_state(state.clone()); let id = Uuid::new_v4(); { - let st = state.lock().await; - st.repository.create_session(id).await.unwrap(); + state.repository.create_session(id).await.unwrap(); let msgs = vec![ Message::new(Role::User).with_contents([Part::text("first")]), Message::new(Role::Assistant).with_contents([Part::text("second")]), ]; - st.repository.append_messages(id, &msgs).await.unwrap(); + state.repository.append_messages(id, &msgs).await.unwrap(); } let body = get_message_history(&app, id).await; @@ -251,18 +245,18 @@ async fn clear_messages_returns_404_for_unknown_session() { async fn clear_messages_removes_persisted_messages() { use ailoy::message::{Message, Part, Role}; - let state = Arc::new(Mutex::new(AppState::new(make_repo().await))); + let store = make_test_store(); + let state = Arc::new(AppState::new(make_repo().await, store)); let app = make_app_with_state(state.clone()); let id = Uuid::new_v4(); { - let st = state.lock().await; - st.repository.create_session(id).await.unwrap(); + state.repository.create_session(id).await.unwrap(); let msgs = vec![ Message::new(Role::User).with_contents([Part::text("hello")]), Message::new(Role::Assistant).with_contents([Part::text("world")]), ]; - st.repository.append_messages(id, &msgs).await.unwrap(); + state.repository.append_messages(id, &msgs).await.unwrap(); } let before = get_message_history(&app, id).await; @@ -287,15 +281,15 @@ async fn clear_messages_removes_persisted_messages() { async fn clear_messages_does_not_delete_session() { use ailoy::message::{Message, Part, Role}; - let state = Arc::new(Mutex::new(AppState::new(make_repo().await))); + let store = make_test_store(); + let state = Arc::new(AppState::new(make_repo().await, store)); let app = make_app_with_state(state.clone()); let id = Uuid::new_v4(); { - let st = state.lock().await; - st.repository.create_session(id).await.unwrap(); + state.repository.create_session(id).await.unwrap(); let msgs = vec![Message::new(Role::User).with_contents([Part::text("ping")])]; - st.repository.append_messages(id, &msgs).await.unwrap(); + state.repository.append_messages(id, &msgs).await.unwrap(); } clear_message_history(&app, id).await; @@ -314,23 +308,22 @@ async fn clear_messages_does_not_delete_session() { async fn can_append_messages_after_clear() { use ailoy::message::{Message, Part, Role}; - let state = Arc::new(Mutex::new(AppState::new(make_repo().await))); + let store = make_test_store(); + let state = Arc::new(AppState::new(make_repo().await, store)); let app = make_app_with_state(state.clone()); let id = Uuid::new_v4(); { - let st = state.lock().await; - st.repository.create_session(id).await.unwrap(); + state.repository.create_session(id).await.unwrap(); let msgs = vec![Message::new(Role::User).with_contents([Part::text("old")])]; - st.repository.append_messages(id, &msgs).await.unwrap(); + state.repository.append_messages(id, &msgs).await.unwrap(); } clear_message_history(&app, id).await; { - let st = state.lock().await; let msgs = vec![Message::new(Role::User).with_contents([Part::text("new")])]; - st.repository.append_messages(id, &msgs).await.unwrap(); + state.repository.append_messages(id, &msgs).await.unwrap(); } let body = get_message_history(&app, id).await; @@ -348,9 +341,8 @@ async fn can_append_messages_after_clear() { #[tokio::test(flavor = "multi_thread")] #[ignore = "requires microsandbox runtime"] async fn clear_messages_also_clears_in_memory_agent_history() { - unsafe { - std::env::set_var("ANTHROPIC_API_KEY", "dummy-key-for-clear-history-test"); - } + dotenvy::dotenv().ok(); + setup_provider().await; let dir = tempfile::tempdir().unwrap(); let db_url = format!("sqlite://{}", dir.path().join("test.db").display()); diff --git a/backend-v2/tests/sandbox_per_session.rs b/backend-v2/tests/sandbox_per_session.rs index b57f619..e2ea0fb 100644 --- a/backend-v2/tests/sandbox_per_session.rs +++ b/backend-v2/tests/sandbox_per_session.rs @@ -14,14 +14,14 @@ use std::sync::Arc; use agent_k_backend::state::AppState; use common::{ delete_session, extract_text, extract_text_from_slice, make_app_with_state, make_repo, - post_session, send_message, send_message_stream, + make_test_store, post_session, send_message, send_message_stream, setup_provider, }; -use tokio::sync::Mutex; // ── helpers ─────────────────────────────────────────────────────────────────── -async fn make_state() -> Arc> { - Arc::new(Mutex::new(AppState::new(make_repo().await))) +async fn make_state() -> Arc { + let store = make_test_store(); + Arc::new(AppState::new(make_repo().await, store)) } // ── tests ───────────────────────────────────────────────────────────────────── @@ -34,11 +34,8 @@ async fn make_state() -> Arc> { #[tokio::test] #[ignore = "requires microsandbox; boots two VMs"] async fn two_sessions_get_isolated_sandboxes() { - if std::env::var("ANTHROPIC_API_KEY").is_err() { - unsafe { - std::env::set_var("ANTHROPIC_API_KEY", "dummy"); - } - } + dotenvy::dotenv().ok(); + setup_provider().await; let state = make_state().await; let app = make_app_with_state(state.clone()); @@ -48,9 +45,8 @@ async fn two_sessions_get_isolated_sandboxes() { assert_ne!(id1, id2, "two sessions must have different ids"); let (re1, re2) = { - let st = state.lock().await; - let a1 = st.get_agent(&id1).expect("session 1 not found"); - let a2 = st.get_agent(&id2).expect("session 2 not found"); + let a1 = state.get_agent(&id1).expect("session 1 not found"); + let a2 = state.get_agent(&id2).expect("session 2 not found"); // Agents are not running now, so try_lock succeeds. let guard1 = a1.try_lock().expect("agent 1 locked unexpectedly"); let guard2 = a2.try_lock().expect("agent 2 locked unexpectedly"); @@ -90,6 +86,7 @@ async fn two_sessions_get_isolated_sandboxes() { #[ignore = "requires microsandbox + ANTHROPIC_API_KEY"] async fn agent_writes_and_reads_file_via_bash_in_sandbox() { dotenvy::dotenv().ok(); + setup_provider().await; let state = make_state().await; let app = make_app_with_state(state.clone()); @@ -112,7 +109,7 @@ async fn agent_writes_and_reads_file_via_bash_in_sandbox() { ); // Verify via runenv directly that the file exists in the sandbox. - let agent_arc = state.lock().await.get_agent(&id).unwrap(); + let agent_arc = state.get_agent(&id).unwrap(); let agent = agent_arc.lock().await; let contents = agent .state @@ -138,11 +135,8 @@ async fn stream_returns_404_for_unknown_session() { use axum::{body::Body, http::Request}; use tower::ServiceExt; - if std::env::var("ANTHROPIC_API_KEY").is_err() { - unsafe { - std::env::set_var("ANTHROPIC_API_KEY", "dummy"); - } - } + dotenvy::dotenv().ok(); + setup_provider().await; let state = make_state().await; let app = make_app_with_state(state); @@ -168,6 +162,7 @@ async fn stream_returns_404_for_unknown_session() { #[ignore = "requires microsandbox + ANTHROPIC_API_KEY"] async fn agent_writes_and_reads_file_via_bash_streaming() { dotenvy::dotenv().ok(); + setup_provider().await; let state = make_state().await; let app = make_app_with_state(state.clone()); @@ -195,7 +190,7 @@ async fn agent_writes_and_reads_file_via_bash_streaming() { ); // Verify the file persisted in the sandbox after the stream ended. - let agent_arc = state.lock().await.get_agent(&id).unwrap(); + let agent_arc = state.get_agent(&id).unwrap(); let agent = agent_arc.lock().await; let contents = agent .state diff --git a/speedwagon/src/agent.rs b/speedwagon/src/agent.rs index d95fc1f..d312f2b 100644 --- a/speedwagon/src/agent.rs +++ b/speedwagon/src/agent.rs @@ -1,4 +1,4 @@ -use ailoy::agent::{Agent, AgentCard, AgentProvider, AgentSpec}; +use ailoy::agent::{AgentCard, AgentSpec}; pub const SYSTEM_PROMPT: &str = r#"You are an expert research assistant. Your task is to answer questions by systematically searching through a document corpus using the provided tools. Think step by step. @@ -62,17 +62,6 @@ impl SpeedwagonSpec { pub fn into_spec(self) -> AgentSpec { self.into() } - - pub async fn into_runtime(self) -> anyhow::Result { - Agent::try_new(self.spec).await - } - - pub async fn into_runtime_with_provider( - self, - provider: &AgentProvider, - ) -> anyhow::Result { - Agent::try_with_provider(self.spec, provider).await - } } impl Default for SpeedwagonSpec { diff --git a/speedwagon/src/main.rs b/speedwagon/src/main.rs index df9cc76..8e7847d 100644 --- a/speedwagon/src/main.rs +++ b/speedwagon/src/main.rs @@ -14,15 +14,18 @@ use std::{ sync::Arc, }; +use tokio::sync::RwLock; + use ailoy::{ agent::{Agent, AgentProvider}, + lang_model::LangModelProvider, message::{Message, Part, Role}, }; use anyhow::Result; use clap::Parser; use futures::StreamExt; use rustyline::{DefaultEditor, error::ReadlineError}; -use speedwagon::{FileType, SpeedwagonSpec, Store, build_toolset}; +use speedwagon::{FileType, SharedStore, SpeedwagonSpec, Store, build_tools}; use speedwagon::preset::{PresetKind, setup_docset}; @@ -53,10 +56,12 @@ fn resolve_dir(path: &str) -> PathBuf { } async fn build_agent(store_dir: &Path, model: &str, provider: &AgentProvider) -> Result { - let store = Arc::new(Store::new(store_dir)?); - let toolset = build_toolset(store); + let store: SharedStore = Arc::new(RwLock::new(Store::new(store_dir)?)); + let spec = SpeedwagonSpec::new().model(model).into_spec(); - Agent::try_with_tools(spec, provider, &toolset).await + let mut provider = provider.clone(); + provider.tools = build_tools(store); + Agent::try_with_provider(spec, &provider).await } async fn run_query(agent: &mut Agent, input: &str) -> Result<()> { @@ -114,15 +119,17 @@ async fn main() -> Result<()> { } let mut provider = AgentProvider::new(); + let mut model_provider = LangModelProvider::new(); if let Ok(key) = std::env::var("OPENAI_API_KEY") { - provider.model_openai(key); + model_provider.insert("openai/*".into(), LangModelProvider::openai(key)); } if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") { - provider.model_claude(key); + model_provider.insert("anthropic/*".into(), LangModelProvider::anthropic(key)); } if let Ok(key) = std::env::var("GEMINI_API_KEY") { - provider.model_gemini(key); + model_provider.insert("google/*".into(), LangModelProvider::gemini(key)); } + provider.models = model_provider; let mut agent = build_agent(&store_dir, &cli.model, &provider).await?; let doc_count = Store::new(&store_dir)?.count(); diff --git a/speedwagon/src/store/mod.rs b/speedwagon/src/store/mod.rs index 9aa10c7..acc6636 100644 --- a/speedwagon/src/store/mod.rs +++ b/speedwagon/src/store/mod.rs @@ -9,8 +9,11 @@ mod translator; use std::{ fs, path::{Path, PathBuf}, + sync::Arc, }; +use tokio::sync::RwLock; + use anyhow::{Context as _, Result}; use tantivy::Index; use uuid::Uuid; @@ -18,6 +21,8 @@ use uuid::Uuid; pub use document::{Document, FindResult}; pub use searcher::{SearchPage, SearchResult}; +pub type SharedStore = Arc>; + /// Speedwagon store layout: /// /// ```text diff --git a/speedwagon/src/store/parser.rs b/speedwagon/src/store/parser.rs index 6a32296..867f694 100644 --- a/speedwagon/src/store/parser.rs +++ b/speedwagon/src/store/parser.rs @@ -1,5 +1,6 @@ use ailoy::{ agent::{Agent, AgentProvider, AgentSpec}, + lang_model::LangModelProvider, message::{Message, Part, Role}, }; use anyhow::{Context as _, Result}; @@ -86,11 +87,17 @@ pub async fn get_title(content: &str) -> Result { None => { dotenvy::dotenv().ok(); - let mut provier = AgentProvider::new(); - provier.model_openai( - std::env::var("OPENAI_API_KEY").context("OPENAI_API_KEY not set in environment")?, + let mut provider = AgentProvider::new(); + let mut model_provider = LangModelProvider::new(); + model_provider.insert( + "openai/*".into(), + LangModelProvider::openai( + std::env::var("OPENAI_API_KEY") + .context("OPENAI_API_KEY not set in environment")?, + ), ); - TitleAgent::new(Some(provier)).generate(content).await + provider.models = model_provider; + TitleAgent::new(Some(provider)).generate(content).await } } } diff --git a/speedwagon/src/tool/calculate.rs b/speedwagon/src/tool/calculate.rs index e9bdcf9..220dc19 100644 --- a/speedwagon/src/tool/calculate.rs +++ b/speedwagon/src/tool/calculate.rs @@ -356,7 +356,10 @@ fn eval_func(name: &str, args: &[f64]) -> Result { } Ok(args[0].ln() / args[1].ln()) } - _ => Err(format!("log() expects 1 or 2 arguments, got {}", args.len())), + _ => Err(format!( + "log() expects 1 or 2 arguments, got {}", + args.len() + )), }, "log10" => { ensure_args(name, args, 1)?; @@ -420,13 +423,19 @@ fn eval_func(name: &str, args: &[f64]) -> Result { } "min" => { if args.len() < 2 { - return Err(format!("min() requires at least 2 arguments, got {}", args.len())); + return Err(format!( + "min() requires at least 2 arguments, got {}", + args.len() + )); } Ok(args.iter().cloned().fold(f64::INFINITY, f64::min)) } "max" => { if args.len() < 2 { - return Err(format!("max() requires at least 2 arguments, got {}", args.len())); + return Err(format!( + "max() requires at least 2 arguments, got {}", + args.len() + )); } Ok(args.iter().cloned().fold(f64::NEG_INFINITY, f64::max)) } diff --git a/speedwagon/src/tool/find.rs b/speedwagon/src/tool/find.rs index 781f0d7..b4cd2a8 100644 --- a/speedwagon/src/tool/find.rs +++ b/speedwagon/src/tool/find.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use ailoy::{ datatype::Value, message::ToolDescBuilder, @@ -8,7 +6,7 @@ use ailoy::{ }; use uuid::Uuid; -use crate::store::{FindResult, Store}; +use crate::store::{FindResult, SharedStore}; fn result_to_value(result: &FindResult) -> Value { let matches: Vec = result @@ -33,7 +31,7 @@ fn result_to_value(result: &FindResult) -> Value { }) } -pub fn build_find_in_document_tool(store: Arc) -> ToolFactory { +pub fn build_find_in_document_tool(store: SharedStore) -> ToolFactory { let desc = ToolDescBuilder::new("find_in_document") .description(concat!( "Find all occurrences of a regex pattern within a document. ", @@ -137,7 +135,8 @@ pub fn build_find_in_document_tool(store: Arc) -> ToolFactory { .unwrap_or(256) .max(0) as usize; - match store.find(id, &pattern, cursor, k, context_bytes) { + let guard = store.read().await; + match guard.find(id, &pattern, cursor, k, context_bytes) { Some(result) => result_to_value(&result), None => to_value!({"error": format!("document not found: {id_str}")}), } diff --git a/speedwagon/src/tool/mod.rs b/speedwagon/src/tool/mod.rs index d8fb1a8..05c6a94 100644 --- a/speedwagon/src/tool/mod.rs +++ b/speedwagon/src/tool/mod.rs @@ -3,25 +3,19 @@ mod find; mod read; mod search; -use std::sync::Arc; - -use ailoy::tool::ToolSet; +use ailoy::tool::ToolProvider; pub use calculate::*; pub use find::*; pub use read::*; pub use search::*; -use crate::store::Store; - -pub fn build_toolset(store: Arc) -> ToolSet { - let mut toolset = ToolSet::new(); +use crate::store::SharedStore; - toolset.insert("search_document", make_search_document_tool(store.clone())); - toolset.insert( - "find_in_document", - build_find_in_document_tool(store.clone()), - ); - toolset.insert("read_document", build_read_document_tool(store.clone())); - toolset.insert("calculate", build_calculate_tool()); - toolset +pub fn build_tools(store: SharedStore) -> ToolProvider { + let mut provider = ToolProvider::new().bash().python_repl().web_search(); + provider = provider.custom(make_search_document_tool(store.clone())); + provider = provider.custom(build_find_in_document_tool(store.clone())); + provider = provider.custom(build_read_document_tool(store.clone())); + provider = provider.custom(build_calculate_tool()); + provider } diff --git a/speedwagon/src/tool/read.rs b/speedwagon/src/tool/read.rs index 7ce63a5..a8ada44 100644 --- a/speedwagon/src/tool/read.rs +++ b/speedwagon/src/tool/read.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use ailoy::{ datatype::Value, message::ToolDescBuilder, @@ -8,9 +6,9 @@ use ailoy::{ }; use uuid::Uuid; -use crate::store::Store; +use crate::store::SharedStore; -pub fn build_read_document_tool(store: Arc) -> ToolFactory { +pub fn build_read_document_tool(store: SharedStore) -> ToolFactory { let desc = ToolDescBuilder::new("read_document") .description(concat!( "Read a byte range of a document's content. ", @@ -61,7 +59,8 @@ pub fn build_read_document_tool(store: Arc) -> ToolFactory { None => return to_value!({"error": "missing required parameter: len"}), }; - match store.read(id, offset, len) { + let guard = store.read().await; + match guard.read(id, offset, len) { Some(content) => Value::from(content), None => to_value!({"error": format!("document not found: {id_str}")}), } diff --git a/speedwagon/src/tool/search.rs b/speedwagon/src/tool/search.rs index c08633e..95ba444 100644 --- a/speedwagon/src/tool/search.rs +++ b/speedwagon/src/tool/search.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use ailoy::{ datatype::Value, message::ToolDescBuilder, @@ -7,7 +5,7 @@ use ailoy::{ tool::{ToolFactory, ToolFunc}, }; -use crate::store::{SearchPage, Store}; +use crate::store::{SearchPage, SharedStore}; fn result_to_value(page: &SearchPage) -> Value { let results: Vec = page @@ -26,7 +24,7 @@ fn result_to_value(page: &SearchPage) -> Value { Value::Array(results) } -pub fn make_search_document_tool(store: Arc) -> ToolFactory { +pub fn make_search_document_tool(store: SharedStore) -> ToolFactory { let desc = ToolDescBuilder::new("search_document") .description(concat!( "Search for relevant documents for a given query. ", @@ -102,7 +100,8 @@ pub fn make_search_document_tool(store: Arc) -> ToolFactory { .unwrap_or(10) .max(1) as u32; - match store.search(&query, page, page_size) { + let guard = store.read().await; + match guard.search(&query, page, page_size) { Ok(output) => result_to_value(&output), Err(e) => to_value!({"error": e.to_string()}), }