Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ path = "usage_metrics.rs"
name = "tasks"
path = "tasks.rs"

[[example]]
name = "stop-with-args-custom-schema"
path = "stop_with_args_custom_schema.rs"

[[example]]
name = "responses-api"
path = "responses_api.rs"
Expand Down
91 changes: 91 additions & 0 deletions examples/stop_with_args_custom_schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//! Demonstrates how to plug a custom JSON schema into the stop tool for an OpenAI-powered agent.
//!
//! Set the `OPENAI_API_KEY` environment variable before running the example. The agent guides the
//! model to call the `stop` tool with a structured payload that matches the schema defined below.
//! The on-stop hook prints the structured payload that made the agent stop.
use anyhow::Result;
use schemars::{JsonSchema, Schema, schema_for};
use serde::{Deserialize, Serialize};
use serde_json::to_string_pretty;
use swiftide::agents::tools::control::StopWithArgs;
use swiftide::agents::{Agent, StopReason};
use swiftide::traits::Tool;

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
enum TaskStatus {
Succeeded,
Failed,
Cancelled,
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(deny_unknown_fields)]
struct StopPayload {
status: TaskStatus,
summary: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
details: Option<serde_json::Value>,
}

fn stop_schema() -> Schema {
schema_for!(StopPayload)
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();

let schema = stop_schema();
let stop_tool = StopWithArgs::with_parameters_schema(schema.clone());

println!(
"stop tool schema:\n{}",
to_string_pretty(&stop_tool.tool_spec())?,
);

let openai = swiftide::integrations::openai::OpenAI::builder()
.default_prompt_model("gpt-4o-mini")
.default_embed_model("text-embedding-3-small")
.build()?;

let mut builder = Agent::builder();
builder
.llm(&openai)
.without_default_stop_tool()
.tools([stop_tool.clone()])
.on_stop(|_, reason, _| {
Box::pin(async move {
if let StopReason::RequestedByTool(_, payload) = reason
&& let Some(payload) = payload
{
println!(
"agent stopped with structured payload:\n{}",
to_string_pretty(&payload).unwrap_or_else(|_| payload.to_string()),
);
}
Ok(())
})
});

if let Some(prompt) = builder.system_prompt_mut() {
prompt
.with_role("Workflow finisher")
.with_guidelines([
"Summarize the work that was just completed and recommend next actions.",
"When you are done, call the `stop` tool using the provided JSON schema.",
"Always include the `details` field; use null when there is nothing to add.",
])
.with_constraints(["Never fabricate task status values outside the schema."]);
}

let mut agent = builder.build()?;

agent
.query_once(
"You completed onboarding five merchants today. Prepare a final handoff report and stop.",
)
.await?;

Ok(())
}
7 changes: 4 additions & 3 deletions swiftide-agents/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::borrow::Cow;

use serde::{Deserialize, Serialize};
use serde_json::Value;
use swiftide_core::chat_completion::ToolCall;

#[derive(Clone, Debug, Default, strum_macros::EnumDiscriminants, strum_macros::EnumIs)]
Expand All @@ -29,7 +30,7 @@ impl State {
#[derive(Clone, Debug, strum_macros::EnumIs, PartialEq, Serialize, Deserialize)]
pub enum StopReason {
/// A tool called stop
RequestedByTool(ToolCall, Option<Cow<'static, str>>),
RequestedByTool(ToolCall, Option<Value>),

/// Agent failed to complete with optional message
AgentFailed(Option<Cow<'static, str>>),
Expand All @@ -52,9 +53,9 @@ pub enum StopReason {
}

impl StopReason {
pub fn as_requested_by_tool(&self) -> Option<(&ToolCall, Option<&str>)> {
pub fn as_requested_by_tool(&self) -> Option<(&ToolCall, Option<&Value>)> {
if let StopReason::RequestedByTool(t, message) = self {
Some((t, message.as_deref()))
Some((t, message.as_ref()))
} else {
None
}
Expand Down
97 changes: 86 additions & 11 deletions swiftide-agents/src/tools/control.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Control tools manage control flow during agent's lifecycle.
use anyhow::Result;
use async_trait::async_trait;
use schemars::schema_for;
use schemars::{Schema, schema_for};
use std::borrow::Cow;
use swiftide_core::{
AgentContext, ToolFeedback,
Expand Down Expand Up @@ -42,11 +42,42 @@ impl From<Stop> for Box<dyn Tool> {
}

/// `StopWithArgs` is an alternative stop tool that takes arguments
#[derive(Clone, Debug, Default)]
pub struct StopWithArgs {}
#[derive(Clone, Debug)]
pub struct StopWithArgs {
parameters_schema: Option<Schema>,
expects_output_field: bool,
}

impl Default for StopWithArgs {
fn default() -> Self {
Self {
parameters_schema: Some(schema_for!(DefaultStopWithArgsSpec)),
expects_output_field: true,
}
}
}

impl StopWithArgs {
/// Create a new `StopWithArgs` tool with a custom parameters schema.
///
/// When providing a custom schema the full argument payload will be forwarded to the
/// stop output without requiring an `output` field wrapper.
pub fn with_parameters_schema(schema: Schema) -> Self {
Self {
parameters_schema: Some(schema),
expects_output_field: false,
}
}

fn parameters_schema(&self) -> Schema {
self.parameters_schema
.clone()
.unwrap_or_else(|| schema_for!(DefaultStopWithArgsSpec))
}
}

#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
struct StopWithArgsSpec {
struct DefaultStopWithArgsSpec {
pub output: String,
}

Expand All @@ -57,24 +88,34 @@ impl Tool for StopWithArgs {
_agent_context: &dyn AgentContext,
tool_call: &ToolCall,
) -> Result<ToolOutput, ToolError> {
let args: StopWithArgsSpec = serde_json::from_str(
tool_call
.args()
.ok_or(ToolError::missing_arguments("output"))?,
)?;
let raw_args = tool_call
.args()
.ok_or_else(|| ToolError::missing_arguments("arguments"))?;

let json: serde_json::Value = serde_json::from_str(raw_args)?;

Ok(ToolOutput::stop_with_args(args.output))
let output = if self.expects_output_field {
json.get("output")
.cloned()
.ok_or_else(|| ToolError::missing_arguments("output"))?
} else {
json
};

Ok(ToolOutput::stop_with_args(output))
}

fn name(&self) -> Cow<'_, str> {
"stop".into()
}

fn tool_spec(&self) -> ToolSpec {
let schema = self.parameters_schema();

ToolSpec::builder()
.name("stop")
.description("When you have completed, your task, call this with your expected output")
.parameters_schema(schema_for!(StopWithArgsSpec))
.parameters_schema(schema)
.build()
.unwrap()
}
Expand Down Expand Up @@ -181,6 +222,8 @@ impl From<ApprovalRequired> for Box<dyn Tool> {
#[cfg(test)]
mod tests {
use super::*;
use schemars::schema_for;
use serde_json::json;

fn dummy_tool_call(name: &str, args: Option<&str>) -> ToolCall {
let mut builder = ToolCall::builder().name(name).id("1").to_owned();
Expand Down Expand Up @@ -230,4 +273,36 @@ mod tests {
// On unit; existing feedback is always present
assert_eq!(out, ToolOutput::Stop(None));
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
struct CustomStopArgs {
value: i32,
}

#[test]
fn test_stop_with_args_custom_schema_in_spec() {
let schema = schema_for!(CustomStopArgs);
let tool = StopWithArgs::with_parameters_schema(schema.clone());
let spec = tool.tool_spec();
assert_eq!(spec.parameters_schema, Some(schema));
}

#[tokio::test]
async fn test_stop_with_args_custom_schema_forwards_payload() {
let schema = schema_for!(CustomStopArgs);
let tool = StopWithArgs::with_parameters_schema(schema);
let ctx = ();
let args = r#"{"value":42}"#;
let tool_call = dummy_tool_call("stop", Some(args));
let out = tool.invoke(&ctx, &tool_call).await.unwrap();
assert_eq!(out, ToolOutput::stop_with_args(json!({"value": 42})));
}

#[test]
fn test_stop_with_args_default_schema_matches_previous() {
let tool = StopWithArgs::default();
let spec = tool.tool_spec();
let expected = schema_for!(DefaultStopWithArgsSpec);
assert_eq!(spec.parameters_schema, Some(expected));
}
}
16 changes: 11 additions & 5 deletions swiftide-core/src/chat_completion/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub enum ToolOutput {
Fail(String),

/// Stops an agent with an optional message
Stop(Option<Cow<'static, str>>),
Stop(Option<serde_json::Value>),

/// Indicates that the agent failed and should stop
AgentFailed(Option<Cow<'static, str>>),
Expand All @@ -37,7 +37,7 @@ impl ToolOutput {
ToolOutput::Stop(None)
}

pub fn stop_with_args(output: impl Into<Cow<'static, str>>) -> Self {
pub fn stop_with_args(output: impl Into<serde_json::Value>) -> Self {
ToolOutput::Stop(Some(output.into()))
}

Expand Down Expand Up @@ -73,9 +73,9 @@ impl ToolOutput {
}

/// Get the inner text if the output is a `Stop` variant.
pub fn as_stop(&self) -> Option<&str> {
pub fn as_stop(&self) -> Option<&serde_json::Value> {
match self {
ToolOutput::Stop(args) => args.as_deref(),
ToolOutput::Stop(args) => args.as_ref(),
_ => None,
}
}
Expand Down Expand Up @@ -107,7 +107,13 @@ impl std::fmt::Display for ToolOutput {
match self {
ToolOutput::Text(value) => write!(f, "{value}"),
ToolOutput::Fail(value) => write!(f, "Tool call failed: {value}"),
ToolOutput::Stop(args) => write!(f, "Stop {}", args.as_deref().unwrap_or_default()),
ToolOutput::Stop(args) => {
if let Some(value) = args {
write!(f, "Stop {value}")
} else {
write!(f, "Stop")
}
}
ToolOutput::FeedbackRequired(_) => {
write!(f, "Feedback required")
}
Expand Down
7 changes: 6 additions & 1 deletion swiftide-integrations/src/openai/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ use swiftide_core::chat_completion::{Usage, UsageBuilder};
use swiftide_core::metrics::emit_usage;

use super::GenericOpenAI;
use super::ensure_tool_schema_additional_properties_false;
use super::openai_error_to_language_model_error;
use super::responses_api::{
build_responses_request_from_chat, response_to_chat_completion, responses_stream_adapter,
};
use super::{
ensure_tool_schema_additional_properties_false, ensure_tool_schema_required_matches_properties,
};
use tracing_futures::Instrument;

#[async_trait]
Expand Down Expand Up @@ -488,12 +490,15 @@ fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
None => json!({
"type": "object",
"properties": {},
"required": [],
"additionalProperties": false,
}),
};

ensure_tool_schema_additional_properties_false(&mut parameters)
.context("tool schema must allow no additional properties")?;
ensure_tool_schema_required_matches_properties(&mut parameters)
.context("tool schema must list required properties")?;
tracing::debug!(
parameters = serde_json::to_string_pretty(&parameters).unwrap(),
tool = %spec.name,
Expand Down
38 changes: 38 additions & 0 deletions swiftide-integrations/src/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,44 @@ pub(crate) fn ensure_tool_schema_additional_properties_false(
Ok(())
}

pub(crate) fn ensure_tool_schema_required_matches_properties(
parameters: &mut Value,
) -> anyhow::Result<()> {
let object = parameters
.as_object_mut()
.context("tool schema must be a JSON object")?;

let property_names: Vec<String> = if let Some(Value::Object(map)) = object.get("properties") {
map.keys().cloned().collect()
} else {
object
.entry("required".to_string())
.or_insert_with(|| Value::Array(Vec::new()));
return Ok(());
};

let required_entry = object
.entry("required".to_string())
.or_insert_with(|| Value::Array(Vec::new()));

let required_array = required_entry
.as_array_mut()
.context("tool schema 'required' must be an array")?;

for name in property_names {
let name_ref = name.as_str();
let already_present = required_array
.iter()
.any(|value| value.as_str().is_some_and(|s| s == name_ref));

if !already_present {
required_array.push(Value::String(name));
}
}

Ok(())
}

impl OpenAI {
/// Creates a new `OpenAIBuilder` for constructing `OpenAI` instances.
pub fn builder() -> OpenAIBuilder {
Expand Down
Loading