Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Collate:
'import-standalone-types-check.R'
'interpolate.R'
'live.R'
'model.R'
'parallel-chat.R'
'params.R'
'provider-anthropic.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export(ContentText)
export(ContentThinking)
export(ContentToolRequest)
export(ContentToolResult)
export(Model)
export(Provider)
export(SystemTurn)
export(Turn)
Expand Down
12 changes: 9 additions & 3 deletions R/batch-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ BatchJob <- R6::R6Class(

# Internal state
provider = NULL,
model = NULL,
started_at = NULL,
stage = NULL,
batch = NULL,
Expand All @@ -208,6 +209,7 @@ BatchJob <- R6::R6Class(
call = caller_env(2)
) {
self$provider <- chat$get_provider()
self$model <- chat$get_model_obj()
check_has_batch_support(self$provider, call = call)

user_turns <- as_user_turns(prompts, call = call)
Expand Down Expand Up @@ -326,7 +328,7 @@ BatchJob <- R6::R6Class(

retrieve = function() {
self$results <- batch_retrieve(self$provider, self$batch)
log_turns(self$provider, self$result_turns())
log_turns(self$provider, self$model, self$result_turns())

self$stage <- "done"
self$save_state()
Expand All @@ -335,7 +337,12 @@ BatchJob <- R6::R6Class(

result_turns = function() {
map2(self$results, self$user_turns, function(result, user_turn) {
batch_result_turn(self$provider, result, has_type = !is.null(self$type))
batch_result_turn(
self$provider,
self$model,
result,
has_type = !is.null(self$type)
)
})
},

Expand Down Expand Up @@ -379,7 +386,6 @@ BatchJob <- R6::R6Class(
provider_hash <- function(x) {
list(
name = x@name,
model = x@model,
base_url = x@base_url
)
}
Expand Down
40 changes: 34 additions & 6 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Chat <- R6::R6Class(
"Chat",
public = list(
#' @param provider A provider object.
#' @param model A model object.
#' @param system_prompt System prompt to start the conversation with.
#' @param echo One of the following options:
#' * `none`: don't emit any output (default when running in a function).
Expand All @@ -33,8 +34,14 @@ Chat <- R6::R6Class(
#'
#' Note this only affects the `chat()` method. You can override the default
#' by setting the `ellmer_echo` option.
initialize = function(provider, system_prompt = NULL, echo = "none") {
initialize = function(
provider,
model,
system_prompt = NULL,
echo = "none"
) {
private$provider <- provider
private$model <- model
private$echo <- echo
private$callback_on_tool_request <- CallbackManager$new(args = "request")
private$callback_on_tool_result <- CallbackManager$new(args = "result")
Expand Down Expand Up @@ -78,7 +85,7 @@ Chat <- R6::R6Class(
check_turn(assistant)

if (log_tokens) {
log_turn(private$provider, assistant)
log_turn(private$provider, private$model, assistant)
}

private$.turns[[length(private$.turns) + 1]] <- user
Expand All @@ -97,7 +104,7 @@ Chat <- R6::R6Class(

#' @description Retrieve the model name
get_model = function() {
private$provider@model
private$model@name
},

#' @description Update the system prompt
Expand Down Expand Up @@ -381,6 +388,11 @@ Chat <- R6::R6Class(
private$provider
},

#' @description Get the underlying model object. For expert use only.
get_model_obj = function() {
private$model
},

#' @description Retrieve the list of registered tools.
get_tools = function() {
private$tools
Expand Down Expand Up @@ -422,6 +434,7 @@ Chat <- R6::R6Class(
),
private = list(
provider = NULL,
model = NULL,

.turns = list(),
echo = NULL,
Expand Down Expand Up @@ -579,6 +592,7 @@ Chat <- R6::R6Class(

response <- chat_perform(
provider = private$provider,
model = private$model,
mode = if (stream) "stream" else "value",
turns = c(private$.turns, list(user_turn)),
tools = if (is.null(type)) private$tools,
Expand All @@ -603,11 +617,17 @@ Chat <- R6::R6Class(

result <- stream_merge_chunks(private$provider, result, chunk)
}
turn <- value_turn(private$provider, result, has_type = !is.null(type))
turn <- value_turn(
private$provider,
private$model,
result,
has_type = !is.null(type)
)
turn <- match_tools(turn, private$tools)
} else {
turn <- value_turn(
private$provider,
private$model,
resp_body_json(response),
has_type = !is.null(type)
)
Expand Down Expand Up @@ -661,6 +681,7 @@ Chat <- R6::R6Class(
) {
response <- chat_perform(
provider = private$provider,
model = private$model,
mode = if (stream) "async-stream" else "async-value",
turns = c(private$.turns, list(user_turn)),
tools = if (is.null(type)) private$tools,
Expand All @@ -685,12 +706,18 @@ Chat <- R6::R6Class(

result <- stream_merge_chunks(private$provider, result, chunk)
}
turn <- value_turn(private$provider, result, has_type = !is.null(type))
turn <- value_turn(
private$provider,
private$model,
result,
has_type = !is.null(type)
)
} else {
result <- await(response)

turn <- value_turn(
private$provider,
private$model,
resp_body_json(result),
has_type = !is.null(type)
)
Expand Down Expand Up @@ -738,6 +765,7 @@ Chat <- R6::R6Class(
#' @export
print.Chat <- function(x, ...) {
provider <- x$get_provider()
model <- x$get_model()
turns <- x$get_turns(include_system_prompt = TRUE)

assistant_turns <- keep(turns, \(x) x@role == "assistant")
Expand All @@ -746,7 +774,7 @@ print.Chat <- function(x, ...) {

cat(paste_c(
"<Chat",
c(" ", provider@name, "/", provider@model),
c(" ", provider@name, "/", model),
c(" turns=", length(turns)),
turn_cost(total_tokens, total_cost, prefix = " "),
">\n"
Expand Down
2 changes: 2 additions & 0 deletions R/httr2.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# We will recconsider this in the future if necessary.
chat_perform <- function(
provider,
model,
mode = c("value", "stream", "async-stream", "async-value"),
turns,
tools = NULL,
Expand All @@ -14,6 +15,7 @@ chat_perform <- function(

req <- chat_request(
provider = provider,
model = model,
turns = turns,
tools = tools,
stream = stream,
Expand Down
29 changes: 29 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#' A chatbot model
#'
#' A Model captures the details of a specific language model and its
#' configuration parameters. This includes the model name, parameters like
#' temperature and max_tokens, and any extra arguments to include in API
#' requests.
#'
#' Model objects are typically created internally by [chat_openai()],
#' [chat_anthropic()], and other `chat_*()` functions. You generally won't
#' need to create them directly unless you're implementing a custom provider.
#'
#' @export
#' @param name Name of the model (e.g., "gpt-4", "claude-sonnet-4").
#' @param params A list of standard parameters created by [params()].
#' @param extra_args Arbitrary extra arguments to be included in the request body.
#' @return An S7 Model object.
#' @examples
#' Model(
#' name = "gpt-4",
#' params = params(temperature = 0.7)
#' )
Model <- new_class(
"Model",
properties = list(
name = prop_string(),
params = class_list,
extra_args = class_list
)
)
13 changes: 9 additions & 4 deletions R/parallel-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ parallel_chat <- function(
my_parallel_turns <- function(conversations) {
parallel_turns(
provider = chat$get_provider(),
model = chat$get_model_obj(),
conversations = conversations,
tools = chat$get_tools(),
max_active = max_active,
Expand Down Expand Up @@ -127,7 +128,7 @@ parallel_chat <- function(
map(seq_along(conversations), function(i) {
if (is_ok[[i]]) {
turns <- conversations[[i]]
log_turns(chat$get_provider(), turns)
log_turns(chat$get_provider(), chat$get_model_obj(), turns)
chat$clone()$set_turns(turns)
} else {
assistant_turns[[i]]
Expand Down Expand Up @@ -191,6 +192,7 @@ parallel_chat_structured <- function(
on_error <- arg_match(on_error)

provider <- chat$get_provider()
model <- chat$get_model_obj()
needs_wrapper <- type_needs_wrapper(type, provider)

# First build up list of cumulative conversations
Expand All @@ -200,14 +202,15 @@ parallel_chat_structured <- function(

turns <- parallel_turns(
provider = provider,
model = model,
conversations = conversations,
tools = chat$get_tools(),
type = wrap_type_if_needed(type, needs_wrapper),
max_active = max_active,
rpm = rpm,
on_error = on_error
)
log_turns(provider, turns)
log_turns(provider, model, turns)

multi_convert(
provider,
Expand All @@ -227,7 +230,7 @@ multi_convert <- function(
include_tokens = FALSE,
include_cost = FALSE
) {
needs_wrapper <- type_needs_wrapper(type, provider)
needs_wrapper <- type_needs_wrapper(type, provider, NULL)

rows <- map(turns, \(turn) {
if (turn_failed(turn)) {
Expand Down Expand Up @@ -306,6 +309,7 @@ turn_failed <- function(turn) {

parallel_turns <- function(
provider,
model,
conversations,
tools,
type = NULL,
Expand All @@ -316,6 +320,7 @@ parallel_turns <- function(
reqs <- map(conversations, function(turns) {
chat_request(
provider = provider,
model = model,
turns = turns,
type = type,
tools = tools,
Expand Down Expand Up @@ -352,7 +357,7 @@ parallel_turns <- function(
resp
} else {
json <- resp_body_json(resp)
turn <- value_turn(provider, json, has_type = !is.null(type))
turn <- value_turn(provider, model, json, has_type = !is.null(type))
turn@duration <- resp_timing(resp)[["total"]] %||% NA_real_
turn
}
Expand Down
Loading
Loading