Skip to content

Commit

Permalink
Add skip_checks parameter to hub connection functions
Browse files Browse the repository at this point in the history
This changeset adds an optional skip_checks parameter to
connect_hub.R and connect_model_output.R per the requirements
outlined in hubverse-org#37.

When working with hub data on a local filesystem, the behavior
is unchanged. When working with hub data in an S3 bucket, the
connect functions will now skip data checks by default to
improve performance. The former connection behavior for
S3-based hubs can obtained by explicitly setting
skip_checks=FALSE.

This comment fixes the test suite to work when using
skip_checks=FALSE to force the previous behavior. The
next commit will add new tests to ensure the new behavior
works as intended.
  • Loading branch information
bsweger committed Jul 25, 2024
1 parent 1b9e98f commit 2795333
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 35 deletions.
66 changes: 49 additions & 17 deletions R/connect_hub.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#' `admin.json` and is ignored by default.
#' If supplied, it will override hub configuration setting. Multiple formats can
#' be supplied to `connect_hub` but only a single file format can be supplied to `connect_mod_out`.
#' @param skip_checks Logical. If `FALSE` (default), check file_format parameter against the
#' hub's model output files. Also excludes invalid model output files when opening hub datasets.
#' Setting to TRUE will improve performance but will result in an error if the model output
#' directory includes invalid files.
#' @inheritParams create_hub_schema
#'
#' @return
Expand Down Expand Up @@ -65,7 +69,7 @@
#' # Connect to a simple forecasting Hub stored in an AWS S3 bucket.
#' \dontrun{
#' hub_path <- s3_bucket("hubverse/hubutils/testhubs/simple/")
#' hub_con <- connect_hub(hub_path)
#' hub_con <- connect_hub(hub_path, skip_checks = FALSE)
#' hub_con
#' }
connect_hub <- function(hub_path,
Expand All @@ -75,7 +79,8 @@ connect_hub <- function(hub_path,
"double", "integer",
"logical", "Date"
),
partitions = list(model_id = arrow::utf8())) {
partitions = list(model_id = arrow::utf8()),
skip_checks = FALSE) {
UseMethod("connect_hub")
}

Expand All @@ -88,7 +93,8 @@ connect_hub.default <- function(hub_path,
"double", "integer",
"logical", "Date"
),
partitions = list(model_id = arrow::utf8())) {
partitions = list(model_id = arrow::utf8()),
skip_checks = FALSE) {
rlang::check_required(hub_path)
output_type_id_datatype <- rlang::arg_match(output_type_id_datatype)

Expand All @@ -112,8 +118,15 @@ connect_hub.default <- function(hub_path,
}
hub_name <- config_admin$name

# Only keep file formats of which files actually exist in model_output_dir.
file_format <- check_file_format(model_output_dir, file_format)
# Based on skip_checks param: 1) set a flag that determines whether or not to
# check for invalid files when opening model output data, and 2) if skip_checks
# is false, only keep file formats of which files actually exist in model_output_dir.
if (isTRUE(skip_checks)) {
exclude_invalid_files_flag <- FALSE
} else {
file_format <- check_file_format(model_output_dir, file_format)
exclude_invalid_files_flag <- TRUE
}

if (length(file_format) == 0L) {
dataset <- list()
Expand All @@ -123,7 +136,8 @@ connect_hub.default <- function(hub_path,
file_format = file_format,
config_tasks = config_tasks,
output_type_id_datatype = output_type_id_datatype,
partitions = partitions
partitions = partitions,
exclude_invalid_files_flag = exclude_invalid_files_flag
)
}
if (inherits(dataset, "UnionDataset")) {
Expand Down Expand Up @@ -165,7 +179,8 @@ connect_hub.SubTreeFileSystem <- function(hub_path,
"logical",
"Date"
),
partitions = list(model_id = arrow::utf8())) {
partitions = list(model_id = arrow::utf8()),
skip_checks = FALSE) {
rlang::check_required(hub_path)
output_type_id_datatype <- rlang::arg_match(output_type_id_datatype)

Expand All @@ -174,6 +189,11 @@ connect_hub.SubTreeFileSystem <- function(hub_path,
{.path {hub_path$base_path}}")
}

# set skip_checks value if not specified by user
if (missing(skip_checks)) {
skip_checks <- get_skip_check_option(hub_path)
}

config_admin <- hubUtils::read_config(hub_path, "admin")
config_tasks <- hubUtils::read_config(hub_path, "tasks")

Expand All @@ -187,8 +207,15 @@ connect_hub.SubTreeFileSystem <- function(hub_path,
}
hub_name <- config_admin$name

# Only keep file formats of which files actually exist in model_output_dir.
file_format <- check_file_format(model_output_dir, file_format)
# Based on skip_checks param: 1) set a flag that determines whether or not to
# check for invalid files when opening model output data, and 2) if skip_checks
# is false, only keep file formats of which files actually exist in model_output_dir.
if (isTRUE(skip_checks)) {
exclude_invalid_files_flag <- FALSE
} else {
file_format <- check_file_format(model_output_dir, file_format)
exclude_invalid_files_flag <- TRUE
}

if (length(file_format) == 0L) {
dataset <- list()
Expand All @@ -198,7 +225,8 @@ connect_hub.SubTreeFileSystem <- function(hub_path,
file_format = file_format,
config_tasks = config_tasks,
output_type_id_datatype = output_type_id_datatype,
partitions = partitions
partitions = partitions,
exclude_invalid_files_flag = exclude_invalid_files_flag
)
}

Expand All @@ -217,7 +245,7 @@ connect_hub.SubTreeFileSystem <- function(hub_path,
# files in dataset
warn_unopened_files(file_format, dataset, model_output_dir)

structure(dataset,
x <- structure(dataset,
class = c("hub_connection", class(dataset)),
hub_name = hub_name,
file_format = file_format,
Expand All @@ -238,7 +266,8 @@ open_hub_dataset <- function(model_output_dir,
"double", "integer",
"logical", "Date"
),
partitions = list(model_id = arrow::utf8())) {
partitions = list(model_id = arrow::utf8()),
exclude_invalid_files_flag) {
file_format <- rlang::arg_match(file_format)
schema <- create_hub_schema(config_tasks,
partitions = partitions,
Expand All @@ -253,23 +282,23 @@ open_hub_dataset <- function(model_output_dir,
col_types = schema,
unify_schemas = FALSE,
strings_can_be_null = TRUE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
),
parquet = arrow::open_dataset(
model_output_dir,
format = "parquet",
partitioning = "model_id",
schema = schema,
unify_schemas = FALSE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
),
arrow = arrow::open_dataset(
model_output_dir,
format = "arrow",
partitioning = "model_id",
schema = schema,
unify_schemas = FALSE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
)
)
}
Expand All @@ -284,14 +313,16 @@ open_hub_datasets <- function(model_output_dir,
"logical", "Date"
),
partitions = list(model_id = arrow::utf8()),
exclude_invalid_files_flag,
call = rlang::caller_env()) {
if (length(file_format) == 1L) {
open_hub_dataset(
model_output_dir = model_output_dir,
file_format = file_format,
config_tasks = config_tasks,
output_type_id_datatype,
partitions = partitions
partitions = partitions,
exclude_invalid_files_flag
)
} else {
cons <- purrr::map(
Expand All @@ -301,7 +332,8 @@ open_hub_datasets <- function(model_output_dir,
file_format = .x,
config_tasks = config_tasks,
output_type_id_datatype = output_type_id_datatype,
partitions = partitions
partitions = partitions,
exclude_invalid_files_flag
)
)

Expand Down
46 changes: 35 additions & 11 deletions R/connect_model_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,32 @@
connect_model_output <- function(model_output_dir,
file_format = c("csv", "parquet", "arrow"),
partition_names = "model_id",
schema = NULL) {
schema = NULL,
skip_checks = FALSE) {
UseMethod("connect_model_output")
}

#' @export
connect_model_output.default <- function(model_output_dir,
file_format = c("csv", "parquet", "arrow"),
partition_names = "model_id",
schema = NULL) {
schema = NULL,
skip_checks = FALSE) {
rlang::check_required(model_output_dir)
if (!dir.exists(model_output_dir)) {
cli::cli_abort(c("x" = "Directory {.path {model_output_dir}} does not exist."))
}

file_format <- rlang::arg_match(file_format)
# Only keep file formats of which files actually exist in model_output_dir.
file_format <- check_file_format(model_output_dir, file_format, error = TRUE)
# Based on skip_checks param: 1) set a flag that determines whether or not to
# check for invalid files when opening model output data, and 2) if skip_checks
# is false, only keep file formats of which files actually exist in model_output_dir.
if (isTRUE(skip_checks)) {
exclude_invalid_files_flag <- FALSE
} else {
file_format <- check_file_format(model_output_dir, file_format, error = TRUE)
exclude_invalid_files_flag <- TRUE
}

if (file_format == "csv") {
dataset <- arrow::open_dataset(
Expand All @@ -35,7 +45,7 @@ connect_model_output.default <- function(model_output_dir,
col_types = schema,
unify_schemas = TRUE,
strings_can_be_null = TRUE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
)
} else {
dataset <- arrow::open_dataset(
Expand All @@ -44,7 +54,7 @@ connect_model_output.default <- function(model_output_dir,
partitioning = partition_names,
schema = schema,
unify_schemas = TRUE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
)
}

Expand All @@ -64,11 +74,25 @@ connect_model_output.default <- function(model_output_dir,
connect_model_output.SubTreeFileSystem <- function(model_output_dir,
file_format = c("csv", "parquet", "arrow"),
partition_names = "model_id",
schema = NULL) {
schema = NULL,
skip_checks = FALSE) {
rlang::check_required(model_output_dir)

# set skip_checks value if not specified by user
if (missing(skip_checks)) {
skip_checks <- get_skip_check_option(model_output_dir)
}

file_format <- rlang::arg_match(file_format)
# Only keep file formats of which files actually exist in model_output_dir.
file_format <- check_file_format(model_output_dir, file_format, error = TRUE)
# Based on skip_checks param: 1) set a flag that determines whether or not to
# check for invalid files when opening model output data, and 2) if skip_checks
# is false, only keep file formats of which files actually exist in model_output_dir.
if (isTRUE(skip_checks)) {
exclude_invalid_files_flag <- FALSE
} else {
file_format <- check_file_format(model_output_dir, file_format, error = TRUE)
exclude_invalid_files_flag <- TRUE
}

if (file_format == "csv") {
dataset <- arrow::open_dataset(
Expand All @@ -78,7 +102,7 @@ connect_model_output.SubTreeFileSystem <- function(model_output_dir,
schema = schema,
unify_schemas = TRUE,
strings_can_be_null = TRUE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
)
} else {
dataset <- arrow::open_dataset(
Expand All @@ -87,7 +111,7 @@ connect_model_output.SubTreeFileSystem <- function(model_output_dir,
partitioning = partition_names,
schema = schema,
unify_schemas = TRUE,
factory_options = list(exclude_invalid_files = TRUE)
factory_options = list(exclude_invalid_files = exclude_invalid_files_flag)
)
}

Expand Down
8 changes: 8 additions & 0 deletions R/utils-connect_hub.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,11 @@ get_dir_file_formats <- function(model_output_dir) {

intersect(all_ext, c("csv", "parquet", "arrow"))
}

get_skip_check_option <- function(dir_path) {
if (dir_path[["url_scheme"]] == "s3") {
skip_checks <- TRUE
} else {
skip_checks <- FALSE
}
}
13 changes: 10 additions & 3 deletions man/connect_hub.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/hub-connection.md
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@
---

Code
connect_model_output(mod_out_path)
connect_model_output(mod_out_path, skip_checks = FALSE)
Condition
Error in `connect_model_output()`:
! No files of file format "csv" found in model output directory.
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-hub-connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ test_that("connect_hub returns empty list when model output folder is empty", {

# S3
hub_path <- s3_bucket("hubverse/hubutils/testhubs/empty/")
hub_con <- suppressWarnings(connect_hub(hub_path))
hub_con <- suppressWarnings(connect_hub(hub_path, skip_checks=FALSE))
attr(hub_con, "model_output_dir") <- "test/model_output_dir"
attr(hub_con, "hub_path") <- "test/hub_path"
expect_snapshot(hub_con)
Expand Down Expand Up @@ -256,7 +256,7 @@ test_that("connect_model_output fails on empty model_output_dir", {
)

mod_out_path <- s3_bucket("hubverse/hubutils/testhubs/empty/model-output")
expect_snapshot(connect_model_output(mod_out_path), error = TRUE)
expect_snapshot(connect_model_output(mod_out_path, skip_checks=FALSE), error = TRUE)
})


Expand Down Expand Up @@ -325,7 +325,7 @@ test_that("connect_hub works on S3 bucket simple forecasting hub on AWS", {
# Simple forecasting Hub example ----

hub_path <- s3_bucket("hubverse/hubutils/testhubs/simple/")
hub_con <- connect_hub(hub_path)
hub_con <- connect_hub(hub_path, skip_checks=FALSE)

# Tests that paths are assigned to attributes correctly
expect_equal(
Expand Down

0 comments on commit 2795333

Please sign in to comment.