Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions speedwagon/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{
};

use ailoy::{
agent::{Agent, AgentProvider},
agent::{Agent, default_provider, default_provider_mut},
message::{Message, Part, Role},
};
use anyhow::Result;
Expand Down Expand Up @@ -52,11 +52,12 @@ fn resolve_dir(path: &str) -> PathBuf {
}
}

async fn build_agent(store_dir: &Path, model: &str, provider: &AgentProvider) -> Result<Agent> {
async fn build_agent(store_dir: &Path, model: &str) -> Result<Agent> {
let store = Arc::new(Store::new(store_dir)?);
let toolset = build_toolset(store);
let spec = SpeedwagonSpec::new().model(model).into_spec();
Agent::try_with_tools(spec, provider, &toolset).await
let provider = default_provider().await;
Agent::try_with_tools(spec, &provider, &toolset).await
}

async fn run_query(agent: &mut Agent, input: &str) -> Result<()> {
Expand Down Expand Up @@ -108,23 +109,28 @@ async fn main() -> Result<()> {
}
};

// Populate ailoy's process-global provider once at boot. Every
// `Agent::try_new`/`try_with_tools` (including speedwagon's title/purpose/
// description helpers) reads from this singleton.
{
let mut default = default_provider_mut().await;
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
default.model_openai(key);
}
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
default.model_claude(key);
}
if let Ok(key) = std::env::var("GEMINI_API_KEY") {
default.model_gemini(key);
}
}

if let Some(ref preset) = cli.preset {
let mut store = Store::new(&store_dir)?;
setup_docset(&mut store, preset).await?;
}

let mut provider = AgentProvider::new();
if let Ok(key) = std::env::var("OPENAI_API_KEY") {
provider.model_openai(key);
}
if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
provider.model_claude(key);
}
if let Ok(key) = std::env::var("GEMINI_API_KEY") {
provider.model_gemini(key);
}

let mut agent = build_agent(&store_dir, &cli.model, &provider).await?;
let mut agent = build_agent(&store_dir, &cli.model).await?;
let doc_count = Store::new(&store_dir)?.count();

println!();
Expand Down Expand Up @@ -154,7 +160,7 @@ async fn main() -> Result<()> {
if input == "/exit" {
break;
} else if input == "/clear" {
agent = build_agent(&store_dir, &cli.model, &provider).await?;
agent = build_agent(&store_dir, &cli.model).await?;
println!("Conversation cleared.");
} else if input == "/list" {
let store = Store::new(&store_dir)?;
Expand Down Expand Up @@ -191,7 +197,7 @@ async fn main() -> Result<()> {
let mut write_store = Store::new(&store_dir)?;
let id = write_store.ingest(bytes, filetype).await?;
drop(write_store);
agent = build_agent(&store_dir, &cli.model, &provider).await?;
agent = build_agent(&store_dir, &cli.model).await?;
println!("Ingested (id: {id}) — agent rebuilt.");
} else if let Some(id_str) = input.strip_prefix("/purge ") {
let id_str = id_str.trim();
Expand All @@ -217,7 +223,7 @@ async fn main() -> Result<()> {
match write_store.purge(id)? {
Some(doc) => {
drop(write_store);
agent = build_agent(&store_dir, &cli.model, &provider).await?;
agent = build_agent(&store_dir, &cli.model).await?;
println!("Purged '{}' — agent rebuilt.", doc.title);
}
None => eprintln!("Document not found: {id}"),
Expand Down
194 changes: 53 additions & 141 deletions speedwagon/src/store/description.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
//! Korean output lands ~1/3 the chars of English at the same budget;
//! per-language budgets are deferred until a near-domain Korean KB shows up.

use ailoy::{
agent::{Agent, AgentProvider, AgentSpec},
message::{Message, Part, Role},
};
use anyhow::{Context as _, Result};
use futures::StreamExt as _;
use ailoy::message::{Message, Part, Role};
use anyhow::Result;

const MODEL: &str = "openai/gpt-5.4-mini";
use super::helper::HelperAgent;

const DESCRIPTION_INSTRUCTION: &str = concat!(
"You write a self-contained description of a knowledge base. ",
Expand All @@ -27,60 +23,37 @@ const DESCRIPTION_INSTRUCTION: &str = concat!(
"or any metadata about how this knowledge base was assembled. ",
"Describe ONLY what documents are inside, as if a curator wrote it. ",
"Write the description in English regardless of the document language. ",
"Length: ~200 characters. Output a JSON object: {\"description\": \"<text>\"}."
"Length: ~200 characters. Output a JSON object: {\"result\": \"<string>\"}."
);

pub struct DescriptionAgent {
spec: AgentSpec,
provider: Option<AgentProvider>,
/// Borrowed input for `DescriptionAgent::generate`.
struct DescriptionInput<'a> {
kb_name: &'a str,
instruction: Option<&'a str>,
docs: &'a [(&'a str, &'a str)],
}

impl DescriptionAgent {
pub fn new(provider: Option<AgentProvider>) -> Self {
Self {
spec: AgentSpec::new(MODEL).instruction(DESCRIPTION_INSTRUCTION),
provider,
}
}
/// Generates a KB-level routing description via LLM. Reads from ailoy's
/// process-global default provider.
struct DescriptionAgent;

/// Only `purpose` is sent to the LLM; `title` is kept in the signature
/// so the caller can feed the same slice to `fallback_description`.
pub async fn generate(
&self,
kb_name: &str,
instruction: Option<&str>,
docs: &[(&str, &str)],
) -> Result<String> {
let user = build_user_message(kb_name, instruction, docs);
let query = Message::new(Role::User).with_contents([Part::text(user)]);
impl HelperAgent for DescriptionAgent {
type Input<'a> = DescriptionInput<'a>;
type Output = String;
const INSTRUCTION: &'static str = DESCRIPTION_INSTRUCTION;

let mut agent = match &self.provider {
Some(provider) => Agent::try_with_provider(self.spec.clone(), provider).await?,
None => Agent::try_new(self.spec.clone()).await?,
};
fn build_query(input: &DescriptionInput<'_>) -> Message {
let user = build_user_message(input.kb_name, input.instruction, input.docs);
Message::new(Role::User).with_contents([Part::text(user)])
}

let mut text_parts: Vec<String> = Vec::new();
{
let mut stream = agent.run(query);
while let Some(result) = stream.next().await {
let output = result?;
for part in &output.message.contents {
if let Some(text) = part.as_text() {
text_parts.push(text.to_string());
}
}
}
}
let raw = text_parts.join("");
Ok(parse_description_response(&raw))
fn fallback(input: &DescriptionInput<'_>) -> String {
let titles: Vec<&str> = input.docs.iter().map(|(t, _)| *t).collect();
fallback_description(input.docs.len(), &titles)
}
}

fn build_user_message(
kb_name: &str,
instruction: Option<&str>,
docs: &[(&str, &str)],
) -> String {
fn build_user_message(kb_name: &str, instruction: Option<&str>, docs: &[(&str, &str)]) -> String {
let mut s = String::new();
s.push_str(&format!("KB name: {kb_name}\n"));
if let Some(instr) = instruction {
Expand All @@ -90,74 +63,26 @@ fn build_user_message(
}
s.push_str(&format!("\nDocuments ({}):\n", docs.len()));
for (_title, purpose) in docs {
let p = if purpose.is_empty() { "(no purpose)" } else { *purpose };
let p = if purpose.is_empty() {
"(no purpose)"
} else {
*purpose
};
s.push_str(&format!("- {p}\n"));
}
s
}

/// Empty return signals the caller to use `fallback_description`.
fn parse_description_response(raw: &str) -> String {
let trimmed = raw.trim();
if trimmed.is_empty() {
return String::new();
}

if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
if let Some(d) = value.get("description").and_then(|v| v.as_str()) {
return d.trim().to_string();
}
}

if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
if start < end {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&trimmed[start..=end]) {
if let Some(d) = value.get("description").and_then(|v| v.as_str()) {
return d.trim().to_string();
}
}
}
}

String::new()
}

/// Reads `OPENAI_API_KEY` from the environment, runs `DescriptionAgent`, and
/// substitutes `fallback_description` if the LLM body is empty. Transport
/// errors (missing key, network) propagate.
pub async fn get_description(
kb_name: &str,
instruction: Option<&str>,
docs: &[(&str, &str)],
) -> Result<String> {
dotenvy::dotenv().ok();

let mut provider = AgentProvider::new();
provider.model_openai(
std::env::var("OPENAI_API_KEY").context("OPENAI_API_KEY not set in environment")?,
);

let agent = DescriptionAgent::new(Some(provider));
let result = agent.generate(kb_name, instruction, docs).await?;
if result.is_empty() {
log::warn!("description generation returned empty string; using fallback");
let titles: Vec<&str> = docs.iter().map(|(t, _)| *t).collect();
Ok(fallback_description(docs.len(), &titles))
} else {
Ok(result)
}
}

/// Deterministic fallback when the LLM call fails or returns empty.
pub fn fallback_description(doc_count: usize, top_titles: &[&str]) -> String {
fn fallback_description(doc_count: usize, top_titles: &[&str]) -> String {
if doc_count == 0 {
return String::new();
}
let titles: Vec<&str> = top_titles
.iter()
.copied()
.filter(|t| !t.is_empty())
.take(5)
.copied()
.collect();
if titles.is_empty() {
format!("{doc_count} documents")
Expand All @@ -166,45 +91,26 @@ pub fn fallback_description(doc_count: usize, top_titles: &[&str]) -> String {
}
}

/// Runs `DescriptionAgent` over the index's `(title, purpose)` slice. An
/// empty/malformed LLM response is substituted with `fallback_description`
/// (a deterministic count + top-titles string). Transport errors propagate.
pub(super) async fn get_description(
kb_name: &str,
instruction: Option<&str>,
docs: &[(&str, &str)],
) -> Result<String> {
DescriptionAgent::generate(DescriptionInput {
kb_name,
instruction,
docs,
})
.await
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn parse_direct_json() {
let raw = r#"{"description": "hello"}"#;
assert_eq!(parse_description_response(raw), "hello");
}

#[test]
fn parse_json_with_surrounding_text() {
let raw = r#"Here you go: {"description": "hello"} done."#;
assert_eq!(parse_description_response(raw), "hello");
}

#[test]
fn parse_trims_inner_whitespace() {
let raw = r#"{"description": " hello "}"#;
assert_eq!(parse_description_response(raw), "hello");
}

#[test]
fn parse_empty_input() {
assert_eq!(parse_description_response(""), "");
assert_eq!(parse_description_response(" "), "");
}

#[test]
fn parse_missing_field() {
let raw = r#"{"other": "value"}"#;
assert_eq!(parse_description_response(raw), "");
}

#[test]
fn parse_malformed_json() {
assert_eq!(parse_description_response("{not json"), "");
}

#[test]
fn fallback_zero_docs() {
assert_eq!(fallback_description(0, &[]), "");
Expand Down Expand Up @@ -292,6 +198,12 @@ mod tests {
std::fs::create_dir_all(root.join("origin")).unwrap();
std::fs::create_dir_all(root.join("corpus")).unwrap();

// Populate ailoy's process-global default provider for this test.
dotenvy::dotenv().ok();
ailoy::agent::default_provider_mut().await.model_openai(
std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required for this test"),
);

let store = crate::store::Store::new(root).expect("open store");
let description = store
.describe("finance", Some("public-company financial filings"))
Expand Down
Loading