diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 56057bd..2f9421d 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,3 +1,6 @@ +#![deny(missing_docs)] +//! Macros for the `aisdk` library. + use proc_macro::TokenStream; use quote::quote; use syn::parse::Parser; @@ -21,17 +24,22 @@ use syn::{ /// /// #[tool] /// /// Returns the username -/// fn get_username(id: String) -> Result { +/// fn get_username(id: String) -> Tool { /// // Your code here /// Ok(format!("user_{}", id)) /// } /// ``` /// /// - `get_username` becomes the name of the tool -/// - `Returns the username` becomes the description of the tool +/// - `"Returns the username"` becomes the description of the tool /// - `id: String` becomes the input of the tool. converted to `{"id": "string"}` /// as json schema /// +/// The function should return a `Result` eventhough the return statement +/// returns a `Tool` object. This is because the macro will automatically convert the +/// function into a `Tool` object and return it. You should return what the model can +/// understand as a `String`. +/// /// In the event that the model refuses to send an argument, the default implementation /// will be used. this works perfectly for arguments that are `Option`s. Make sure to /// use `Option` types for arguments that are optional or implement a default for those diff --git a/src/core/language_model/generate_text.rs b/src/core/language_model/generate_text.rs index 7c42ac6..8f0a946 100644 --- a/src/core/language_model/generate_text.rs +++ b/src/core/language_model/generate_text.rs @@ -1,3 +1,5 @@ +//! Text Generation impl for the `LanguageModelRequest` trait. + use crate::error::Result; use crate::{ Error, @@ -16,12 +18,40 @@ use serde::ser::Error as SerdeError; use std::ops::Deref; impl LanguageModelRequest { - /// Generates text using a specified language model. + /// Generates text and executes tools using the language model. + /// + /// This method performs non-streaming text generation, potentially involving multiple + /// steps of tool calling and execution until the conversation reaches a natural stopping point. + /// The model may call tools based on the configured options, and responses are processed + /// iteratively until completion. + /// + /// For streaming responses, use [`stream_text`](Self::stream_text) instead. + /// + /// # Returns + /// + /// A [`GenerateTextResponse`] containing the final conversation state and generated content. + /// + /// # Errors + /// + /// Returns an [`Error`] if the underlying language model fails to generate a response + /// or if tool execution encounters an error. + /// + /// # Examples /// - /// Generate a text and call tools for a given prompt using a language model. - /// This function does not stream the output. If you want to stream the output, use `StreamText` instead. + /// ```rust,no_run + /// # use aisdk::core::LanguageModelRequest; + /// # use aisdk::providers::openai::OpenAI; + /// # async fn example() -> Result<(), Box> { + /// let mut request = LanguageModelRequest::builder() + /// .model(OpenAI::new("gpt-4")) + /// .prompt("Calculate 2 + 2") + /// .build(); /// - /// Returns an `Error` if the underlying model fails to generate a response. + /// let response = request.generate_text().await?; + /// println!("Response: {}", response.text().unwrap_or("No text")); + /// # Ok(()) + /// # } + /// ``` pub async fn generate_text(&mut self) -> Result { let (system_prompt, messages) = resolve_message(&self.options, &self.prompt); @@ -136,6 +166,22 @@ pub struct GenerateTextResponse { } impl GenerateTextResponse { + /// Deserializes the response text into a structured type. + /// + /// This method attempts to parse the generated text as JSON and deserialize it + /// into the specified type `T`. It requires that the response contains text content. + /// + /// # Type Parameters + /// + /// * `T` - The type to deserialize into, which must implement [`DeserializeOwned`]. + /// + /// # Returns + /// + /// A result containing the deserialized value or a JSON error. + /// + /// # Errors + /// + /// Returns an error if there is no text response or if deserialization fails. pub fn into_schema(&self) -> std::result::Result { if let Some(text) = &self.text() { serde_json::from_str(text) @@ -145,6 +191,7 @@ impl GenerateTextResponse { } #[cfg(any(test, feature = "test-access"))] + /// Returns the step ids of the messages in the response. pub fn step_ids(&self) -> Vec { self.options.messages.iter().map(|t| t.step_id).collect() } @@ -162,9 +209,10 @@ impl Deref for GenerateTextResponse { mod tests { use super::*; use crate::core::{ - AssistantMessage, ToolCallInfo, ToolResultInfo, + AssistantMessage, language_model::{LanguageModelResponseContentType, Usage}, messages::TaggedMessage, + tools::{ToolCallInfo, ToolResultInfo}, }; #[test] diff --git a/src/core/language_model/mod.rs b/src/core/language_model/mod.rs index 864bf83..93504f4 100644 --- a/src/core/language_model/mod.rs +++ b/src/core/language_model/mod.rs @@ -12,7 +12,10 @@ pub mod stream_text; use crate::core::messages::{AssistantMessage, TaggedMessage, TaggedMessageHelpers}; use crate::core::tools::ToolList; use crate::core::utils; -use crate::core::{Message, ToolCallInfo, ToolResultInfo}; +use crate::core::{ + Message, + tools::{ToolCallInfo, ToolResultInfo}, +}; use crate::error::{Error, Result}; use async_trait::async_trait; use derive_builder::Builder; @@ -29,6 +32,8 @@ use std::task::{Context, Poll}; // ============================================================================ // Section: constants // ============================================================================ + +/// Default maximum number of tool calling steps allowed in a single request. pub const DEFAULT_TOOL_STEP_COUNT: usize = 3; // ============================================================================ @@ -44,6 +49,7 @@ pub const DEFAULT_TOOL_STEP_COUNT: usize = 3; /// generation and streaming responses. #[async_trait] pub trait LanguageModel: Send + Sync + std::fmt::Debug { + /// Returns a human-readable name for the language model. fn name(&self) -> String; /// Performs a single, non-streaming text generation request. /// @@ -71,29 +77,44 @@ pub trait LanguageModel: Send + Sync + std::fmt::Debug { // Section: hook types // ============================================================================ +/// Type alias for a hook function that determines when to stop generation. +/// +/// Returns `true` if generation should stop. pub type StopWhenHook = Arc bool + Send + Sync>; + +/// Type alias for a hook function called before each generation step. pub type OnStepStartHook = Arc; + +/// Type alias for a hook function called after each generation step. pub type OnStepFinishHook = Arc; // ============================================================================ // Section: structs and impls // ============================================================================ -/// A "step" represents a single cycle of model interaction. +/// Represents a single step in the language model interaction process. +/// +/// A step contains all messages exchanged during one cycle of model interaction, +/// including user input, assistant responses, and tool calls/results. pub struct Step { + /// The unique identifier for this step. pub step_id: usize, + /// The messages that occurred during this step. pub messages: Vec, } impl Step { + /// Creates a new `Step` with the given ID and messages. pub fn new(step_id: usize, messages: Vec) -> Self { Self { step_id, messages } } + /// Returns a reference to the messages in this step. pub fn messages(&self) -> &[Message] { &self.messages } + /// Calculates the total token usage for this step. pub fn usage(&self) -> Usage { self.messages() .iter() @@ -104,6 +125,7 @@ impl Step { .fold(Usage::default(), |acc, u| &acc + u) } + /// Returns a vector of all tool calls in the conversation. pub fn tool_calls(&self) -> Option> { let calls: Vec = self .messages() @@ -120,6 +142,7 @@ impl Step { if calls.is_empty() { None } else { Some(calls) } } + /// Returns a vector of all tool results in the conversation. pub fn tool_results(&self) -> Option> { let results: Vec = self .messages() @@ -141,7 +164,10 @@ impl Step { // Section: options // ============================================================================ -/// Options for a language model request. +/// Configuration options for language model requests. +/// +/// This struct contains all the parameters that can be used to customize +/// text generation, including sampling parameters, tools, and hooks. #[derive(Clone, Default, Builder)] #[builder(pattern = "owned", setter(into), build_fn(error = "Error"))] pub struct LanguageModelOptions { @@ -155,19 +181,19 @@ pub struct LanguageModelOptions { /// by the model, calls will generate deterministic results. pub seed: Option, - /// Randomness. + /// Controls randomness in generation (0-100, scaled to 0.0-1.0). pub temperature: Option, - /// Nucleus sampling. + /// Nucleus sampling parameter (0-100, scaled to 0.0-1.0). pub top_p: Option, - /// Top-k sampling. + /// Top-k sampling parameter. pub top_k: Option, - /// Maximum number of retries. + /// Maximum number of retries for failed requests. pub max_retries: Option, - /// Maxoutput tokens. + /// Maximum number of output tokens to generate. pub max_output_tokens: Option, /// Stop sequences. @@ -182,29 +208,28 @@ pub struct LanguageModelOptions { /// to repeatedly use the same words or phrases. pub frequency_penalty: Option, - /// Hook to stop tool calling if returns true + /// Hook to conditionally stop generation. pub stop_when: Option, - /// Hook called before each step (language model request) + /// Hook called before each generation step. pub on_step_start: Option, - /// Hook called after each step finishes + /// Hook called after each generation step. pub on_step_finish: Option, - /// Reasoning effort + /// Level of reasoning effort for the model. pub reasoning_effort: Option, /// List of tools to use. pub(crate) tools: Option, - /// Used to track message steps + /// Current step ID for tracking multi-step interactions. pub(crate) current_step_id: usize, - /// The messages to generate text from. - /// At least User Message is required. + /// The conversation messages for the request. pub(crate) messages: Vec, - // The stop reasons. should be updated after each step. + /// The reason why generation stopped. pub(crate) stop_reason: Option, } @@ -233,17 +258,17 @@ impl Debug for LanguageModelOptions { } impl LanguageModelOptions { + /// Creates a new builder for `LanguageModelOptions`. pub fn builder() -> LanguageModelOptionsBuilder { LanguageModelOptionsBuilder::default() } + /// Returns a vector of all messages in the conversation. pub fn messages(&self) -> Vec { self.messages.iter().map(|m| m.message.clone()).collect() } - /// Calls the requested tools, adds tool ouput message to messages, - /// and decrements the step count. uses the previous step id for tagging - /// the created messages. + /// Executes a tool call and adds the result to the message history. pub(crate) async fn handle_tool_call(&mut self, input: &ToolCallInfo) -> &mut Self { if let Some(tools) = &self.tools { let tool_result_task = tools.execute(input.clone()).await; @@ -275,6 +300,7 @@ impl LanguageModelOptions { } } + /// Returns the step with the given index, if it exists. pub fn step(&self, index: usize) -> Option { let messages: Vec = self .messages @@ -289,11 +315,13 @@ impl LanguageModelOptions { } } + /// Returns the most recent step, if any. pub fn last_step(&self) -> Option { let max_step = self.messages.iter().map(|t| t.step_id).max()?; self.step(max_step) } + /// Returns all steps in chronological order. pub fn steps(&self) -> Vec { let mut step_map: HashMap> = HashMap::new(); for tagged in &self.messages { @@ -310,6 +338,7 @@ impl LanguageModelOptions { steps } + /// Calculates the total token usage across all steps. pub fn usage(&self) -> Usage { self.steps() .iter() @@ -317,6 +346,7 @@ impl LanguageModelOptions { .fold(Usage::default(), |acc, u| &acc + &u) } + /// Returns the content of the last assistant message, excluding reasoning. pub fn content(&self) -> Option<&LanguageModelResponseContentType> { if let Some(msg) = self.messages.last() { match msg.message { @@ -334,6 +364,7 @@ impl LanguageModelOptions { } } + /// Returns the text content of the last assistant message. pub fn text(&self) -> Option { if let Some(msg) = self.messages.last() { match msg.message { @@ -348,14 +379,17 @@ impl LanguageModelOptions { } } + /// Extracts all tool results from the conversation. pub fn tool_results(&self) -> Option> { self.messages.as_slice().extract_tool_results() } + /// Extracts all tool calls from the conversation. pub fn tool_calls(&self) -> Option> { self.messages.as_slice().extract_tool_calls() } + /// Returns the reason why generation stopped. pub fn stop_reason(&self) -> Option { self.stop_reason.clone() } @@ -365,11 +399,16 @@ impl LanguageModelOptions { // Section: response types // ============================================================================ +/// The different types of content that can be generated by a language model. #[derive(Debug, Clone)] pub enum LanguageModelResponseContentType { + /// Plain text response. Text(String), + /// A tool call request. ToolCall(ToolCallInfo), + /// Reasoning or thinking content. Reasoning(String), + /// Feature not supported by the provider. NotSupported(String), } @@ -386,16 +425,22 @@ impl From for LanguageModelResponseContentType { } impl LanguageModelResponseContentType { + /// Creates a new text content type. pub fn new(text: impl Into) -> Self { Self::Text(text.into()) } } +/// Token usage statistics for a language model operation. #[derive(Default, Debug, Clone, PartialEq)] pub struct Usage { + /// Number of input tokens processed. pub input_tokens: Option, + /// Number of output tokens generated. pub output_tokens: Option, + /// Number of tokens used for reasoning. pub reasoning_tokens: Option, + /// Number of cached tokens reused. pub cached_tokens: Option, } @@ -423,7 +468,7 @@ pub struct LanguageModelResponse { } impl LanguageModelResponse { - /// Creates a new response with the generated text. + /// Creates a new response with the given text content. pub fn new(text: impl Into) -> Self { Self { contents: vec![LanguageModelResponseContentType::new(text.into())], @@ -432,29 +477,32 @@ impl LanguageModelResponse { } } +/// Types of chunks that can be emitted during streaming text generation. #[derive(Default, Debug, Clone)] pub enum LanguageModelStreamChunkType { - /// The model has started generating text. + /// Indicates the start of generation. #[default] Start, - /// Text chunk + /// A chunk of generated text. Text(String), - /// Tool call argument chunk + /// A chunk of tool call data. ToolCall(String), - /// The model has stopped generating text successfully. + /// Successful completion of generation. End(AssistantMessage), - /// The model has failed to generate text. error specified by - /// the language model + /// Generation failed with an error message. Failed(String), - /// The model finsished generating text with incomplete response. + /// Generation ended with an incomplete response. Incomplete(String), - /// Return this for unimplemented features for a specific model. + /// Feature not supported by the provider. NotSupported(String), } +/// A chunk of data from a streaming language model response. #[derive(Debug, Clone)] pub enum LanguageModelStreamChunk { + /// An incremental update during streaming. Delta(LanguageModelStreamChunkType), + /// The final result when streaming is complete. Done(AssistantMessage), } @@ -462,13 +510,13 @@ pub enum LanguageModelStreamChunk { pub(crate) type ProviderStream = Pin>> + Send>>; -// A mapping of `ProviderStream` to a channel like stream. +/// A stream wrapper that provides a channel-based interface for language model streaming. pub struct LanguageModelStream { receiver: Receiver, } impl LanguageModelStream { - // Creates a new MpmcStream with an associated Sender + /// Creates a new stream with an associated sender for pushing chunks. pub fn new() -> (Sender, LanguageModelStream) { let (tx, rx) = mpsc::channel(); (tx, LanguageModelStream { receiver: rx }) @@ -487,29 +535,31 @@ impl Stream for LanguageModelStream { } } +/// Reasons why text generation might stop. #[derive(Debug, Clone, PartialEq, Default)] pub enum StopReason { + /// Generation completed successfully. #[default] - // The model has finished generating text Finish, - // Provider specific reasons like timeout, rate limit etc + /// Provider-specific stop reason (e.g., timeout, rate limit). Provider(String), - // The user has explicitly provided a hook causing to stop + /// The user has explicitly provided a hook causing to stop Hook, - // Problematic errors. Providers specific errors can be accessed - // through `Error::ProviderError` + /// Stopped due to an error. Error(Error), - // Anything that is not supported by the above reasons + /// Other unspecified reason. Other(String), } -// will be converted to the appropriate level of reasoning -// for a language model +/// Levels of reasoning effort for language models that support it. #[derive(Debug, Clone, Copy, Default)] pub enum ReasoningEffort { + /// Low reasoning effort. #[default] Low, + /// Medium reasoning effort. Medium, + /// High reasoning effort. High, } diff --git a/src/core/language_model/request.rs b/src/core/language_model/request.rs index c26b373..410ff16 100644 --- a/src/core/language_model/request.rs +++ b/src/core/language_model/request.rs @@ -1,9 +1,8 @@ -//! Defines the central `LanguageModel` trait for interacting with text-based AI models. +//! Defines the `LanguageModelRequest` struct and its builder for configuring text generation requests. //! -//! This module provides the `LanguageModel` trait, which establishes the core -//! contract for all language models supported by the SDK. It abstracts the -//! underlying implementation details of different AI providers, offering a -//! unified interface for various operations like text generation or streaming. +//! This module provides the `LanguageModelRequest` type, which encapsulates a language model +//! and options for generating text or streaming responses. It includes a type-state builder +//! pattern to ensure requests are constructed correctly and safely. use crate::core::Message; use crate::core::language_model::{LanguageModel, LanguageModelOptions}; @@ -13,21 +12,54 @@ use std::fmt::Debug; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -/// Options for text generation requests such as `generate_text` and `stream_text`. +/// A request for text generation or streaming with a language model. +/// +/// This struct holds the model instance and configuration options needed to perform +/// text generation operations. It is typically constructed using the builder pattern +/// to ensure all required fields are set. +/// +/// # Type Parameters +/// +/// * `M` - The language model type that implements the [`LanguageModel`] trait. +/// +/// # Fields +/// +/// * `model` - The language model instance to use for the request. +/// * `system` - An optional system prompt. Mutually exclusive with messages. +/// * `prompt` - A prompt for the model. Mutually exclusive with messages. +/// * `messages` - A vector of [`Message`] instances. Mutually exclusive with prompt and system. +/// * `options` - Additional configuration options for the request. See [`LanguageModelOptions`]. +/// +/// # Examples +/// +/// ```rust +/// use aisdk::core::{LanguageModelRequest}; +/// use aisdk::providers::openai::OpenAI; +/// +/// let request = LanguageModelRequest::builder() +/// .model(OpenAI::new("gpt-4")) +/// .prompt("Hello, world!") +/// .build(); +/// ``` #[derive(Debug, Clone)] pub struct LanguageModelRequest { - /// The Language Model to use. + /// The language model to use for text generation. pub model: M, - /// The prompt to generate text from. - /// Only one of prompt or messages should be set. + /// An optional simple text prompt for the request. + /// + /// This should not be set if `messages` are provided in the options. pub prompt: Option, - /// Language model call options for the request + /// Configuration options for the language model request. pub(crate) options: LanguageModelOptions, } impl LanguageModelRequest { + /// Creates a new builder for constructing a `LanguageModelRequest`. + /// + /// This method initiates the type-state builder pattern, starting with the + /// [`ModelStage`] where you must specify the language model. pub fn builder() -> LanguageModelRequestBuilder { LanguageModelRequestBuilder::default() } @@ -47,25 +79,41 @@ impl DerefMut for LanguageModelRequest { } } -// State for GenerateOptionsBuilder -// Following the type State builder pattern - -/// Initial state for setting the model -/// returns SystemStage +/// Type-state markers for the `LanguageModelRequestBuilder`. +/// +/// These zero-sized types ensure the builder is used in the correct order, +/// preventing invalid request configurations at compile time. +/// +/// The initial builder state where the language model must be set. +/// +/// Transitions to [`SystemStage`] after calling [`model`](LanguageModelRequestBuilder::model). pub struct ModelStage {} -/// Secondary state for including system prompt or not -/// returns ConversationStage +/// The state after setting the model, where a system prompt can be optionally added. +/// +/// Transitions to [`ConversationStage`] after calling [`system`](LanguageModelRequestBuilder::system), +/// or directly to [`OptionsStage`] after calling [`prompt`](LanguageModelRequestBuilder::prompt) or [`messages`](LanguageModelRequestBuilder::messages). pub struct SystemStage {} -/// Third state for conversation, Message or Prompt -/// returns OptionsStage +/// The state after optionally setting a system prompt, where conversation input must be provided. +/// +/// Transitions to [`OptionsStage`] after calling [`prompt`](LanguageModelRequestBuilder::prompt) or [`messages`](LanguageModelRequestBuilder::messages). pub struct ConversationStage {} -/// Final State for setting Options and config -/// returns builder.build +/// The final state where additional options can be configured before building. +/// +/// Transitions to the completed `LanguageModelRequest` after calling [`build`](LanguageModelRequestBuilder::build). pub struct OptionsStage {} +/// A type-state builder for constructing `LanguageModelRequest` instances. +/// +/// This builder uses phantom types to enforce a specific construction order, +/// ensuring that required fields (like the model) are set before optional ones. +/// +/// # Type Parameters +/// +/// * `M` - The language model type. +/// * `State` - The current builder state, determining available methods. pub struct LanguageModelRequestBuilder { model: Option, prompt: Option, @@ -76,12 +124,18 @@ pub struct LanguageModelRequestBuilder { impl Deref for LanguageModelRequestBuilder { type Target = LanguageModelOptions; + /// Dereferences to the underlying `LanguageModelOptions`. + /// + /// This allows direct access to the options fields during building. fn deref(&self) -> &Self::Target { &self.options } } impl DerefMut for LanguageModelRequestBuilder { + /// Mutably dereferences to the underlying `LanguageModelOptions`. + /// + /// This allows direct mutation of the options fields during building. fn deref_mut(&mut self) -> &mut Self::Target { &mut self.options } @@ -98,8 +152,19 @@ impl LanguageModelRequestBuilder { } } -/// ModelStage Builder +/// Methods available in the [`ModelStage`] state. impl LanguageModelRequestBuilder { + /// Sets the language model for the request. + /// + /// This is the first required step in building a request. + /// + /// # Parameters + /// + /// * `model` - The language model instance to use. + /// + /// # Returns + /// + /// The builder in the [`SystemStage`] state. pub fn model(self, model: M) -> LanguageModelRequestBuilder { LanguageModelRequestBuilder { model: Some(model), @@ -110,8 +175,19 @@ impl LanguageModelRequestBuilder { } } -/// SystemStage Builder +/// Methods available in the [`SystemStage`] state. impl LanguageModelRequestBuilder { + /// Sets an optional system prompt for the request. + /// + /// The system prompt provides context or instructions to the model. + /// + /// # Parameters + /// + /// * `system` - The system prompt text. + /// + /// # Returns + /// + /// The builder in the [`ConversationStage`] state. pub fn system( self, system: impl Into, @@ -127,6 +203,17 @@ impl LanguageModelRequestBuilder { } } + /// Sets a simple text prompt for the request. + /// + /// This skips the system prompt and goes directly to options. + /// + /// # Parameters + /// + /// * `prompt` - The user prompt text. + /// + /// # Returns + /// + /// The builder in the [`OptionsStage`] state. pub fn prompt(self, prompt: impl Into) -> LanguageModelRequestBuilder { LanguageModelRequestBuilder { model: self.model, @@ -136,6 +223,17 @@ impl LanguageModelRequestBuilder { } } + /// Sets conversation messages for the request. + /// + /// This allows for multi-turn conversations with the model. + /// + /// # Parameters + /// + /// * `messages` - A vector of [`Message`] instances representing the conversation. + /// + /// # Returns + /// + /// The builder in the [`OptionsStage`] state. pub fn messages(self, messages: Vec) -> LanguageModelRequestBuilder { LanguageModelRequestBuilder { model: self.model, @@ -149,8 +247,17 @@ impl LanguageModelRequestBuilder { } } -/// ConversationStage Builder +/// Methods available in the [`ConversationStage`] state. impl LanguageModelRequestBuilder { + /// Sets a simple text prompt for the request. + /// + /// # Parameters + /// + /// * `prompt` - The user prompt text. + /// + /// # Returns + /// + /// The builder in the [`OptionsStage`] state. pub fn prompt(self, prompt: impl Into) -> LanguageModelRequestBuilder { LanguageModelRequestBuilder { model: self.model, @@ -160,6 +267,15 @@ impl LanguageModelRequestBuilder { } } + /// Sets conversation messages for the request. + /// + /// # Parameters + /// + /// * `messages` - A vector of [`Message`] instances. + /// + /// # Returns + /// + /// The builder in the [`OptionsStage`] state. pub fn messages(self, messages: Vec) -> LanguageModelRequestBuilder { LanguageModelRequestBuilder { model: self.model, @@ -172,52 +288,148 @@ impl LanguageModelRequestBuilder { } } } -/// OptionsStage Builder + +/// Methods available in the [`OptionsStage`] state. impl LanguageModelRequestBuilder { + /// Sets a JSON schema for structured output. + /// + /// The model will generate output conforming to the schema of type `T`. + /// + /// # Type Parameters + /// + /// * `T` - A type that implements [`JsonSchema`]. + /// + /// # Returns + /// + /// The builder with the schema set. pub fn schema(mut self) -> Self { self.schema = Some(schema_for!(T)); self } + + /// Sets a seed for deterministic generation. + /// + /// # Parameters + /// + /// * `seed` - The random seed value. + /// + /// # Returns + /// + /// The builder with the seed set. pub fn seed(mut self, seed: impl Into) -> Self { self.seed = Some(seed.into()); self } + /// Sets the temperature for generation randomness (0-100, scaled to 0.0-1.0). + /// + /// Higher values increase creativity, lower values increase determinism. + /// + /// # Parameters + /// + /// * `temperature` - The temperature value (0-100). + /// + /// # Returns + /// + /// The builder with the temperature set. pub fn temperature(mut self, temperature: impl Into) -> Self { self.temperature = Some(temperature.into()); self } + /// Sets the top-p (nucleus) sampling parameter (0-100, scaled to 0.0-1.0). + /// + /// # Parameters + /// + /// * `top_p` - The top-p value (0-100). + /// + /// # Returns + /// + /// The builder with top-p set. pub fn top_p(mut self, top_p: impl Into) -> Self { self.top_p = Some(top_p.into()); self } + /// Sets the top-k sampling parameter. + /// + /// # Parameters + /// + /// * `top_k` - The top-k value. + /// + /// # Returns + /// + /// The builder with top-k set. pub fn top_k(mut self, top_k: impl Into) -> Self { self.top_k = Some(top_k.into()); self } + /// Sets stop sequences that halt generation. + /// + /// # Parameters + /// + /// * `stop_sequences` - A list of strings that stop generation when encountered. + /// + /// # Returns + /// + /// The builder with stop sequences set. pub fn stop_sequences(mut self, stop_sequences: impl Into>) -> Self { self.stop_sequences = Some(stop_sequences.into()); self } + /// Sets the maximum number of retries for failed requests. + /// + /// # Parameters + /// + /// * `max_retries` - The maximum retry count. + /// + /// # Returns + /// + /// The builder with max retries set. pub fn max_retries(mut self, max_retries: impl Into) -> Self { self.max_retries = Some(max_retries.into()); self } + /// Sets the frequency penalty to reduce repetition. + /// + /// # Parameters + /// + /// * `frequency_penalty` - The penalty value. + /// + /// # Returns + /// + /// The builder with frequency penalty set. pub fn frequency_penalty(mut self, frequency_penalty: impl Into) -> Self { self.frequency_penalty = Some(frequency_penalty.into()); self } + /// Adds a tool for the model to use during generation. + /// + /// # Parameters + /// + /// * `tool` - The tool to add. + /// + /// # Returns + /// + /// The builder with the tool added. pub fn with_tool(mut self, tool: Tool) -> Self { self.tools.get_or_insert_default().add_tool(tool); self } + /// Sets a condition to stop the generation loop. + /// + /// # Parameters + /// + /// * `hook` - A function that returns `true` when generation should stop. + /// + /// # Returns + /// + /// The builder with the stop condition set. pub fn stop_when(mut self, hook: F) -> Self where F: Fn(&LanguageModelOptions) -> bool + Send + Sync + 'static, @@ -226,6 +438,15 @@ impl LanguageModelRequestBuilder { self } + /// Sets a hook to run at the start of each generation step. + /// + /// # Parameters + /// + /// * `hook` - A function called before each step. + /// + /// # Returns + /// + /// The builder with the hook set. pub fn on_step_start(mut self, hook: F) -> Self where F: Fn(&mut LanguageModelOptions) + Send + Sync + 'static, @@ -234,6 +455,15 @@ impl LanguageModelRequestBuilder { self } + /// Sets a hook to run at the end of each generation step. + /// + /// # Parameters + /// + /// * `hook` - A function called after each step. + /// + /// # Returns + /// + /// The builder with the hook set. pub fn on_step_finish(mut self, hook: F) -> Self where F: Fn(&LanguageModelOptions) + Send + Sync + 'static, @@ -242,6 +472,15 @@ impl LanguageModelRequestBuilder { self } + /// Sets the reasoning effort level. + /// + /// # Parameters + /// + /// * `reasoning_effort` - The effort level. + /// + /// # Returns + /// + /// The builder with reasoning effort set. pub fn reasoning_effort( mut self, reasoning_effort: impl Into, @@ -250,6 +489,13 @@ impl LanguageModelRequestBuilder { self } + /// Builds the `LanguageModelRequest`. + /// + /// This method consumes the builder and returns the configured request. + /// + /// # Returns + /// + /// The constructed `LanguageModelRequest`. pub fn build(self) -> LanguageModelRequest { let model = self .model diff --git a/src/core/language_model/stream_text.rs b/src/core/language_model/stream_text.rs index 1f55ba1..20a01ee 100644 --- a/src/core/language_model/stream_text.rs +++ b/src/core/language_model/stream_text.rs @@ -1,3 +1,5 @@ +//! Text Streaming impl for the `LanguageModelRequest` trait. + use crate::core::{ AssistantMessage, LanguageModelStreamChunkType, Message, language_model::{ @@ -12,12 +14,48 @@ use futures::StreamExt; use std::ops::Deref; impl LanguageModelRequest { - /// Generates Streaming text using a specified language model. + /// Streams text generation and tool execution using the language model. + /// + /// This method performs streaming text generation, providing real-time access to response chunks + /// as they are produced. It supports tool calling and execution in multiple steps, streaming + /// intermediate results and handling tool interactions dynamically. + /// + /// For non-streaming responses, use [`generate_text`](Self::generate_text) instead. + /// + /// # Returns + /// + /// A [`StreamTextResponse`] containing the stream of chunks and final conversation state. /// - /// Generate a text and call tools for a given prompt using a language model. - /// This function streams the output. If you do not want to stream the output, use `GenerateText` instead. + /// # Errors /// - /// Returns an `Error` if the underlying model fails to generate a response. + /// Returns an `Error` if the underlying language model fails to generate a response + /// or if tool execution encounters an error. + /// + /// # Examples + /// + /// ```rust,no_run + /// # use aisdk::core::LanguageModelRequest; + /// # use aisdk::core::language_model::LanguageModelStreamChunkType; + /// # use aisdk::providers::openai::OpenAI; + /// # use futures::StreamExt; + /// # async fn example() -> Result<(), Box> { + /// let mut request = LanguageModelRequest::builder() + /// .model(OpenAI::new("gpt-4")) + /// .prompt("Tell me a story") + /// .build(); + /// + /// let response = request.stream_text().await?; + /// while let Some(chunk) = response.stream.next().await { + /// match chunk { + /// LanguageModelStreamChunkType::Text(text) => { + /// print!("{}", text); + /// } + /// _ => {} + /// } + /// } + /// # Ok(()) + /// # } + /// ``` pub async fn stream_text(&mut self) -> Result { let (system_prompt, messages) = resolve_message(&self.options, &self.prompt); @@ -161,15 +199,21 @@ impl LanguageModelRequest { // Section: response types // ============================================================================ -// Response from a stream call on `StreamText`. +/// Response from a streaming text generation call. +/// +/// This struct contains the streaming response from a language model, +/// including the stream of chunks and the final options state. pub struct StreamTextResponse { - /// A stream of responses from the language model. + /// The stream of response chunks from the language model. pub stream: LanguageModelStream, - /// The reason the model stopped generating text. + /// The final options state after streaming completes. options: LanguageModelOptions, } impl StreamTextResponse { + /// Returns the step IDs of all messages in the conversation. + /// + /// This is primarily used for testing and debugging purposes. #[cfg(any(test, feature = "test-access"))] pub fn step_ids(&self) -> Vec { self.options.messages.iter().map(|t| t.step_id).collect() diff --git a/src/core/messages.rs b/src/core/messages.rs index 6fe1f51..673c72b 100644 --- a/src/core/messages.rs +++ b/src/core/messages.rs @@ -1,23 +1,33 @@ +//! Message types for the `aisdk` library. + use crate::core::{ - ToolCallInfo, ToolResultInfo, language_model::{LanguageModelResponseContentType, Usage}, + tools::{ToolCallInfo, ToolResultInfo}, }; -/// Role for model messages. +/// The role of a participant in a conversation. #[derive(Debug, Clone)] pub enum Role { + /// System-level instructions or context. System, + /// Human user input. User, + /// AI assistant response. Assistant, } -/// Message Type for model messages. +/// A message in a conversation with a language model. #[derive(Debug, Clone)] pub enum Message { + /// A system message providing context or instructions. System(SystemMessage), + /// A user message containing input from the human. User(UserMessage), + /// An assistant message containing the model's response. Assistant(AssistantMessage), + /// A tool result message from executing a tool call. Tool(ToolResultInfo), + /// A developer-specific message for advanced use cases. Developer(String), } @@ -63,13 +73,15 @@ impl Message { } } -/// System message. +/// A system message that provides context or instructions to the model. #[derive(Debug, Clone)] pub struct SystemMessage { + /// The text content of the system message. pub content: String, } impl SystemMessage { + /// Creates a new system message with the given content. pub fn new(content: impl Into) -> Self { Self { content: content.into(), @@ -89,13 +101,15 @@ impl From<&str> for SystemMessage { } } -/// User message. +/// A user message containing input from the human participant. #[derive(Debug, Clone)] pub struct UserMessage { + /// The text content of the user message. pub content: String, } impl UserMessage { + /// Creates a new user message with the given content. pub fn new(content: impl Into) -> Self { Self { content: content.into(), @@ -115,14 +129,12 @@ impl From<&str> for UserMessage { } } -/// Assistant model message. +/// A message generated by the language model assistant. #[derive(Default, Debug, Clone)] -/// Message generated by the language model. wraps a `LanguageModelResponseContentType` -/// and adds additional metadata pub struct AssistantMessage { - /// The different types of language model responses (supports multiple) + /// The content of the assistant's response. pub content: LanguageModelResponseContentType, - /// usage detials + /// Optional usage statistics for the response. pub usage: Option, } @@ -136,22 +148,28 @@ impl From for AssistantMessage { } impl AssistantMessage { + /// Creates a new assistant message with the given content and usage. pub fn new(content: LanguageModelResponseContentType, usage: Option) -> Self { Self { content, usage } } } -/// Message State for type safe message list construction. -/// Initial state for initial message builder with either system or user message. +/// Type-state marker for the initial message builder state. +/// +/// In this state, the first message must be either a system prompt or a user message. #[derive(Debug, Clone)] pub struct Initial; -/// Message State for type safe message list construction. -/// Conversation state is used for only user and assistant message builder. +/// Type-state marker for the conversation message builder state. +/// +/// In this state, only user and assistant messages can be added. #[derive(Debug, Clone)] pub struct Conversation; -/// Message Builder with state for type safe message list construction. +/// A type-state builder for constructing message lists safely. +/// +/// This builder ensures that messages are added in a valid order, +/// preventing invalid conversation structures. #[derive(Debug, Clone)] pub struct MessageBuilder { messages: Vec, @@ -159,6 +177,9 @@ pub struct MessageBuilder { } impl MessageBuilder { + /// Creates a new message builder starting in the conversation state. + /// + /// This allows building conversations without requiring an initial system or user message. pub fn conversation_builder() -> MessageBuilder { MessageBuilder { messages: Vec::new(), @@ -177,12 +198,22 @@ impl Default for MessageBuilder { } impl MessageBuilder { + /// Builds the message list. pub fn build(self) -> Vec { self.messages } } impl MessageBuilder { + /// Adds a system message and transitions to the conversation state. + /// + /// # Parameters + /// + /// * `content` - The system message content. + /// + /// # Returns + /// + /// The builder in the conversation state. pub fn system(mut self, content: impl Into) -> MessageBuilder { self.messages.push(Message::System(content.into().into())); MessageBuilder { @@ -191,6 +222,15 @@ impl MessageBuilder { } } + /// Adds a user message and transitions to the conversation state. + /// + /// # Parameters + /// + /// * `content` - The user message content. + /// + /// # Returns + /// + /// The builder in the conversation state. pub fn user(mut self, content: impl Into) -> MessageBuilder { self.messages.push(Message::User(content.into().into())); MessageBuilder { @@ -201,6 +241,15 @@ impl MessageBuilder { } impl MessageBuilder { + /// Adds a user message to the conversation. + /// + /// # Parameters + /// + /// * `content` - The user message content. + /// + /// # Returns + /// + /// The builder with the message added. pub fn user(mut self, content: impl Into) -> MessageBuilder { self.messages.push(Message::User(content.into().into())); MessageBuilder { @@ -208,11 +257,18 @@ impl MessageBuilder { state: std::marker::PhantomData, } } + + /// Adds an assistant message to the conversation. + /// + /// # Parameters + /// + /// * `content` - The assistant message content. + /// + /// # Returns + /// + /// The builder with the message added. pub fn assistant(mut self, content: impl Into) -> MessageBuilder { - self.messages - // no need for usage as this method is supposed to be called by the user - // and there is no usage data. - .push(Message::Assistant(content.into().into())); + self.messages.push(Message::Assistant(content.into().into())); MessageBuilder { messages: self.messages, state: std::marker::PhantomData, diff --git a/src/core/mod.rs b/src/core/mod.rs index ee44da7..4f51a15 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -10,7 +10,7 @@ pub mod language_model; pub mod messages; pub mod provider; -pub mod tools; +pub(crate) mod tools; pub mod utils; pub use aisdk_macros::tool; @@ -23,4 +23,4 @@ pub use language_model::{ pub use messages::{AssistantMessage, Message, Role, SystemMessage, UserMessage}; pub use provider::Provider; -pub use tools::{Tool, ToolCallInfo, ToolResultInfo}; +pub use tools::{Tool, ToolExecute}; diff --git a/src/core/tools.rs b/src/core/tools.rs index 9200e15..b2dd944 100644 --- a/src/core/tools.rs +++ b/src/core/tools.rs @@ -1,3 +1,91 @@ +//! Tools are a way to extend the capabilities of a language model. aisdk provides a +//! macro to simplify the process of defining and registering tools. This module provides +//! The necessary types and functions for defining and using tools both by the macro and +//! by the user. +//! +//! The Tool struct is the core component of a tool. It contains the `name`, `description`, +//! and `input_schema` of the tool as well as the logic to execute. The `execute` +//! method is the main entry point for executing the tool. The language model is responsible +//! for calling this method using `input_schema` to generate the arguments for the tool. +//! +//! +//! The tool macro generates the necessary code for registering the tool with the SDK. +//! It infers the necessary fields for the Tool struct from a valid rust function. +//! +//! # Example +//! ``` +//! use aisdk::core::Tool; +//! use aisdk_macros::tool; +//! +//! #[tool] +//! /// Adds two numbers together. +//! pub fn sum(a: u8, b: u8) -> Tool { +//! Ok(format!("{}", a + b)) +//! } +//! +//! let tool: Tool = get_weather(); +//! +//! assert_eq!(tool.name, "get_weather"); +//! assert_eq!(tool.description, "Adds two numbers together."); +//! assert_eq!(tool.input_schema.to_value(), serde_json::json!({ +//! "type": "object", +//! "required": ["a", "b"], +//! "properties": { +//! "a": { +//! "type": "integer", +//! "format": "uint8", +//! "minimum": 0, +//! "maximum": 255 +//! }, +//! "b": { +//! "type": "integer", +//! "format": "uint8", +//! "minimum": 0, +//! "maximum": 255 +//! } +//! } +//! })); +//! +//! +//! ``` +//! +//! # Example with struct +//! +//! ```rust +//! use aisdk::core::{Tool, ToolExecute}; +//! use serde_json::Value; +//! +//! let tool: Tool = Tool { +//! name: "sum".to_string(), +//! description: "Adds two numbers together.".to_string(), +//! input_schema: serde_json::json!({ +//! "type": "object", +//! "required": ["a", "b"], +//! "properties": { +//! "a": { +//! "type": "integer", +//! "format": "uint8", +//! "minimum": 0, +//! "maximum": 255 +//! }, +//! "b": { +//! "type": "integer", +//! "format": "uint8", +//! "minimum": 0, +//! "maximum": 255 +//! } +//! } +//! }), +//! execute: +//! ToolExecute::new(Box::new(|params: Value| { +//! let a = params["a"].as_u64().unwrap(); +//! let b = params["b"].as_u64().unwrap(); +//! Ok(format!("{}", a + b)) +//! })), +//! }; +//! ``` +//! + use crate::error::{Error, Result}; use derive_builder::Builder; use schemars::Schema; @@ -9,16 +97,22 @@ use tokio::task::JoinHandle; pub type ToolFn = Box std::result::Result + Send + Sync>; +/// Holds the function that will be called when the tool is executed. the function +/// should take a single argument of type `Value` and returns a +/// `Result`. #[derive(Clone)] pub struct ToolExecute { inner: Arc, } impl ToolExecute { - pub fn call(&self, map: Value) -> Result { + pub(crate) fn call(&self, map: Value) -> Result { (*self.inner)(map).map_err(Error::ToolCallError) } + /// Creates a new `ToolExecute` instance with the given function. + /// The function should take a single argument of type `Value` and return a + /// `Result`. pub fn new(f: ToolFn) -> Self { Self { inner: Arc::new(f) } } @@ -48,6 +142,74 @@ impl<'de> Deserialize<'de> for ToolExecute { } } +/// The `Tool` struct represents a tool that can be executed by a language model. +/// It contains the name, description, input schema, and execution logic of the tool. +/// The `execute` method is the main entry point for executing the tool and is called. +/// by the language model. +/// +/// `name` and `description` help the model identify and understand the tool. `input_schema` +/// defines the structure of the input data that the tool expects. `Schema` is a type from +/// the [`schemars`](https://docs.rs/schemars/latest/schemars/) crate that can be used to +/// define the input schema. +/// +/// The execute method is responsible for executing the tool and returning the result to +/// the language model. It takes a single argument of type `Value` and returns a +/// `Result`. +/// +/// # Example +/// ``` +/// use aisdk::core::Tool; +/// use aisdk_macros::tool; +/// +/// let tool: Tool = Tool { +/// name: "sum".to_string(), +/// description: "Adds two numbers together.".to_string(), +/// input_schema: serde_json::json!({ +/// "type": "object", +/// "required": ["a", "b"], +/// "properties": { +/// "a": { +/// "type": "integer", +/// "format": "uint8", +/// "minimum": 0, +/// "maximum": 255 +/// }, +/// "b": { +/// "type": "integer", +/// "format": "uint8", +/// "minimum": 0, +/// "maximum": 255 +/// } +/// } +/// }), +/// execute: ToolExecute::new(Box::new(|params| { +/// let a = params["a"].as_u64().unwrap(); +/// let b = params["b"].as_u64().unwrap(); +/// Ok(format!("{}", a + b)) +/// })), +/// }; +/// +/// assert_eq!(tool.name, "sum"); +/// assert_eq!(tool.description, "Adds two numbers together."); +/// assert_eq!(tool.input_schema.to_value(), serde_json::json!({ +/// "type": "object", +/// "required": ["a", "b"], +/// "properties": { +/// "a": { +/// "type": "integer", +/// "format": "uint8", +/// "minimum": 0, +/// "maximum": 255 +/// }, +/// "b": { +/// "type": "integer", +/// "format": "uint8", +/// "minimum": 0, +/// "maximum": 255 +/// } +/// } +/// })); +/// ``` #[derive(Builder, Clone, Default)] #[builder(pattern = "owned", setter(into), build_fn(error = "Error"))] pub struct Tool { @@ -71,6 +233,7 @@ impl Debug for Tool { } impl Tool { + /// Creates a new `Tool` instance with default values. pub fn new() -> Self { Self { name: "".to_string(), @@ -121,20 +284,23 @@ impl ToolList { #[derive(Default, Debug, Clone, PartialEq)] /// Describes a tool pub struct ToolDetails { - // the name of the tool, usually a function name. + /// The name of the tool, usually a function name. pub name: String, - // uniquely identifies a tool, provided by the LLM. + /// Uniquely identifies a tool, usually provided by the provider. pub id: String, } /// Contains information necessary to call a tool #[derive(Default, Debug, Clone, PartialEq)] pub struct ToolCallInfo { + /// The details of the tool to be called. pub tool: ToolDetails, + /// The input parameters for the tool. pub input: serde_json::Value, } impl ToolCallInfo { + /// Creates a new `ToolCallInfo` instance with the given name. pub fn new(name: impl Into) -> Self { Self { tool: ToolDetails { diff --git a/src/core/utils.rs b/src/core/utils.rs index 7515a6c..5dc493f 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -1,6 +1,42 @@ +//! Utility functions for the `aisdk` library. + use crate::core::{Message, language_model::LanguageModelOptions, messages::TaggedMessage}; -/// Returns true if the number of steps is equal to the provided step. +/// Creates a hook that returns `true` if the number of conversation steps exceeds the given count. +/// +/// This function is useful for stopping generation after a certain number of tool-calling +/// iterations or conversation turns. The returned closure can be used with the `stop_when` +/// option in `LanguageModelOptions`. +/// +/// # Parameters +/// +/// * `step` - The step count threshold. Returns `true` when the actual step count is greater than this value. +/// +/// # Returns +/// +/// A closure that takes `&LanguageModelOptions` and returns `bool`. +/// +/// # Examples +/// +/// ```rust,no_run +/// use aisdk::core::{LanguageModelRequest, utils::step_count_is}; +/// use aisdk::providers::openai::OpenAI; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let result = LanguageModelRequest::builder() +/// .model(OpenAI::new("gpt-4o")) +/// .system("You are a helpful assistant with access to tools.") +/// .prompt("What is the weather in New York?") +/// .stop_when(step_count_is(3)) // Limit agent loop to 3 steps +/// .build() +/// .generate_text() +/// .await?; +/// +/// println!("{}", result.text().unwrap()); +/// Ok(()) +/// } +/// ``` pub fn step_count_is(step: usize) -> impl Fn(&LanguageModelOptions) -> bool { move |options| options.steps().len() > step } diff --git a/src/error.rs b/src/error.rs index fd50ea9..32109c3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -45,6 +45,8 @@ pub enum Error { #[error("Invalid input: {0}")] InvalidInput(String), + /// An error related to tool execution. This includes errors caused by the + /// tool itself as well by the SDK when interacting with the tool. #[error("Tool error: {0}")] ToolCallError(String), diff --git a/src/lib.rs b/src/lib.rs index ff06584..1a482f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,7 @@ +#![deny(missing_docs)] + +//! `aisdk` is An open-source Rust library for building AI-powered applications, inspired by the Vercel AI SDK. It provides a type-safe interface for interacting with Large Language Models (LLMs). + pub mod core; pub mod error; #[cfg(feature = "prompt")] diff --git a/src/providers/anthropic/mod.rs b/src/providers/anthropic/mod.rs index 549b786..7b0164f 100644 --- a/src/providers/anthropic/mod.rs +++ b/src/providers/anthropic/mod.rs @@ -12,7 +12,7 @@ use crate::core::language_model::{ }; use crate::core::messages::AssistantMessage; use crate::core::tools::ToolDetails; -use crate::core::{LanguageModelStreamChunkType, ToolCallInfo}; +use crate::core::{LanguageModelStreamChunkType, tools::ToolCallInfo}; use crate::error::ProviderError; use crate::providers::anthropic::client::{ AnthropicClient, AnthropicDelta, AnthropicError, AnthropicMessageDeltaUsage, diff --git a/src/providers/openai/conversions.rs b/src/providers/openai/conversions.rs index cc5059b..3b4d331 100644 --- a/src/providers/openai/conversions.rs +++ b/src/providers/openai/conversions.rs @@ -1,10 +1,10 @@ //! Helper functions and conversions for the OpenAI provider. +use crate::core::Tool; use crate::core::language_model::{ LanguageModelOptions, LanguageModelResponseContentType, ReasoningEffort, Usage, }; use crate::core::messages::Message; -use crate::core::tools::Tool; use async_openai::types::responses::{ CreateResponse, Function, Input, InputContent, InputItem, InputMessage, InputMessageType, ReasoningConfig, ReasoningSummary, Role, TextConfig, TextResponseFormat, ToolDefinition,