Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#![deny(missing_docs)]
//! Macros for the `aisdk` library.

use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser;
Expand All @@ -21,17 +24,22 @@ use syn::{
///
/// #[tool]
/// /// Returns the username
/// fn get_username(id: String) -> Result<String, String> {
/// fn get_username(id: String) -> Tool {
/// // Your code here
/// Ok(format!("user_{}", id))
/// }
/// ```
///
/// - `get_username` becomes the name of the tool
/// - `Returns the username` becomes the description of the tool
/// - `"Returns the username"` becomes the description of the tool
/// - `id: String` becomes the input of the tool. converted to `{"id": "string"}`
/// as json schema
///
/// The function should return a `Result<String, String>` eventhough the return statement
/// returns a `Tool` object. This is because the macro will automatically convert the
/// function into a `Tool` object and return it. You should return what the model can
/// understand as a `String`.
///
/// In the event that the model refuses to send an argument, the default implementation
/// will be used. this works perfectly for arguments that are `Option`s. Make sure to
/// use `Option` types for arguments that are optional or implement a default for those
Expand Down
3 changes: 2 additions & 1 deletion src/core/language_model/generate_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ impl Deref for GenerateTextResponse {
mod tests {
use super::*;
use crate::core::{
AssistantMessage, ToolCallInfo, ToolResultInfo,
AssistantMessage,
language_model::{LanguageModelResponseContentType, Usage},
messages::TaggedMessage,
tools::{ToolCallInfo, ToolResultInfo},
};

#[test]
Expand Down
5 changes: 4 additions & 1 deletion src/core/language_model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ pub mod stream_text;
use crate::core::messages::{AssistantMessage, TaggedMessage, TaggedMessageHelpers};
use crate::core::tools::ToolList;
use crate::core::utils;
use crate::core::{Message, ToolCallInfo, ToolResultInfo};
use crate::core::{
Message,
tools::{ToolCallInfo, ToolResultInfo},
};
use crate::error::{Error, Result};
use async_trait::async_trait;
use derive_builder::Builder;
Expand Down
2 changes: 1 addition & 1 deletion src/core/messages.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::core::{
ToolCallInfo, ToolResultInfo,
language_model::{LanguageModelResponseContentType, Usage},
tools::{ToolCallInfo, ToolResultInfo},
};

/// Role for model messages.
Expand Down
4 changes: 2 additions & 2 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
pub mod language_model;
pub mod messages;
pub mod provider;
pub mod tools;
pub(crate) mod tools;
pub mod utils;

pub use aisdk_macros::tool;
Expand All @@ -23,4 +23,4 @@ pub use language_model::{

pub use messages::{AssistantMessage, Message, Role, SystemMessage, UserMessage};
pub use provider::Provider;
pub use tools::{Tool, ToolCallInfo, ToolResultInfo};
pub use tools::{Tool, ToolExecute};
170 changes: 167 additions & 3 deletions src/core/tools.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,89 @@
//! Tools are a way to extend the capabilities of a language model. aisdk provides a
//! macro to simplify the process of defining and registering tools. This module provides
//! The necessary types and functions for defining and using tools both by the macro and
//! by the user.
//!
//! The Tool struct is the core component of a tool. It contains the `name`, `description`,
//! and `input_schema` of the tool as well as the logic to execute. The `execute`
//! method is the main entry point for executing the tool. The language model is responsible
//! for calling this method using `input_schema` to generate the arguments for the tool.
//!
//!
//! The tool macro generates the necessary code for registering the tool with the SDK.
//! It infers the necessary fields for the Tool struct from a valid rust function.
//!
//! # Example
//! ```
//! use aisdk::core::Tool;
//! use aisdk_macros::tool;
//!
//! #[tool]
//! /// Adds two numbers together.
//! pub fn sum(a: u8, b: u8) -> Tool {
//! Ok(format!("{}", a + b))
//! }
//!
//! let tool: Tool = get_weather();
//!
//! assert_eq!(tool.name, "get_weather");
//! assert_eq!(tool.description, "Adds two numbers together.");
//! assert_eq!(tool.input_schema.to_value(), serde_json::json!({
//! "type": "object",
//! "required": ["a", "b"],
//! "properties": {
//! "a": {
//! "type": "integer",
//! "format": "uint8",
//! "minimum": 0,
//! "maximum": 255
//! },
//! "b": {
//! "type": "integer",
//! "format": "uint8",
//! "minimum": 0,
//! "maximum": 255
//! }
//! }
//! }));
//!
//!
//! ```
//!
//! # Example with struct
//! ```
//! use aisdk::core::{Tool, ToolExecute};
//! use serde_json::Value;
//!
//! let tool: Tool = Tool {
//! name: "sum".to_string(),
//! description: "Adds two numbers together.".to_string(),
//! input_schema: serde_json::json!({
//! "type": "object",
//! "required": ["a", "b"],
//! "properties": {
//! "a": {
//! "type": "integer",
//! "format": "uint8",
//! "minimum": 0,
//! "maximum": 255
//! },
//! "b": {
//! "type": "integer",
//! "format": "uint8",
//! "minimum": 0,
//! "maximum": 255
//! }
//! }
//! }),
//! execute:
//! ToolExecute::new(Box::new(|params: Value| {
//! let a = params["a"].as_u64().unwrap();
//! let b = params["b"].as_u64().unwrap();
//! Ok(format!("{}", a + b))
//! })),
//! ```
//!

use crate::error::{Error, Result};
use derive_builder::Builder;
use schemars::Schema;
Expand All @@ -9,16 +95,22 @@ use tokio::task::JoinHandle;

pub type ToolFn = Box<dyn Fn(Value) -> std::result::Result<String, String> + Send + Sync>;

/// Holds the function that will be called when the tool is executed. the function
/// should take a single argument of type `Value` and returns a
/// `Result<String, String>`.
#[derive(Clone)]
pub struct ToolExecute {
inner: Arc<ToolFn>,
}

impl ToolExecute {
pub fn call(&self, map: Value) -> Result<String> {
pub(crate) fn call(&self, map: Value) -> Result<String> {
(*self.inner)(map).map_err(Error::ToolCallError)
}

/// Creates a new `ToolExecute` instance with the given function.
/// The function should take a single argument of type `Value` and return a
/// `Result<String, String>`.
pub fn new(f: ToolFn) -> Self {
Self { inner: Arc::new(f) }
}
Expand Down Expand Up @@ -48,6 +140,74 @@ impl<'de> Deserialize<'de> for ToolExecute {
}
}

/// The `Tool` struct represents a tool that can be executed by a language model.
/// It contains the name, description, input schema, and execution logic of the tool.
/// The `execute` method is the main entry point for executing the tool and is called.
/// by the language model.
///
/// `name` and `description` help the model identify and understand the tool. `input_schema`
/// defines the structure of the input data that the tool expects. `Schema` is a type from
/// the [`schemars`](https://docs.rs/schemars/latest/schemars/) crate that can be used to
/// define the input schema.
///
/// The execute method is responsible for executing the tool and returning the result to
/// the language model. It takes a single argument of type `Value` and returns a
/// `Result<String, String>`.
///
/// # Example
/// ```
/// use aisdk::core::Tool;
/// use aisdk_macros::tool;
///
/// let tool: Tool = Tool {
/// name: "sum".to_string(),
/// description: "Adds two numbers together.".to_string(),
/// input_schema: serde_json::json!({
/// "type": "object",
/// "required": ["a", "b"],
/// "properties": {
/// "a": {
/// "type": "integer",
/// "format": "uint8",
/// "minimum": 0,
/// "maximum": 255
/// },
/// "b": {
/// "type": "integer",
/// "format": "uint8",
/// "minimum": 0,
/// "maximum": 255
/// }
/// }
/// }),
/// execute: ToolExecute::new(Box::new(|params| {
/// let a = params["a"].as_u64().unwrap();
/// let b = params["b"].as_u64().unwrap();
/// Ok(format!("{}", a + b))
/// })),
/// };
///
/// assert_eq!(tool.name, "sum");
/// assert_eq!(tool.description, "Adds two numbers together.");
/// assert_eq!(tool.input_schema.to_value(), serde_json::json!({
/// "type": "object",
/// "required": ["a", "b"],
/// "properties": {
/// "a": {
/// "type": "integer",
/// "format": "uint8",
/// "minimum": 0,
/// "maximum": 255
/// },
/// "b": {
/// "type": "integer",
/// "format": "uint8",
/// "minimum": 0,
/// "maximum": 255
/// }
/// }
/// }));
/// ```
#[derive(Builder, Clone, Default)]
#[builder(pattern = "owned", setter(into), build_fn(error = "Error"))]
pub struct Tool {
Expand All @@ -71,6 +231,7 @@ impl Debug for Tool {
}

impl Tool {
/// Creates a new `Tool` instance with default values.
pub fn new() -> Self {
Self {
name: "".to_string(),
Expand Down Expand Up @@ -121,20 +282,23 @@ impl ToolList {
#[derive(Default, Debug, Clone, PartialEq)]
/// Describes a tool
pub struct ToolDetails {
// the name of the tool, usually a function name.
/// The name of the tool, usually a function name.
pub name: String,
// uniquely identifies a tool, provided by the LLM.
/// Uniquely identifies a tool, usually provided by the provider.
pub id: String,
}

/// Contains information necessary to call a tool
#[derive(Default, Debug, Clone, PartialEq)]
pub struct ToolCallInfo {
/// The details of the tool to be called.
pub tool: ToolDetails,
/// The input parameters for the tool.
pub input: serde_json::Value,
}

impl ToolCallInfo {
/// Creates a new `ToolCallInfo` instance with the given name.
pub fn new(name: impl Into<String>) -> Self {
Self {
tool: ToolDetails {
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pub enum Error {
#[error("Invalid input: {0}")]
InvalidInput(String),

/// An error related to tool execution. This includes errors caused by the
/// tool itself as well by the SDK when interacting with the tool.
#[error("Tool error: {0}")]
ToolCallError(String),

Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#![deny(missing_docs)]

//! `aisdk` is An open-source Rust library for building AI-powered applications, inspired by the Vercel AI SDK. It provides a type-safe interface for interacting with Large Language Models (LLMs).

pub mod core;
pub mod error;
#[cfg(feature = "prompt")]
Expand Down
2 changes: 1 addition & 1 deletion src/providers/anthropic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::core::language_model::{
};
use crate::core::messages::AssistantMessage;
use crate::core::tools::ToolDetails;
use crate::core::{LanguageModelStreamChunkType, ToolCallInfo};
use crate::core::{LanguageModelStreamChunkType, tools::ToolCallInfo};
use crate::error::ProviderError;
use crate::providers::anthropic::client::{
AnthropicClient, AnthropicDelta, AnthropicError, AnthropicMessageDeltaUsage,
Expand Down
2 changes: 1 addition & 1 deletion src/providers/openai/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//! Helper functions and conversions for the OpenAI provider.

use crate::core::Tool;
use crate::core::language_model::{
LanguageModelOptions, LanguageModelResponseContentType, ReasoningEffort, Usage,
};
use crate::core::messages::Message;
use crate::core::tools::Tool;
use async_openai::types::responses::{
CreateResponse, Function, Input, InputContent, InputItem, InputMessage, InputMessageType,
ReasoningConfig, ReasoningSummary, Role, TextConfig, TextResponseFormat, ToolDefinition,
Expand Down
Loading