Skip to content

Commit 89a396a

Browse files
authored
Merge pull request #94 from mlverse/bugfix/93_parsnip
fix #93
2 parents dd451e9 + 01b92eb commit 89a396a

File tree

4 files changed

+40
-3
lines changed

4 files changed

+40
-3
lines changed

DESCRIPTION

+3-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Suggests:
3131
testthat (>= 3.0.0),
3232
modeldata,
3333
recipes,
34+
rsample,
3435
parsnip,
3536
dials,
3637
withr,
@@ -43,6 +44,7 @@ Suggests:
4344
tidyr,
4445
purrr,
4546
tune,
46-
workflows
47+
workflows,
48+
yardstick
4749
VignetteBuilder: knitr
4850
Config/testthat/edition: 3

R/hardhat.R

+3-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,9 @@ tabnet_bridge <- function(processed, config = tabnet_config(), tabnet_model, fro
329329

330330
}
331331
if (task == "supervised") {
332-
stopifnot("Error: found missing values in the response vector" = sum(is.na(outcomes))==0)
332+
if (sum(is.na(outcomes))>0) {
333+
rlang::abort("Error: found missing values in the response vector")
334+
}
333335
if (is.null(tabnet_model)) {
334336
# new supervised model needs network initialization
335337
tabnet_model_lst <- tabnet_initialize(predictors, outcomes, config = config)

R/parsnip.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ update.tabnet <- function(object, parameters = NULL, epochs = NULL, penalty = NU
307307
num_steps = NULL, feature_reusage = NULL, virtual_batch_size = NULL,
308308
num_independent = NULL, num_shared = NULL, momentum = NULL, ...) {
309309
rlang::check_installed("parsnip")
310-
eng_args <- parsnip::update_engine_parameters(object$eng_args, ...)
310+
eng_args <- parsnip::update_engine_parameters(object$eng_args, fresh=TRUE, ...)
311311
args <- list(
312312
epochs = rlang::enquo(epochs),
313313
penalty = rlang::enquo(penalty),

tests/testthat/test-parsnip.R

+33
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,36 @@ test_that("Check we can finalize a workflow", {
6262
expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$penalty), 0.01)
6363
expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$epochs), 1)
6464
})
65+
66+
test_that("Check we can finalize a workflow from a tune_grid", {
67+
68+
library(parsnip)
69+
data("ames", package = "modeldata")
70+
71+
model <- tabnet(epochs = tune()) %>%
72+
set_mode("regression") %>%
73+
set_engine("torch")
74+
75+
wf <- workflows::workflow() %>%
76+
workflows::add_model(model) %>%
77+
workflows::add_formula(Sale_Price ~ .)
78+
79+
custom_grid <- tidyr::crossing(epochs = c(1,2,3))
80+
cv_folds <- ames %>%
81+
rsample::vfold_cv(v = 2, repeats = 1)
82+
83+
at <- tune::tune_grid(
84+
object = wf,
85+
resamples = cv_folds,
86+
grid = custom_grid,
87+
metrics = yardstick::metric_set(yardstick::rmse),
88+
control = tune::control_grid(verbose = F)
89+
)
90+
91+
best_rmse <- tune::select_best(at, "rmse")
92+
93+
expect_error(
94+
final_wf <- tune::finalize_workflow(wf, best_rmse),
95+
regexp = NA
96+
)
97+
})

0 commit comments

Comments
 (0)