diff --git a/examples/Cargo.toml b/examples/Cargo.toml index f6fe43167e..2cb501e568 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -159,6 +159,10 @@ path = "usage_metrics.rs" name = "tasks" path = "tasks.rs" +[[example]] +name = "stop-with-args-custom-schema" +path = "stop_with_args_custom_schema.rs" + [[example]] name = "responses-api" path = "responses_api.rs" diff --git a/examples/stop_with_args_custom_schema.rs b/examples/stop_with_args_custom_schema.rs new file mode 100644 index 0000000000..0f59797c4a --- /dev/null +++ b/examples/stop_with_args_custom_schema.rs @@ -0,0 +1,91 @@ +//! Demonstrates how to plug a custom JSON schema into the stop tool for an OpenAI-powered agent. +//! +//! Set the `OPENAI_API_KEY` environment variable before running the example. The agent guides the +//! model to call the `stop` tool with a structured payload that matches the schema defined below. +//! The on-stop hook prints the structured payload that made the agent stop. +use anyhow::Result; +use schemars::{JsonSchema, Schema, schema_for}; +use serde::{Deserialize, Serialize}; +use serde_json::to_string_pretty; +use swiftide::agents::tools::control::StopWithArgs; +use swiftide::agents::{Agent, StopReason}; +use swiftide::traits::Tool; + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +enum TaskStatus { + Succeeded, + Failed, + Cancelled, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[schemars(deny_unknown_fields)] +struct StopPayload { + status: TaskStatus, + summary: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + details: Option, +} + +fn stop_schema() -> Schema { + schema_for!(StopPayload) +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + + let schema = stop_schema(); + let stop_tool = StopWithArgs::with_parameters_schema(schema.clone()); + + println!( + "stop tool schema:\n{}", + to_string_pretty(&stop_tool.tool_spec())?, + ); + + let openai = swiftide::integrations::openai::OpenAI::builder() + .default_prompt_model("gpt-4o-mini") + .default_embed_model("text-embedding-3-small") + .build()?; + + let mut builder = Agent::builder(); + builder + .llm(&openai) + .without_default_stop_tool() + .tools([stop_tool.clone()]) + .on_stop(|_, reason, _| { + Box::pin(async move { + if let StopReason::RequestedByTool(_, payload) = reason + && let Some(payload) = payload + { + println!( + "agent stopped with structured payload:\n{}", + to_string_pretty(&payload).unwrap_or_else(|_| payload.to_string()), + ); + } + Ok(()) + }) + }); + + if let Some(prompt) = builder.system_prompt_mut() { + prompt + .with_role("Workflow finisher") + .with_guidelines([ + "Summarize the work that was just completed and recommend next actions.", + "When you are done, call the `stop` tool using the provided JSON schema.", + "Always include the `details` field; use null when there is nothing to add.", + ]) + .with_constraints(["Never fabricate task status values outside the schema."]); + } + + let mut agent = builder.build()?; + + agent + .query_once( + "You completed onboarding five merchants today. Prepare a final handoff report and stop.", + ) + .await?; + + Ok(()) +} diff --git a/swiftide-agents/src/state.rs b/swiftide-agents/src/state.rs index 12ee4fc3b2..4ef2919e09 100644 --- a/swiftide-agents/src/state.rs +++ b/swiftide-agents/src/state.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use serde::{Deserialize, Serialize}; +use serde_json::Value; use swiftide_core::chat_completion::ToolCall; #[derive(Clone, Debug, Default, strum_macros::EnumDiscriminants, strum_macros::EnumIs)] @@ -29,7 +30,7 @@ impl State { #[derive(Clone, Debug, strum_macros::EnumIs, PartialEq, Serialize, Deserialize)] pub enum StopReason { /// A tool called stop - RequestedByTool(ToolCall, Option>), + RequestedByTool(ToolCall, Option), /// Agent failed to complete with optional message AgentFailed(Option>), @@ -52,9 +53,9 @@ pub enum StopReason { } impl StopReason { - pub fn as_requested_by_tool(&self) -> Option<(&ToolCall, Option<&str>)> { + pub fn as_requested_by_tool(&self) -> Option<(&ToolCall, Option<&Value>)> { if let StopReason::RequestedByTool(t, message) = self { - Some((t, message.as_deref())) + Some((t, message.as_ref())) } else { None } diff --git a/swiftide-agents/src/tools/control.rs b/swiftide-agents/src/tools/control.rs index 4266c96245..22577e70e9 100644 --- a/swiftide-agents/src/tools/control.rs +++ b/swiftide-agents/src/tools/control.rs @@ -1,7 +1,7 @@ //! Control tools manage control flow during agent's lifecycle. use anyhow::Result; use async_trait::async_trait; -use schemars::schema_for; +use schemars::{Schema, schema_for}; use std::borrow::Cow; use swiftide_core::{ AgentContext, ToolFeedback, @@ -42,11 +42,42 @@ impl From for Box { } /// `StopWithArgs` is an alternative stop tool that takes arguments -#[derive(Clone, Debug, Default)] -pub struct StopWithArgs {} +#[derive(Clone, Debug)] +pub struct StopWithArgs { + parameters_schema: Option, + expects_output_field: bool, +} + +impl Default for StopWithArgs { + fn default() -> Self { + Self { + parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)), + expects_output_field: true, + } + } +} + +impl StopWithArgs { + /// Create a new `StopWithArgs` tool with a custom parameters schema. + /// + /// When providing a custom schema the full argument payload will be forwarded to the + /// stop output without requiring an `output` field wrapper. + pub fn with_parameters_schema(schema: Schema) -> Self { + Self { + parameters_schema: Some(schema), + expects_output_field: false, + } + } + + fn parameters_schema(&self) -> Schema { + self.parameters_schema + .clone() + .unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec)) + } +} #[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)] -struct StopWithArgsSpec { +struct DefaultStopWithArgsSpec { pub output: String, } @@ -57,13 +88,21 @@ impl Tool for StopWithArgs { _agent_context: &dyn AgentContext, tool_call: &ToolCall, ) -> Result { - let args: StopWithArgsSpec = serde_json::from_str( - tool_call - .args() - .ok_or(ToolError::missing_arguments("output"))?, - )?; + let raw_args = tool_call + .args() + .ok_or_else(|| ToolError::missing_arguments("arguments"))?; + + let json: serde_json::Value = serde_json::from_str(raw_args)?; - Ok(ToolOutput::stop_with_args(args.output)) + let output = if self.expects_output_field { + json.get("output") + .cloned() + .ok_or_else(|| ToolError::missing_arguments("output"))? + } else { + json + }; + + Ok(ToolOutput::stop_with_args(output)) } fn name(&self) -> Cow<'_, str> { @@ -71,10 +110,12 @@ impl Tool for StopWithArgs { } fn tool_spec(&self) -> ToolSpec { + let schema = self.parameters_schema(); + ToolSpec::builder() .name("stop") .description("When you have completed, your task, call this with your expected output") - .parameters_schema(schema_for!(StopWithArgsSpec)) + .parameters_schema(schema) .build() .unwrap() } @@ -181,6 +222,8 @@ impl From for Box { #[cfg(test)] mod tests { use super::*; + use schemars::schema_for; + use serde_json::json; fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall { let mut builder = ToolCall::builder().name(name).id("1").to_owned(); @@ -230,4 +273,36 @@ mod tests { // On unit; existing feedback is always present assert_eq!(out, ToolOutput::Stop(None)); } + + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] + struct CustomStopArgs { + value: i32, + } + + #[test] + fn test_stop_with_args_custom_schema_in_spec() { + let schema = schema_for!(CustomStopArgs); + let tool = StopWithArgs::with_parameters_schema(schema.clone()); + let spec = tool.tool_spec(); + assert_eq!(spec.parameters_schema, Some(schema)); + } + + #[tokio::test] + async fn test_stop_with_args_custom_schema_forwards_payload() { + let schema = schema_for!(CustomStopArgs); + let tool = StopWithArgs::with_parameters_schema(schema); + let ctx = (); + let args = r#"{"value":42}"#; + let tool_call = dummy_tool_call("stop", Some(args)); + let out = tool.invoke(&ctx, &tool_call).await.unwrap(); + assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42}))); + } + + #[test] + fn test_stop_with_args_default_schema_matches_previous() { + let tool = StopWithArgs::default(); + let spec = tool.tool_spec(); + let expected = schema_for!(DefaultStopWithArgsSpec); + assert_eq!(spec.parameters_schema, Some(expected)); + } } diff --git a/swiftide-core/src/chat_completion/tools.rs b/swiftide-core/src/chat_completion/tools.rs index 327f08b674..54c110f140 100644 --- a/swiftide-core/src/chat_completion/tools.rs +++ b/swiftide-core/src/chat_completion/tools.rs @@ -18,7 +18,7 @@ pub enum ToolOutput { Fail(String), /// Stops an agent with an optional message - Stop(Option>), + Stop(Option), /// Indicates that the agent failed and should stop AgentFailed(Option>), @@ -37,7 +37,7 @@ impl ToolOutput { ToolOutput::Stop(None) } - pub fn stop_with_args(output: impl Into>) -> Self { + pub fn stop_with_args(output: impl Into) -> Self { ToolOutput::Stop(Some(output.into())) } @@ -73,9 +73,9 @@ impl ToolOutput { } /// Get the inner text if the output is a `Stop` variant. - pub fn as_stop(&self) -> Option<&str> { + pub fn as_stop(&self) -> Option<&serde_json::Value> { match self { - ToolOutput::Stop(args) => args.as_deref(), + ToolOutput::Stop(args) => args.as_ref(), _ => None, } } @@ -107,7 +107,13 @@ impl std::fmt::Display for ToolOutput { match self { ToolOutput::Text(value) => write!(f, "{value}"), ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"), - ToolOutput::Stop(args) => write!(f, "Stop {}", args.as_deref().unwrap_or_default()), + ToolOutput::Stop(args) => { + if let Some(value) = args { + write!(f, "Stop {value}") + } else { + write!(f, "Stop") + } + } ToolOutput::FeedbackRequired(_) => { write!(f, "Feedback required") } diff --git a/swiftide-integrations/src/openai/chat_completion.rs b/swiftide-integrations/src/openai/chat_completion.rs index 5c931b4cad..0928ad92c4 100644 --- a/swiftide-integrations/src/openai/chat_completion.rs +++ b/swiftide-integrations/src/openai/chat_completion.rs @@ -25,11 +25,13 @@ use swiftide_core::chat_completion::{Usage, UsageBuilder}; use swiftide_core::metrics::emit_usage; use super::GenericOpenAI; -use super::ensure_tool_schema_additional_properties_false; use super::openai_error_to_language_model_error; use super::responses_api::{ build_responses_request_from_chat, response_to_chat_completion, responses_stream_adapter, }; +use super::{ + ensure_tool_schema_additional_properties_false, ensure_tool_schema_required_matches_properties, +}; use tracing_futures::Instrument; #[async_trait] @@ -488,12 +490,15 @@ fn tools_to_openai(spec: &ToolSpec) -> Result { None => json!({ "type": "object", "properties": {}, + "required": [], "additionalProperties": false, }), }; ensure_tool_schema_additional_properties_false(&mut parameters) .context("tool schema must allow no additional properties")?; + ensure_tool_schema_required_matches_properties(&mut parameters) + .context("tool schema must list required properties")?; tracing::debug!( parameters = serde_json::to_string_pretty(¶meters).unwrap(), tool = %spec.name, diff --git a/swiftide-integrations/src/openai/mod.rs b/swiftide-integrations/src/openai/mod.rs index 9a488c84a5..53ad59343c 100644 --- a/swiftide-integrations/src/openai/mod.rs +++ b/swiftide-integrations/src/openai/mod.rs @@ -287,6 +287,44 @@ pub(crate) fn ensure_tool_schema_additional_properties_false( Ok(()) } +pub(crate) fn ensure_tool_schema_required_matches_properties( + parameters: &mut Value, +) -> anyhow::Result<()> { + let object = parameters + .as_object_mut() + .context("tool schema must be a JSON object")?; + + let property_names: Vec = if let Some(Value::Object(map)) = object.get("properties") { + map.keys().cloned().collect() + } else { + object + .entry("required".to_string()) + .or_insert_with(|| Value::Array(Vec::new())); + return Ok(()); + }; + + let required_entry = object + .entry("required".to_string()) + .or_insert_with(|| Value::Array(Vec::new())); + + let required_array = required_entry + .as_array_mut() + .context("tool schema 'required' must be an array")?; + + for name in property_names { + let name_ref = name.as_str(); + let already_present = required_array + .iter() + .any(|value| value.as_str().is_some_and(|s| s == name_ref)); + + if !already_present { + required_array.push(Value::String(name)); + } + } + + Ok(()) +} + impl OpenAI { /// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances. pub fn builder() -> OpenAIBuilder { diff --git a/swiftide-integrations/src/openai/responses_api.rs b/swiftide-integrations/src/openai/responses_api.rs index 7d4fcb1ef3..d1a98a6915 100644 --- a/swiftide-integrations/src/openai/responses_api.rs +++ b/swiftide-integrations/src/openai/responses_api.rs @@ -19,7 +19,7 @@ use swiftide_core::chat_completion::{ use super::{ GenericOpenAI, ensure_tool_schema_additional_properties_false, - openai_error_to_language_model_error, + ensure_tool_schema_required_matches_properties, openai_error_to_language_model_error, }; use crate::openai::LanguageModelError; @@ -143,12 +143,15 @@ fn tool_spec_to_responses_tool(spec: &ToolSpec) -> Result { None => json!({ "type": "object", "properties": {}, + "required": [], "additionalProperties": false, }), }; ensure_tool_schema_additional_properties_false(&mut parameters) .context("tool schema must allow no additional properties")?; + ensure_tool_schema_required_matches_properties(&mut parameters) + .context("tool schema must list required properties")?; let function = FunctionArgs::default() .name(&spec.name) @@ -214,7 +217,8 @@ fn chat_messages_to_input_items(messages: &[ChatMessage]) -> LmResult { serde_json::Value::String(text.clone()) } - ToolOutput::Stop(message) | ToolOutput::AgentFailed(message) => { + ToolOutput::Stop(message) => message.clone().unwrap_or(serde_json::Value::Null), + ToolOutput::AgentFailed(message) => { serde_json::Value::String(message.clone().unwrap_or_default().into_owned()) } _ => serde_json::Value::Null,