Skip to content

Commit 5b65a1b

Browse files
Inline model default for embedding functions in closures. (#52)
* Store argument defaults for embedding functions in closures. * refactor `embed` processing - only inline `model` arg - refactor recursive ast walker - process `embed` before computing `embedding_size` * add tests * preserve `...`, ensure `model` arg always named --------- Co-authored-by: Tomasz Kalinowski <[email protected]>
1 parent 7ad1630 commit 5b65a1b

File tree

3 files changed

+140
-10
lines changed

3 files changed

+140
-10
lines changed

R/store.R

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ ragnar_store_create <- function(
7171
if (is_motherduck_location(location)) {
7272
con <- motherduck_connection(location, create = TRUE, overwrite)
7373
} else {
74-
if (any(file.exists(c(location, location.wal <- paste0(location, ".wal"))))) {
74+
if (
75+
any(file.exists(c(location, location.wal <- paste0(location, ".wal"))))
76+
) {
7577
if (overwrite) {
7678
unlink(c(location, location.wal), force = TRUE)
7779
} else {
@@ -91,14 +93,10 @@ ragnar_store_create <- function(
9193
if (is.null(embed)) {
9294
embedding_size <- NULL
9395
} else {
96+
embed <- process_embed_func(embed)
9497
check_number_whole(embedding_size, min = 0)
9598
embedding_size <- as.integer(embedding_size)
9699

97-
if (!inherits(embed, "crate")) {
98-
environment(embed) <- baseenv()
99-
embed <- rlang::zap_srcref(embed)
100-
}
101-
102100
default_schema$embedding <- matrix(
103101
numeric(0),
104102
nrow = 0,
@@ -232,18 +230,80 @@ motherduck_connection <- function(location, create = FALSE, overwrite = FALSE) {
232230
if (create) {
233231
if (dbName %in% DBI::dbGetQuery(con, "SHOW DATABASES")$database_name) {
234232
if (overwrite) {
235-
DBI::dbExecute(con, glue::glue_sql(.con = con, "DROP DATABASE {`dbName`}"))
233+
DBI::dbExecute(
234+
con,
235+
glue::glue_sql(.con = con, "DROP DATABASE {`dbName`}")
236+
)
236237
} else {
237238
stop("Database already exists: ", dbName)
238239
}
239240
}
240-
DBI::dbExecute(con, glue::glue_sql(.con = con, "CREATE DATABASE {`dbName`}"))
241+
DBI::dbExecute(
242+
con,
243+
glue::glue_sql(.con = con, "CREATE DATABASE {`dbName`}")
244+
)
241245
}
242246

243247
DBI::dbExecute(con, glue::glue_sql(.con = con, "USE {`dbName`}"))
244248
con
245249
}
246250

251+
process_embed_func <- function(embed) {
252+
if (inherits(embed, "crate")) {
253+
return(embed)
254+
}
255+
environment(embed) <- baseenv()
256+
embed <- rlang::zap_srcref(embed)
257+
258+
embed_func_names <- grep(
259+
"^embed_",
260+
getNamespaceExports("ragnar"),
261+
value = TRUE
262+
)
263+
264+
walker <- function(x) {
265+
switch(
266+
typeof(x),
267+
list = {
268+
x <- lapply(x, walker)
269+
},
270+
language = {
271+
if (rlang::is_call(x, embed_func_names, ns = c("", "ragnar"))) {
272+
name <- rlang::call_name(x)
273+
fn <- get(name)
274+
ox <- x
275+
x <- rlang::call_match(x, fn, defaults = FALSE, dots_expand = FALSE)
276+
x <- as.list(x)
277+
278+
# ensure 'model' is explicit arg embedded in call
279+
if (!"model" %in% names(x)) {
280+
x["model"] <- formals(fn)["model"]
281+
}
282+
283+
# preserve `...` if they were present in the call (call_match() removes them)
284+
if (any(map_lgl(as.list(ox), identical, quote(...)))) {
285+
x <- c(x, quote(...))
286+
}
287+
x <- as.call(x)
288+
289+
# ensure the call is namespaced
290+
if (is.null(rlang::call_ns(x))) {
291+
x[[1L]] <- call("::", quote(ragnar), as.symbol(name))
292+
}
293+
} else {
294+
x <- as.call(lapply(x, walker))
295+
}
296+
},
297+
x
298+
)
299+
x
300+
}
301+
302+
body(embed) <- walker(body(embed))
303+
embed
304+
}
305+
306+
247307
#' Connect to `RagnarStore`
248308
#'
249309
#' @param location string, a filepath location.

tests/testthat/_snaps/store.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# embed functions get the defaults stored
2+
3+
Code
4+
store@embed
5+
Output
6+
function (x)
7+
ragnar::embed_openai(x = x, model = "text-embedding-3-small")
8+
<environment: base>
9+
10+
---
11+
12+
Code
13+
store@embed
14+
Output
15+
function (x)
16+
ragnar::embed_openai(x = x, model = "text-embedding-3-small")
17+
<environment: base>
18+
19+
---
20+
21+
Code
22+
store@embed
23+
Output
24+
function (x, ...)
25+
ragnar::embed_openai(x = x, ..., model = "text-embedding-3-small")
26+
<environment: base>
27+
28+
---
29+
30+
Code
31+
store@embed
32+
Output
33+
function (x)
34+
ragnar::embed_openai(x = x, model = "text-embedding-3-small")
35+
<environment: base>
36+
37+
---
38+
39+
Code
40+
store@embed
41+
Output
42+
function (x)
43+
ragnar::embed_ollama(x = x, model = "all-minilm")
44+
<environment: base>
45+

tests/testthat/test-store.R

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ test_that("Allow a NULL embedding function", {
168168
})
169169

170170
test_that("works with MotherDuck", {
171-
172171
testthat::skip_if(Sys.getenv("motherduck_token", "") == "")
173172

174173
store <- ragnar_store_create(
@@ -184,7 +183,10 @@ test_that("works with MotherDuck", {
184183
)
185184

186185
expect_error(ragnar_store_insert(store, chunks), regexp = NA)
187-
expect_warning(ragnar_store_build_index(store), regexp = "MotherDuck does not support")
186+
expect_warning(
187+
ragnar_store_build_index(store),
188+
regexp = "MotherDuck does not support"
189+
)
188190
expect_error(ragnar_retrieve(store, "hello"), regexp = NA)
189191

190192
# Since we used insert, there's no checking if the hash is the same
@@ -198,3 +200,26 @@ test_that("works with MotherDuck", {
198200
val <- dbGetQuery(store@.con, "select origin, hash, text from chunks")
199201
expect_equal(nrow(val), 1)
200202
})
203+
204+
test_that("embed functions get the defaults stored", {
205+
store <- ragnar_store_create(embed = function(x) ragnar::embed_openai(x))
206+
expect_snapshot(store@embed)
207+
208+
# here embed_openai is implicitly obtained from ragnar::embed_openai
209+
store <- ragnar_store_create(embed = function(x) embed_openai(x))
210+
expect_snapshot(store@embed)
211+
212+
# if the embed function takes ..., they're preserved
213+
store <- ragnar_store_create(
214+
embed = function(x, ...) ragnar::embed_openai(x, ...)
215+
)
216+
expect_snapshot(store@embed)
217+
218+
# when using the partialized version, we should also add the defaults
219+
store <- ragnar_store_create(embed = embed_openai())
220+
expect_snapshot(store@embed)
221+
222+
# test other embed funcs
223+
store <- ragnar_store_create(embed = function(x) ragnar::embed_ollama(x))
224+
expect_snapshot(store@embed)
225+
})

0 commit comments

Comments
 (0)