diff --git a/DESCRIPTION b/DESCRIPTION index 504841e22..f7da45d82 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -86,6 +86,7 @@ Collate: 'import-standalone-types-check.R' 'interpolate.R' 'live.R' + 'model.R' 'parallel-chat.R' 'params.R' 'provider-anthropic.R' diff --git a/NAMESPACE b/NAMESPACE index 8750d8f34..8a470a7ac 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -16,6 +16,7 @@ export(ContentText) export(ContentThinking) export(ContentToolRequest) export(ContentToolResult) +export(Model) export(Provider) export(SystemTurn) export(Turn) diff --git a/R/batch-chat.R b/R/batch-chat.R index 0f947421a..f5019ef12 100644 --- a/R/batch-chat.R +++ b/R/batch-chat.R @@ -193,6 +193,7 @@ BatchJob <- R6::R6Class( # Internal state provider = NULL, + model = NULL, started_at = NULL, stage = NULL, batch = NULL, @@ -208,6 +209,7 @@ BatchJob <- R6::R6Class( call = caller_env(2) ) { self$provider <- chat$get_provider() + self$model <- chat$get_model_obj() check_has_batch_support(self$provider, call = call) user_turns <- as_user_turns(prompts, call = call) @@ -326,7 +328,7 @@ BatchJob <- R6::R6Class( retrieve = function() { self$results <- batch_retrieve(self$provider, self$batch) - log_turns(self$provider, self$result_turns()) + log_turns(self$provider, self$model, self$result_turns()) self$stage <- "done" self$save_state() @@ -335,7 +337,12 @@ BatchJob <- R6::R6Class( result_turns = function() { map2(self$results, self$user_turns, function(result, user_turn) { - batch_result_turn(self$provider, result, has_type = !is.null(self$type)) + batch_result_turn( + self$provider, + self$model, + result, + has_type = !is.null(self$type) + ) }) }, @@ -379,7 +386,6 @@ BatchJob <- R6::R6Class( provider_hash <- function(x) { list( name = x@name, - model = x@model, base_url = x@base_url ) } diff --git a/R/chat.R b/R/chat.R index 0e7da10ba..a863d037b 100644 --- a/R/chat.R +++ b/R/chat.R @@ -24,6 +24,7 @@ Chat <- R6::R6Class( "Chat", public = list( #' @param provider A provider object. + #' @param model A model object. #' @param system_prompt System prompt to start the conversation with. #' @param echo One of the following options: #' * `none`: don't emit any output (default when running in a function). @@ -33,8 +34,14 @@ Chat <- R6::R6Class( #' #' Note this only affects the `chat()` method. You can override the default #' by setting the `ellmer_echo` option. - initialize = function(provider, system_prompt = NULL, echo = "none") { + initialize = function( + provider, + model, + system_prompt = NULL, + echo = "none" + ) { private$provider <- provider + private$model <- model private$echo <- echo private$callback_on_tool_request <- CallbackManager$new(args = "request") private$callback_on_tool_result <- CallbackManager$new(args = "result") @@ -78,7 +85,7 @@ Chat <- R6::R6Class( check_turn(assistant) if (log_tokens) { - log_turn(private$provider, assistant) + log_turn(private$provider, private$model, assistant) } private$.turns[[length(private$.turns) + 1]] <- user @@ -97,7 +104,7 @@ Chat <- R6::R6Class( #' @description Retrieve the model name get_model = function() { - private$provider@model + private$model@name }, #' @description Update the system prompt @@ -381,6 +388,11 @@ Chat <- R6::R6Class( private$provider }, + #' @description Get the underlying model object. For expert use only. + get_model_obj = function() { + private$model + }, + #' @description Retrieve the list of registered tools. get_tools = function() { private$tools @@ -422,6 +434,7 @@ Chat <- R6::R6Class( ), private = list( provider = NULL, + model = NULL, .turns = list(), echo = NULL, @@ -579,6 +592,7 @@ Chat <- R6::R6Class( response <- chat_perform( provider = private$provider, + model = private$model, mode = if (stream) "stream" else "value", turns = c(private$.turns, list(user_turn)), tools = if (is.null(type)) private$tools, @@ -603,11 +617,17 @@ Chat <- R6::R6Class( result <- stream_merge_chunks(private$provider, result, chunk) } - turn <- value_turn(private$provider, result, has_type = !is.null(type)) + turn <- value_turn( + private$provider, + private$model, + result, + has_type = !is.null(type) + ) turn <- match_tools(turn, private$tools) } else { turn <- value_turn( private$provider, + private$model, resp_body_json(response), has_type = !is.null(type) ) @@ -661,6 +681,7 @@ Chat <- R6::R6Class( ) { response <- chat_perform( provider = private$provider, + model = private$model, mode = if (stream) "async-stream" else "async-value", turns = c(private$.turns, list(user_turn)), tools = if (is.null(type)) private$tools, @@ -685,12 +706,18 @@ Chat <- R6::R6Class( result <- stream_merge_chunks(private$provider, result, chunk) } - turn <- value_turn(private$provider, result, has_type = !is.null(type)) + turn <- value_turn( + private$provider, + private$model, + result, + has_type = !is.null(type) + ) } else { result <- await(response) turn <- value_turn( private$provider, + private$model, resp_body_json(result), has_type = !is.null(type) ) @@ -738,6 +765,7 @@ Chat <- R6::R6Class( #' @export print.Chat <- function(x, ...) { provider <- x$get_provider() + model <- x$get_model() turns <- x$get_turns(include_system_prompt = TRUE) assistant_turns <- keep(turns, \(x) x@role == "assistant") @@ -746,7 +774,7 @@ print.Chat <- function(x, ...) { cat(paste_c( "\n" diff --git a/R/httr2.R b/R/httr2.R index 7a20e2839..1613b7747 100644 --- a/R/httr2.R +++ b/R/httr2.R @@ -3,6 +3,7 @@ # We will recconsider this in the future if necessary. chat_perform <- function( provider, + model, mode = c("value", "stream", "async-stream", "async-value"), turns, tools = NULL, @@ -14,6 +15,7 @@ chat_perform <- function( req <- chat_request( provider = provider, + model = model, turns = turns, tools = tools, stream = stream, diff --git a/R/model.R b/R/model.R new file mode 100644 index 000000000..56c5dde8b --- /dev/null +++ b/R/model.R @@ -0,0 +1,29 @@ +#' A chatbot model +#' +#' A Model captures the details of a specific language model and its +#' configuration parameters. This includes the model name, parameters like +#' temperature and max_tokens, and any extra arguments to include in API +#' requests. +#' +#' Model objects are typically created internally by [chat_openai()], +#' [chat_anthropic()], and other `chat_*()` functions. You generally won't +#' need to create them directly unless you're implementing a custom provider. +#' +#' @export +#' @param name Name of the model (e.g., "gpt-4", "claude-sonnet-4"). +#' @param params A list of standard parameters created by [params()]. +#' @param extra_args Arbitrary extra arguments to be included in the request body. +#' @return An S7 Model object. +#' @examples +#' Model( +#' name = "gpt-4", +#' params = params(temperature = 0.7) +#' ) +Model <- new_class( + "Model", + properties = list( + name = prop_string(), + params = class_list, + extra_args = class_list + ) +) diff --git a/R/parallel-chat.R b/R/parallel-chat.R index 3b232f677..f927896b0 100644 --- a/R/parallel-chat.R +++ b/R/parallel-chat.R @@ -78,6 +78,7 @@ parallel_chat <- function( my_parallel_turns <- function(conversations) { parallel_turns( provider = chat$get_provider(), + model = chat$get_model_obj(), conversations = conversations, tools = chat$get_tools(), max_active = max_active, @@ -127,7 +128,7 @@ parallel_chat <- function( map(seq_along(conversations), function(i) { if (is_ok[[i]]) { turns <- conversations[[i]] - log_turns(chat$get_provider(), turns) + log_turns(chat$get_provider(), chat$get_model_obj(), turns) chat$clone()$set_turns(turns) } else { assistant_turns[[i]] @@ -191,6 +192,7 @@ parallel_chat_structured <- function( on_error <- arg_match(on_error) provider <- chat$get_provider() + model <- chat$get_model_obj() needs_wrapper <- type_needs_wrapper(type, provider) # First build up list of cumulative conversations @@ -200,6 +202,7 @@ parallel_chat_structured <- function( turns <- parallel_turns( provider = provider, + model = model, conversations = conversations, tools = chat$get_tools(), type = wrap_type_if_needed(type, needs_wrapper), @@ -207,7 +210,7 @@ parallel_chat_structured <- function( rpm = rpm, on_error = on_error ) - log_turns(provider, turns) + log_turns(provider, model, turns) multi_convert( provider, @@ -227,7 +230,7 @@ multi_convert <- function( include_tokens = FALSE, include_cost = FALSE ) { - needs_wrapper <- type_needs_wrapper(type, provider) + needs_wrapper <- type_needs_wrapper(type, provider, NULL) rows <- map(turns, \(turn) { if (turn_failed(turn)) { @@ -306,6 +309,7 @@ turn_failed <- function(turn) { parallel_turns <- function( provider, + model, conversations, tools, type = NULL, @@ -316,6 +320,7 @@ parallel_turns <- function( reqs <- map(conversations, function(turns) { chat_request( provider = provider, + model = model, turns = turns, type = type, tools = tools, @@ -352,7 +357,7 @@ parallel_turns <- function( resp } else { json <- resp_body_json(resp) - turn <- value_turn(provider, json, has_type = !is.null(type)) + turn <- value_turn(provider, model, json, has_type = !is.null(type)) turn@duration <- resp_timing(resp)[["total"]] %||% NA_real_ turn } diff --git a/R/provider-anthropic.R b/R/provider-anthropic.R index b0ae9f4fe..be0711d5e 100644 --- a/R/provider-anthropic.R +++ b/R/provider-anthropic.R @@ -100,17 +100,24 @@ chat_anthropic <- function( provider <- ProviderAnthropic( name = "Anthropic", - model = model, - params = params %||% params(), - extra_args = api_args, extra_headers = api_headers, base_url = base_url, beta_headers = beta_headers, credentials = credentials, cache = cache ) + model_obj <- Model( + name = model, + params = params %||% params(), + extra_args = api_args + ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } #' @rdname chat_anthropic @@ -180,6 +187,7 @@ method(chat_path, ProviderAnthropic) <- function(provider) { } method(chat_body, ProviderAnthropic) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -213,7 +221,7 @@ method(chat_body, ProviderAnthropic) <- function( } tools <- as_json(provider, unname(tools)) - params <- chat_params(provider, provider@params) + params <- chat_params(provider, model) if (has_name(params, "budget_tokens")) { thinking <- list( type = "enabled", @@ -225,7 +233,7 @@ method(chat_body, ProviderAnthropic) <- function( } compact(list2( - model = provider@model, + model = model@name, system = system, messages = messages, stream = stream, @@ -236,9 +244,9 @@ method(chat_body, ProviderAnthropic) <- function( )) } -method(chat_params, ProviderAnthropic) <- function(provider, params) { +method(chat_params, ProviderAnthropic) <- function(provider, model) { params <- standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "top_p", @@ -326,7 +334,7 @@ method(stream_merge_chunks, ProviderAnthropic) <- function( result } -method(value_tokens, ProviderAnthropic) <- function(provider, json) { +method(value_tokens, ProviderAnthropic) <- function(provider, model, json) { tokens( # Hack in pricing for cache writes input = json$usage$input_tokens + @@ -338,6 +346,7 @@ method(value_tokens, ProviderAnthropic) <- function(provider, json) { method(value_turn, ProviderAnthropic) <- function( provider, + model, result, has_type = FALSE ) { @@ -366,8 +375,8 @@ method(value_turn, ProviderAnthropic) <- function( } }) - tokens <- value_tokens(provider, result) - cost <- get_token_cost(provider, tokens) + tokens <- value_tokens(provider, model, result) + cost <- get_token_cost(provider, model, tokens) AssistantTurn(contents, json = result, tokens = unlist(tokens), cost = cost) } @@ -584,11 +593,12 @@ method(batch_retrieve, ProviderAnthropic) <- function(provider, batch) { method(batch_result_turn, ProviderAnthropic) <- function( provider, + model, result, has_type = FALSE ) { if (result$type == "succeeded") { - value_turn(provider, result$message, has_type = has_type) + value_turn(provider, model, result$message, has_type = has_type) } else { NULL } @@ -604,7 +614,6 @@ models_anthropic <- function( ) { provider <- ProviderAnthropic( name = "Anthropic", - model = "", base_url = base_url, credentials = function() api_key, cache = "none" diff --git a/R/provider-aws.R b/R/provider-aws.R index 0132a13ba..7f1e1bb90 100644 --- a/R/provider-aws.R +++ b/R/provider-aws.R @@ -73,11 +73,19 @@ chat_aws_bedrock <- function( base_url = base_url, model = model, profile = profile, - params = params, - extra_args = api_args, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } @@ -90,7 +98,7 @@ models_aws_bedrock <- function(profile = NULL, base_url = NULL) { provider <- provider_aws_bedrock( base_url = base_url, model = "", - profile = profile, + profile = profile ) req <- base_request(provider) req <- req_url_path_append(req, "foundation-models") @@ -112,8 +120,6 @@ provider_aws_bedrock <- function( base_url, model = "", profile = NULL, - params = list(), - extra_args = list(), extra_headers = character() ) { cache <- aws_creds_cache(profile) @@ -128,13 +134,11 @@ provider_aws_bedrock <- function( ProviderAWSBedrock( name = "AWS/Bedrock", base_url = base_url, - model = model, profile = profile, region = credentials$region, cache = cache, - params = params, - extra_args = extra_args, - extra_headers = extra_headers + extra_headers = extra_headers, + model = model ) } @@ -171,10 +175,10 @@ method(base_request_error, ProviderAWSBedrock) <- function(provider, req) { }) } -method(chat_params, ProviderAWSBedrock) <- function(provider, params) { +method(chat_params, ProviderAWSBedrock) <- function(provider, model) { # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html standardise_params( - params, + model@params, c( temperature = "temperature", topP = "top_p", @@ -186,6 +190,7 @@ method(chat_params, ProviderAWSBedrock) <- function(provider, params) { method(chat_request, ProviderAWSBedrock) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -229,9 +234,9 @@ method(chat_request, ProviderAWSBedrock) <- function( } # Merge params into inferenceConfig, giving precedence to manual api_args - params <- chat_params(provider, provider@params) + params <- chat_params(provider, model) - extra_args <- provider@extra_args + extra_args <- model@extra_args extra_args$inferenceConfig <- modify_list(params, extra_args$inferenceConfig) # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html @@ -324,7 +329,7 @@ method(stream_merge_chunks, ProviderAWSBedrock) <- function( result } -method(value_tokens, ProviderAWSBedrock) <- function(provider, json) { +method(value_tokens, ProviderAWSBedrock) <- function(provider, model, json) { tokens( input = json$usage$inputTokens, output = json$usage$outputTokens, @@ -333,6 +338,7 @@ method(value_tokens, ProviderAWSBedrock) <- function(provider, json) { method(value_turn, ProviderAWSBedrock) <- function( provider, + model, result, has_type = FALSE ) { @@ -357,8 +363,8 @@ method(value_turn, ProviderAWSBedrock) <- function( } }) - tokens <- value_tokens(provider, result) - cost <- get_token_cost(provider, tokens) + tokens <- value_tokens(provider, model, result) + cost <- get_token_cost(provider, model, tokens) AssistantTurn(contents, json = result, tokens = unlist(tokens), cost = cost) } diff --git a/R/provider-azure.R b/R/provider-azure.R index 634e3198a..861c64378 100644 --- a/R/provider-azure.R +++ b/R/provider-azure.R @@ -77,14 +77,21 @@ chat_azure_openai <- function( provider <- ProviderAzureOpenAI( name = "Azure/OpenAI", base_url = paste0(endpoint, "/openai/deployments/", model), - model = model, - params = params, api_version = api_version, credentials = credentials, - extra_args = api_args, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } chat_azure_openai_test <- function( diff --git a/R/provider-cloudflare.R b/R/provider-cloudflare.R index 1c23079c8..0a337977b 100644 --- a/R/provider-cloudflare.R +++ b/R/provider-cloudflare.R @@ -58,14 +58,21 @@ chat_cloudflare <- function( provider <- ProviderCloudflare( name = "Cloudflare", base_url = base_url, - model = model, - params = params, credentials = credentials, - extra_args = api_args, extra_headers = api_headers ) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderCloudflare <- new_class("ProviderCloudflare", parent = ProviderOpenAI) diff --git a/R/provider-databricks.R b/R/provider-databricks.R index 9737dbf48..755b113be 100644 --- a/R/provider-databricks.R +++ b/R/provider-databricks.R @@ -68,13 +68,20 @@ chat_databricks <- function( provider <- ProviderDatabricks( name = "Databricks", base_url = workspace, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderDatabricks <- new_class( @@ -93,6 +100,7 @@ method(base_request, ProviderDatabricks) <- function(provider) { method(chat_body, ProviderDatabricks) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -100,13 +108,14 @@ method(chat_body, ProviderDatabricks) <- function( ) { body <- chat_body( super(provider, ProviderOpenAI), + model, stream = stream, turns = turns, tools = tools, type = type ) - params <- chat_params(provider, provider@params) + params <- chat_params(provider, model) body <- modify_list(body, params) # Databricks doesn't support stream options @@ -115,10 +124,10 @@ method(chat_body, ProviderDatabricks) <- function( body } -method(chat_params, ProviderDatabricks) <- function(provider, params) { +method(chat_params, ProviderDatabricks) <- function(provider, model) { # https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#chat-request standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "topP", diff --git a/R/provider-deepseek.R b/R/provider-deepseek.R index e66b37e42..5d84acd61 100644 --- a/R/provider-deepseek.R +++ b/R/provider-deepseek.R @@ -50,21 +50,28 @@ chat_deepseek <- function( provider <- ProviderDeepSeek( name = "DeepSeek", base_url = base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderDeepSeek <- new_class("ProviderDeepSeek", parent = ProviderOpenAI) -method(chat_params, ProviderDeepSeek) <- function(provider, params) { +method(chat_params, ProviderDeepSeek) <- function(provider, model) { # https://platform.deepseek.com/api-docs/api/create-chat-completion standardise_params( - params, + model@params, c( frequency_penalty = "frequency_penalty", max_tokens = "max_tokens", diff --git a/R/provider-github.R b/R/provider-github.R index 379c2170e..e31c9e99e 100644 --- a/R/provider-github.R +++ b/R/provider-github.R @@ -91,7 +91,7 @@ models_github <- function( provider <- ProviderOpenAI( name = "github", - model = "", + base_url = base_url, credentials = credentials ) diff --git a/R/provider-google.R b/R/provider-google.R index 1bcb9772e..42e38c589 100644 --- a/R/provider-google.R +++ b/R/provider-google.R @@ -60,13 +60,21 @@ chat_google_gemini <- function( provider <- ProviderGoogleGemini( name = "Google/Gemini", base_url = base_url, - model = model, - params = params %||% params(), - extra_args = api_args, extra_headers = api_headers, - credentials = credentials + credentials = credentials, + model = model + ) + model_obj <- Model( + name = model, + params = params %||% params(), + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) } chat_google_gemini_test <- function( @@ -107,13 +115,21 @@ chat_google_vertex <- function( provider <- ProviderGoogleGemini( name = "Google/Vertex", base_url = vertex_url(location, project_id), - model = model, - params = params %||% params(), - extra_args = api_args, extra_headers = api_headers, - credentials = credentials + credentials = credentials, + model = model + ) + model_obj <- Model( + name = model, + params = params %||% params(), + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) } # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent @@ -153,6 +169,7 @@ method(base_request, ProviderGoogleGemini) <- function(provider) { method(chat_request, ProviderGoogleGemini) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -179,12 +196,13 @@ method(chat_request, ProviderGoogleGemini) <- function( body <- chat_body( provider = provider, + model = model, stream = stream, turns = turns, tools = tools, type = type ) - body <- modify_list(body, provider@extra_args) + body <- modify_list(body, model@extra_args) req <- req_body_json(req, body) req <- req_headers(req, !!!provider@extra_headers) @@ -193,6 +211,7 @@ method(chat_request, ProviderGoogleGemini) <- function( method(chat_body, ProviderGoogleGemini) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -204,7 +223,7 @@ method(chat_body, ProviderGoogleGemini) <- function( system <- list(parts = list(text = "")) } - generation_config <- chat_params(provider, provider@params) + generation_config <- chat_params(provider, model) if (!is.null(type)) { generation_config$response_mime_type <- "application/json" generation_config$response_schema <- as_json(provider, type) @@ -236,9 +255,9 @@ method(chat_body, ProviderGoogleGemini) <- function( )) } -method(chat_params, ProviderGoogleGemini) <- function(provider, params) { +method(chat_params, ProviderGoogleGemini) <- function(provider, model) { standardise_params( - params, + model@params, c( temperature = "temperature", topP = "top_p", @@ -278,7 +297,7 @@ method(stream_merge_chunks, ProviderGoogleGemini) <- function( } } -method(value_tokens, ProviderGoogleGemini) <- function(provider, json) { +method(value_tokens, ProviderGoogleGemini) <- function(provider, model, json) { usage <- json$usageMetadata cached <- usage$cachedContentTokenCount %||% 0 @@ -291,6 +310,7 @@ method(value_tokens, ProviderGoogleGemini) <- function(provider, json) { method(value_turn, ProviderGoogleGemini) <- function( provider, + model, result, has_type = FALSE ) { @@ -322,8 +342,8 @@ method(value_turn, ProviderGoogleGemini) <- function( } }) contents <- compact(contents) - tokens <- value_tokens(provider, result) - cost <- get_token_cost(provider, tokens) + tokens <- value_tokens(provider, model, result) + cost <- get_token_cost(provider, model, tokens) AssistantTurn(contents, json = result, tokens = unlist(tokens), cost = cost) } diff --git a/R/provider-groq.R b/R/provider-groq.R index 48e3c9ae6..48a5224b5 100644 --- a/R/provider-groq.R +++ b/R/provider-groq.R @@ -53,13 +53,20 @@ chat_groq <- function( provider <- ProviderGroq( name = "Groq", base_url = base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderGroq <- new_class("ProviderGroq", parent = ProviderOpenAI) diff --git a/R/provider-huggingface.R b/R/provider-huggingface.R index f291983cd..554002bf1 100644 --- a/R/provider-huggingface.R +++ b/R/provider-huggingface.R @@ -56,13 +56,20 @@ chat_huggingface <- function( provider <- ProviderHuggingFace( name = "HuggingFace", base_url = base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderHuggingFace <- new_class("ProviderHuggingFace", parent = ProviderOpenAI) diff --git a/R/provider-mistral.R b/R/provider-mistral.R index a063e9f7a..138588f96 100644 --- a/R/provider-mistral.R +++ b/R/provider-mistral.R @@ -44,13 +44,20 @@ chat_mistral <- function( provider <- ProviderMistral( name = "Mistral", base_url = mistral_base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } mistral_base_url <- "https://api.mistral.ai/v1/" @@ -91,6 +98,7 @@ method(base_request, ProviderMistral) <- function(provider) { method(chat_body, ProviderMistral) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -98,6 +106,7 @@ method(chat_body, ProviderMistral) <- function( ) { body <- chat_body( super(provider, ProviderOpenAI), + model, stream = stream, turns = turns, tools = tools, @@ -110,9 +119,9 @@ method(chat_body, ProviderMistral) <- function( body } -method(chat_params, ProviderMistral) <- function(provider, params) { +method(chat_params, ProviderMistral) <- function(provider, model) { standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "top_p", @@ -137,7 +146,6 @@ mistral_key <- function() { models_mistral <- function(api_key = mistral_key()) { provider <- ProviderMistral( name = "Mistral", - model = "", base_url = mistral_base_url, credentials = function() api_key ) diff --git a/R/provider-ollama.R b/R/provider-ollama.R index eb10b475a..9f7b0c9f6 100644 --- a/R/provider-ollama.R +++ b/R/provider-ollama.R @@ -85,14 +85,22 @@ chat_ollama <- function( provider <- ProviderOllama( name = "Ollama", base_url = file.path(base_url, "v1"), ## the v1 portion of the path is added for openAI compatible API - model = model, - params = params %||% params(), - extra_args = api_args, credentials = credentials, - extra_headers = api_headers + extra_headers = api_headers, + model = model + ) + model_obj <- Model( + name = model, + params = params %||% params(), + extra_args = api_args ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderOllama <- new_class( @@ -103,10 +111,10 @@ ProviderOllama <- new_class( ) ) -method(chat_params, ProviderOllama) <- function(provider, params) { +method(chat_params, ProviderOllama) <- function(provider, model) { # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion standardise_params( - params, + model@params, c( frequency_penalty = "frequency_penalty", presence_penalty = "presence_penalty", diff --git a/R/provider-openai-responses.R b/R/provider-openai-responses.R index d082e7ec3..c06ef4f10 100644 --- a/R/provider-openai-responses.R +++ b/R/provider-openai-responses.R @@ -55,14 +55,21 @@ chat_openai_responses <- function( provider <- ProviderOpenAIResponses( name = "OpenAI", base_url = base_url, - model = model, - params = params %||% params(), - extra_args = api_args, extra_headers = api_headers, credentials = credentials, service_tier = service_tier ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params %||% params(), + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } chat_openai_responses_test <- function( system_prompt = "Be terse.", @@ -101,6 +108,7 @@ method(chat_path, ProviderOpenAIResponses) <- function(provider) { # https://platform.openai.com/docs/api-reference/responses method(chat_body, ProviderOpenAIResponses) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -124,7 +132,7 @@ method(chat_body, ProviderOpenAIResponses) <- function( } # https://platform.openai.com/docs/api-reference/responses/create#responses-create-include - params <- chat_params(provider, provider@params) + params <- chat_params(provider, model) if (has_name(params, "reasoning_effort")) { reasoning <- list( @@ -157,9 +165,9 @@ method(chat_body, ProviderOpenAIResponses) <- function( } -method(chat_params, ProviderOpenAIResponses) <- function(provider, params) { +method(chat_params, ProviderOpenAIResponses) <- function(provider, model) { standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "top_p", @@ -205,7 +213,11 @@ method(stream_merge_chunks, ProviderOpenAIResponses) <- function( } } -method(value_tokens, ProviderOpenAIResponses) <- function(provider, json) { +method(value_tokens, ProviderOpenAIResponses) <- function( + provider, + model, + json +) { usage <- json$usage cached_tokens <- usage$input_tokens_details$cached_tokens %||% 0 diff --git a/R/provider-openai.R b/R/provider-openai.R index 959799298..8e5f5b0d3 100644 --- a/R/provider-openai.R +++ b/R/provider-openai.R @@ -69,13 +69,20 @@ chat_openai <- function( provider <- ProviderOpenAI( name = "OpenAI", base_url = base_url, - model = model, - params = params %||% params(), - extra_args = api_args, extra_headers = api_headers, credentials = credentials ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model <- Model( + name = model, + params = params %||% params(), + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model, + system_prompt = system_prompt, + echo = echo + ) } chat_openai_test <- function( system_prompt = "Be terse.", @@ -147,6 +154,7 @@ method(chat_path, ProviderOpenAI) <- function(provider) { # https://platform.openai.com/docs/api-reference/chat/create method(chat_body, ProviderOpenAI) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -168,11 +176,11 @@ method(chat_body, ProviderOpenAI) <- function( response_format <- NULL } - params <- chat_params(provider, provider@params) + params <- chat_params(provider, model) compact(list2( messages = messages, - model = provider@model, + model = model@name, !!!params, stream = stream, stream_options = if (stream) list(include_usage = TRUE), @@ -182,9 +190,9 @@ method(chat_body, ProviderOpenAI) <- function( } -method(chat_params, ProviderOpenAI) <- function(provider, params) { +method(chat_params, ProviderOpenAI) <- function(provider, model) { standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "top_p", @@ -227,7 +235,7 @@ method(stream_merge_chunks, ProviderOpenAI) <- function( } } -method(value_tokens, ProviderOpenAI) <- function(provider, json) { +method(value_tokens, ProviderOpenAI) <- function(provider, model, json) { usage <- json$usage cached_tokens <- usage$prompt_tokens_details$cached_tokens %||% 0 @@ -240,6 +248,7 @@ method(value_tokens, ProviderOpenAI) <- function(provider, json) { method(value_turn, ProviderOpenAI) <- function( provider, + model, result, has_type = FALSE ) { @@ -272,8 +281,8 @@ method(value_turn, ProviderOpenAI) <- function( content <- c(content, calls) } - tokens <- value_tokens(provider, result) - cost <- get_token_cost(provider, tokens) + tokens <- value_tokens(provider, model, result) + cost <- get_token_cost(provider, model, tokens) AssistantTurn(content, json = result, tokens = unlist(tokens), cost = cost) } @@ -525,11 +534,12 @@ method(batch_retrieve, ProviderOpenAI) <- function(provider, batch) { method(batch_result_turn, ProviderOpenAI) <- function( provider, + model, result, has_type = FALSE ) { if (result$status_code == 200) { - value_turn(provider, result$body, has_type = has_type) + value_turn(provider, model, result$body, has_type = has_type) } else { NULL } @@ -553,7 +563,6 @@ models_openai <- function( provider <- ProviderOpenAI( name = "OpenAI", - model = "", base_url = base_url, credentials = credentials ) diff --git a/R/provider-openrouter.R b/R/provider-openrouter.R index 928f5bfe9..62a6fd5a4 100644 --- a/R/provider-openrouter.R +++ b/R/provider-openrouter.R @@ -47,13 +47,20 @@ chat_openrouter <- function( provider <- ProviderOpenRouter( name = "OpenRouter", base_url = "https://openrouter.ai/api/v1", - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } chat_openrouter_test <- function(..., echo = "none") { @@ -65,10 +72,10 @@ ProviderOpenRouter <- new_class( parent = ProviderOpenAI, ) -method(chat_params, ProviderOpenRouter) <- function(provider, params) { +method(chat_params, ProviderOpenRouter) <- function(provider, model) { # https://openrouter.ai/docs/api-reference/parameters standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "top_p", diff --git a/R/provider-perplexity.R b/R/provider-perplexity.R index 73274a452..f76e36b1c 100644 --- a/R/provider-perplexity.R +++ b/R/provider-perplexity.R @@ -53,13 +53,20 @@ chat_perplexity <- function( provider <- ProviderPerplexity( name = "Perplexity", base_url = base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderPerplexity <- new_class( @@ -67,10 +74,10 @@ ProviderPerplexity <- new_class( parent = ProviderOpenAI, ) -method(chat_params, ProviderPerplexity) <- function(provider, params) { +method(chat_params, ProviderPerplexity) <- function(provider, model) { # https://docs.perplexity.ai/api-reference/chat-completions-post standardise_params( - params, + model@params, c( max_tokens = "max_tokens", temperature = "temperature", diff --git a/R/provider-portkey.R b/R/provider-portkey.R index e00c73cf5..40f5d3381 100644 --- a/R/provider-portkey.R +++ b/R/provider-portkey.R @@ -50,14 +50,21 @@ chat_portkey <- function( provider <- ProviderPortkeyAI( name = "PortkeyAI", base_url = base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, virtual_key = virtual_key, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } chat_portkey_test <- function( @@ -122,7 +129,6 @@ models_portkey <- function( ) { provider <- ProviderPortkeyAI( name = "PortkeyAI", - model = "", base_url = base_url, credentials = function() api_key, virtual_key = virtual_key diff --git a/R/provider-snowflake.R b/R/provider-snowflake.R index 24c2dd010..562e9e2ba 100644 --- a/R/provider-snowflake.R +++ b/R/provider-snowflake.R @@ -61,13 +61,20 @@ chat_snowflake <- function( base_url = snowflake_url(account), account = account, credentials = credentials, - model = model, - params = params, - extra_args = api_args, extra_headers = api_headers ) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } ProviderSnowflakeCortex <- new_class( @@ -100,6 +107,7 @@ method(chat_path, ProviderSnowflakeCortex) <- function(provider) { # See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#api-reference method(chat_body, ProviderSnowflakeCortex) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -115,10 +123,10 @@ method(chat_body, ProviderSnowflakeCortex) <- function( response_format <- NULL } - params <- chat_params(provider, provider@params) + params <- chat_params(provider, model) compact(list2( messages = messages, - model = provider@model, + model = model@name, !!!params, stream = stream, tools = tools, @@ -145,9 +153,9 @@ method(as_json, list(ProviderSnowflakeCortex, TypeObject)) <- function( } # See: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#optional-json-arguments -method(chat_params, ProviderSnowflakeCortex) <- function(provider, params) { +method(chat_params, ProviderSnowflakeCortex) <- function(provider, model) { standardise_params( - params, + model@params, c( temperature = "temperature", top_p = "top_p", @@ -215,7 +223,11 @@ method(stream_merge_chunks, ProviderSnowflakeCortex) <- function( result } -method(value_tokens, ProviderSnowflakeCortex) <- function(provider, json) { +method(value_tokens, ProviderSnowflakeCortex) <- function( + provider, + model, + json +) { tokens( input = json$usage$prompt_tokens, output = json$usage$completion_tokens @@ -224,6 +236,7 @@ method(value_tokens, ProviderSnowflakeCortex) <- function(provider, json) { method(value_turn, ProviderSnowflakeCortex) <- function( provider, + model, result, has_type = FALSE ) { @@ -252,8 +265,8 @@ method(value_turn, ProviderSnowflakeCortex) <- function( ) } }) - tokens <- value_tokens(provider, result) - cost <- get_token_cost(provider, tokens) + tokens <- value_tokens(provider, model, result) + cost <- get_token_cost(provider, model, tokens) AssistantTurn(contents, json = result, tokens = unlist(tokens), cost = cost) } diff --git a/R/provider-vllm.R b/R/provider-vllm.R index 39da5d744..254f0a3b9 100644 --- a/R/provider-vllm.R +++ b/R/provider-vllm.R @@ -57,13 +57,20 @@ chat_vllm <- function( provider <- ProviderVllm( name = "VLLM", base_url = base_url, - model = model, - params = params, - extra_args = api_args, credentials = credentials, extra_headers = api_headers ) - Chat$new(provider = provider, system_prompt = system_prompt, echo = echo) + model_obj <- Model( + name = model, + params = params, + extra_args = api_args + ) + Chat$new( + provider = provider, + model = model_obj, + system_prompt = system_prompt, + echo = echo + ) } chat_vllm_test <- function(..., echo = "none") { diff --git a/R/provider.R b/R/provider.R index 360433b5d..9f7744737 100644 --- a/R/provider.R +++ b/R/provider.R @@ -14,36 +14,29 @@ NULL #' #' @export #' @param name Name of the provider. -#' @param model Name of the model. #' @param base_url The base URL for the API. -#' @param params A list of standard parameters created by [params()]. #' @param credentials A zero-argument function that returns the credentials to use #' for authentication. Can either return a string, representing an API key, #' or a named list of headers. -#' @param extra_args Arbitrary extra arguments to be included in the request body. #' @param extra_headers Arbitrary extra headers to be added to the request. #' @return An S7 Provider object. #' @examples #' Provider( #' name = "CoolModels", -#' model = "my_model", #' base_url = "https://cool-models.com" #' ) Provider <- new_class( "Provider", properties = list( name = prop_string(), - model = prop_string(), base_url = prop_string(), - params = class_list, - extra_args = class_list, extra_headers = class_character, credentials = class_function | NULL ) ) -test_provider <- function(name = "", model = "", base_url = "", ...) { - Provider(name = name, model = model, base_url = base_url, ...) +test_provider <- function(name = "", base_url = "", ...) { + Provider(name = name, base_url = base_url, ...) } # Create a request------------------------------------ @@ -65,6 +58,7 @@ chat_request <- new_generic( "provider", function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -76,6 +70,7 @@ chat_request <- new_generic( method(chat_request, Provider) <- function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -86,12 +81,13 @@ method(chat_request, Provider) <- function( body <- chat_body( provider = provider, + model = model, stream = stream, turns = turns, tools = tools, type = type ) - body <- modify_list(body, provider@extra_args) + body <- modify_list(body, model@extra_args) req <- req_body_json(req, body) req <- req_headers(req, !!!provider@extra_headers) @@ -103,6 +99,7 @@ chat_body <- new_generic( "provider", function( provider, + model, stream = TRUE, turns = list(), tools = list(), @@ -130,7 +127,7 @@ method(chat_resp_stream, Provider) <- function(provider, resp) { chat_params <- new_generic( "chat_params", "provider", - function(provider, params) { + function(provider, model) { S7_dispatch() } ) @@ -161,18 +158,24 @@ stream_merge_chunks <- new_generic( # Extract data from non-streaming results -------------------------------------- -value_turn <- new_generic("value_turn", "provider") +value_turn <- new_generic( + "value_turn", + "provider", + function(provider, model, json, ...) { + S7_dispatch() + } +) # Extract token counts from API response # Returns a named list produced by token_usage() value_tokens <- new_generic( "value_tokens", "provider", - function(provider, json) { + function(provider, model, json) { S7_dispatch() } ) -method(value_tokens, Provider) <- function(provider, json) { +method(value_tokens, Provider) <- function(provider, model, json) { tokens() } @@ -258,7 +261,7 @@ batch_retrieve <- new_generic( batch_result_turn <- new_generic( "batch_result_turn", "provider", - function(provider, result, has_type = FALSE) { + function(provider, model, result, has_type = FALSE) { S7_dispatch() } ) diff --git a/R/tokens.R b/R/tokens.R index c43c00432..0f4e59133 100644 --- a/R/tokens.R +++ b/R/tokens.R @@ -20,11 +20,11 @@ map_tokens <- function(x, f, ...) { out } -log_tokens <- function(provider, tokens, cost) { +log_tokens <- function(provider, model, tokens, cost) { i <- vctrs::vec_match( data.frame( provider = provider@name, - model = provider@model + model = model@name ), the$tokens[c("provider", "model")] ) @@ -32,7 +32,7 @@ log_tokens <- function(provider, tokens, cost) { if (is.na(i)) { new_row <- tokens_row( provider@name, - provider@model, + model@name, tokens$input, tokens$output, tokens$cached_input, @@ -50,14 +50,14 @@ log_tokens <- function(provider, tokens, cost) { invisible() } -log_turn <- function(provider, turn) { - log_tokens(provider, exec(tokens, !!!as.list(turn@tokens)), turn@cost) +log_turn <- function(provider, model, turn) { + log_tokens(provider, model, exec(tokens, !!!as.list(turn@tokens)), turn@cost) } -log_turns <- function(provider, turns) { +log_turns <- function(provider, model, turns) { for (turn in turns) { if (S7_inherits(turn, AssistantTurn)) { - log_turn(provider, turn) + log_turn(provider, model, turn) } } } @@ -110,14 +110,14 @@ token_usage <- function() { # Cost ---------------------------------------------------------------------- has_cost <- function(provider, model) { - needle <- data.frame(provider = provider@name, model = model) + needle <- data.frame(provider = provider@name, model = model@name) vctrs::vec_in(needle, prices[c("provider", "model")]) } -get_token_cost <- function(provider, tokens, variant = "") { +get_token_cost <- function(provider, model, tokens, variant = "") { needle <- data.frame( provider = provider@name, - model = provider@model, + model = model@name, variant = variant ) idx <- vctrs::vec_match(needle, prices[c("provider", "model", "variant")]) diff --git a/man/Chat.Rd b/man/Chat.Rd index 66ec02d58..7f1d6bb1b 100644 --- a/man/Chat.Rd +++ b/man/Chat.Rd @@ -45,6 +45,7 @@ chat$chat("Tell me a funny joke") \item \href{#method-Chat-register_tool}{\code{Chat$register_tool()}} \item \href{#method-Chat-register_tools}{\code{Chat$register_tools()}} \item \href{#method-Chat-get_provider}{\code{Chat$get_provider()}} +\item \href{#method-Chat-get_model_obj}{\code{Chat$get_model_obj()}} \item \href{#method-Chat-get_tools}{\code{Chat$get_tools()}} \item \href{#method-Chat-set_tools}{\code{Chat$set_tools()}} \item \href{#method-Chat-on_tool_request}{\code{Chat$on_tool_request()}} @@ -57,7 +58,7 @@ chat$chat("Tell me a funny joke") \if{latex}{\out{\hypertarget{method-Chat-new}{}}} \subsection{Method \code{new()}}{ \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Chat$new(provider, system_prompt = NULL, echo = "none")}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Chat$new(provider, model, system_prompt = NULL, echo = "none")}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -65,6 +66,8 @@ chat$chat("Tell me a funny joke") \describe{ \item{\code{provider}}{A provider object.} +\item{\code{model}}{A model object.} + \item{\code{system_prompt}}{System prompt to start the conversation with.} \item{\code{echo}}{One of the following options: @@ -447,6 +450,16 @@ Get the underlying provider object. For expert use only. \if{html}{\out{
}}\preformatted{Chat$get_provider()}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Chat-get_model_obj}{}}} +\subsection{Method \code{get_model_obj()}}{ +Get the underlying model object. For expert use only. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Chat$get_model_obj()}\if{html}{\out{
}} +} + } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/Model.Rd b/man/Model.Rd new file mode 100644 index 000000000..63d9a1a02 --- /dev/null +++ b/man/Model.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model.R +\name{Model} +\alias{Model} +\title{A chatbot model} +\usage{ +Model(name = stop("Required"), params = list(), extra_args = list()) +} +\arguments{ +\item{name}{Name of the model (e.g., "gpt-4", "claude-sonnet-4").} + +\item{params}{A list of standard parameters created by \code{\link[=params]{params()}}.} + +\item{extra_args}{Arbitrary extra arguments to be included in the request body.} +} +\value{ +An S7 Model object. +} +\description{ +A Model captures the details of a specific language model and its +configuration parameters. This includes the model name, parameters like +temperature and max_tokens, and any extra arguments to include in API +requests. +} +\details{ +Model objects are typically created internally by \code{\link[=chat_openai]{chat_openai()}}, +\code{\link[=chat_anthropic]{chat_anthropic()}}, and other \verb{chat_*()} functions. You generally won't +need to create them directly unless you're implementing a custom provider. +} +\examples{ +Model( + name = "gpt-4", + params = params(temperature = 0.7) +) +} diff --git a/man/Provider.Rd b/man/Provider.Rd index 38187849a..1a6e9c14b 100644 --- a/man/Provider.Rd +++ b/man/Provider.Rd @@ -6,10 +6,7 @@ \usage{ Provider( name = stop("Required"), - model = stop("Required"), base_url = stop("Required"), - params = list(), - extra_args = list(), extra_headers = character(0), credentials = function() NULL ) @@ -17,14 +14,8 @@ Provider( \arguments{ \item{name}{Name of the provider.} -\item{model}{Name of the model.} - \item{base_url}{The base URL for the API.} -\item{params}{A list of standard parameters created by \code{\link[=params]{params()}}.} - -\item{extra_args}{Arbitrary extra arguments to be included in the request body.} - \item{extra_headers}{Arbitrary extra headers to be added to the request.} \item{credentials}{A zero-argument function that returns the credentials to use @@ -48,7 +39,6 @@ the various generics that control the behavior of each provider. \examples{ Provider( name = "CoolModels", - model = "my_model", base_url = "https://cool-models.com" ) } diff --git a/tests/testthat/testthat-problems.rds b/tests/testthat/testthat-problems.rds new file mode 100644 index 000000000..af2df6f3e Binary files /dev/null and b/tests/testthat/testthat-problems.rds differ