diff --git a/.gitignore b/.gitignore index c392361..a1e44ce 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,4 @@ did.Rproj .claude/ .vscode/ .revdep_manual/ -vignettes/*_cache/ \ No newline at end of file +vignettes/*_cache/ diff --git a/R/DIDparams.R b/R/DIDparams.R index 7aafd3a..6e6a08d 100644 --- a/R/DIDparams.R +++ b/R/DIDparams.R @@ -43,6 +43,7 @@ DIDparams <- function(yname, nT=NULL, tlist=NULL, glist=NULL, + est_method_vars=NULL, call=NULL) { out <- list(yname=yname, @@ -64,6 +65,7 @@ DIDparams <- function(yname, pl=pl, cores=cores, est_method=est_method, + est_method_vars=est_method_vars, base_period=base_period, panel=panel, true_repeated_cross_sections=true_repeated_cross_sections, diff --git a/R/DIDparams2.R b/R/DIDparams2.R index 108b50b..5bb778e 100644 --- a/R/DIDparams2.R +++ b/R/DIDparams2.R @@ -16,6 +16,7 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { xformla <- args$xformla # formula of covariates panel <- args$panel est_method <- args$est_method + est_method_vars <- args$est_method_vars bstrap <- args$bstrap biters <- args$biters cband <- args$cband @@ -58,6 +59,7 @@ DIDparams2 <- function(did_tensors, args, call=NULL) { xformla=xformla, panel=panel, est_method=est_method, + est_method_vars=est_method_vars, bstrap=bstrap, biters=biters, cband=cband, diff --git a/R/MP.R b/R/MP.R index ffeb0ae..a953045 100644 --- a/R/MP.R +++ b/R/MP.R @@ -20,13 +20,14 @@ #' @param alp the significance level, default is 0.05 #' @param DIDparams a [`DIDparams`] object. A way to optionally return the parameters #' of the call to [att_gt()] or [conditional_did_pretest()]. +#' @param ... additional named elements to include in the MP object #' #' @return MP object #' @export -MP <- function(group, t, att, V_analytical, se, c, inffunc, n=NULL, W=NULL, Wpval=NULL, aggte=NULL, alp = 0.05, DIDparams=NULL) { - out <- list(group=group, t=t, att=att, V_analytical=V_analytical, se=se, c=c, +MP <- function(group, t, att, V_analytical, se, c, inffunc, n=NULL, W=NULL, Wpval=NULL, aggte=NULL, alp = 0.05, DIDparams=NULL, ...) { + out <- c(list(group=group, t=t, att=att, V_analytical=V_analytical, se=se, c=c, inffunc=inffunc, n=n, W=W, Wpval=Wpval, aggte=aggte, alp = alp, - DIDparams=DIDparams, call=DIDparams$call) + DIDparams=DIDparams, call=DIDparams$call), list(...)) class(out) <- "MP" out } diff --git a/R/att_gt.R b/R/att_gt.R index b5663b4..2ed6845 100644 --- a/R/att_gt.R +++ b/R/att_gt.R @@ -131,6 +131,11 @@ #' the user allows for anticipation) to be equal to 0, but one #' extra estimate in an earlier period. #' +#' @param est_method_vars Optional character vector of column names from `data` +#' to pass through to a custom `est_method` function. These columns are +#' preserved through preprocessing, subsetted to match each (g,t) partition, +#' and passed to `est_method` as an additional `data` argument (a data.frame). +#' Ignored when using built-in estimation methods. Default is `NULL`. #' @param ... Additional arguments to be passed to a custom `est_method` #' function. These are ignored when using built-in estimation methods #' (`"dr"`, `"ipw"`, `"reg"`). @@ -206,6 +211,7 @@ att_gt <- function(yname, print_details = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, ...) { # Capture extra arguments for custom est_method extra_args <- list(...) @@ -217,6 +223,22 @@ att_gt <- function(yname, "\". Extra arguments are only passed to custom est_method functions.") } + # Validate est_method_vars + if (!is.null(est_method_vars)) { + if (!is.character(est_method_vars)) { + stop("est_method_vars must be a character vector of column names from data.") + } + missing_emv <- setdiff(est_method_vars, colnames(data)) + if (length(missing_emv) > 0) { + stop("The following est_method_vars are not found in data: ", + paste(missing_emv, collapse = ", "), ".") + } + if (!inherits(est_method, "function")) { + warning("est_method_vars is specified but est_method is not a custom function. ", + "est_method_vars will be ignored.") + } + } + # Validate est_method if (!inherits(est_method, "function")) { if (!is.character(est_method) || length(est_method) != 1) { @@ -255,6 +277,7 @@ att_gt <- function(yname, biters = biters, clustervars = clustervars, est_method = est_method, + est_method_vars = est_method_vars, base_period = base_period, print_details = print_details, faster_mode = faster_mode, @@ -290,6 +313,7 @@ att_gt <- function(yname, biters = biters, clustervars = clustervars, est_method = est_method, + est_method_vars = est_method_vars, base_period = base_period, print_details = print_details, pl = pl, @@ -485,6 +509,9 @@ att_gt <- function(yname, } - # Return this list - return(MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, DIDparams = dp)) + # Build the MP object, append extra_gt results form est_method calls if any + extra_gt <- Filter(Negate(is.null), lapply(attgt.list, function(x) x$extra)) + MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, + inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, + DIDparams = dp, extra_gt = if (length(extra_gt)) extra_gt) } diff --git a/R/compute.att_gt.R b/R/compute.att_gt.R index bef72c3..4374f49 100644 --- a/R/compute.att_gt.R +++ b/R/compute.att_gt.R @@ -26,6 +26,7 @@ compute.att_gt <- function(dp) { xformla <- dp$xformla weightsname <- dp$weightsname est_method <- dp$est_method + est_method_vars <- dp$est_method_vars extra_args <- if (is.null(dp$extra_args)) list() else dp$extra_args base_period <- dp$base_period panel <- dp$panel @@ -249,13 +250,24 @@ compute.att_gt <- function(dp) { attgt <- tryCatch({ if (inherits(est_method, "function")) { # user-specified function - res <- do.call(est_method, c(list( + base_args <- list( y1 = Ypost, y0 = Ypre, D = G, covariates = covariates, i.weights = w, inffunc = TRUE - ), extra_args)) + ) + # forward cell identity if est_method can accept it + fmls <- names(formals(est_method)) + if ("g" %in% fmls) { + base_args$g <- glist[g] + base_args$t <- tlist[(t + tfac)] + } + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- disdat[, est_method_vars, with = FALSE] + } + res <- do.call(est_method, c(base_args, extra_args)) } else if (est_method == "ipw") { # inverse-probability weights res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G, @@ -416,14 +428,25 @@ compute.att_gt <- function(dp) { attgt <- tryCatch({ if (inherits(est_method, "function")) { # user-specified function - res <- do.call(est_method, c(list( + base_args <- list( y = Y, post = post, D = G, covariates = covariates, i.weights = w, inffunc = TRUE - ), extra_args)) + ) + # forward cell identity if est_method can accept it + fmls <- names(formals(est_method)) + if ("g" %in% fmls) { + base_args$g <- glist[g] + base_args$t <- tlist[(t + tfac)] + } + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- disdat[, est_method_vars, with = FALSE] + } + res <- do.call(est_method, c(base_args, extra_args)) } else if (est_method == "ipw") { # inverse-probability weights res <- DRDID::std_ipw_did_rc( @@ -485,8 +508,10 @@ compute.att_gt <- function(dp) { } # end panel if # save results for this att(g,t) + extra <- if (custom_est_method) attgt[!names(attgt) %in% c("ATT", "att.inf.func")] else NULL + if (!length(extra)) extra <- NULL attgt.list[[counter]] <- list( - att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat + att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat, extra = extra ) diff --git a/R/compute.att_gt2.R b/R/compute.att_gt2.R index bfdf3ea..7f1eb77 100644 --- a/R/compute.att_gt2.R +++ b/R/compute.att_gt2.R @@ -91,6 +91,7 @@ get_did_cohort_index <- function(group, time, tfac, pret, dp2){ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ extra_args <- if (is.null(dp2$extra_args)) list() else dp2$extra_args + est_method_vars <- dp2$est_method_vars gt_label <- if (!is.null(g_val) && !is.null(t_val)) paste0(" for group ", g_val, " in time period ", t_val) else "" if(dp2$panel){ @@ -150,13 +151,24 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ if (inherits(dp2$est_method, "function")) { # user-specified function - attgt <- do.call(dp2$est_method, c(list( - y1=cohort_data[, y1], - y0=cohort_data[, y0], - D=cohort_data[, D], - covariates=covariates, - i.weights=cohort_data[, i.weights], - inffunc=TRUE), extra_args)) + base_args <- list( + y1=cohort_data[, y1], + y0=cohort_data[, y0], + D=cohort_data[, D], + covariates=covariates, + i.weights=cohort_data[, i.weights], + inffunc=TRUE) + # forward cell identity if est_method can accept it + fmls <- names(formals(dp2$est_method)) + if ("g" %in% fmls) { + base_args$g <- g_val + base_args$t <- t_val + } + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE] + } + attgt <- do.call(dp2$est_method, c(base_args, extra_args)) } else if (dp2$est_method == "ipw") { # inverse-probability weights attgt <- std_ipw_did_panel(y1=cohort_data[, y1], @@ -252,13 +264,24 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ if (inherits(dp2$est_method, "function")) { # user-specified function - attgt <- do.call(dp2$est_method, c(list( - y=cohort_data[, y], - post=cohort_data[, post], - D=cohort_data[, D], - covariates=covariates, - i.weights=cohort_data[, i.weights], - inffunc=TRUE), extra_args)) + base_args <- list( + y=cohort_data[, y], + post=cohort_data[, post], + D=cohort_data[, D], + covariates=covariates, + i.weights=cohort_data[, i.weights], + inffunc=TRUE) + # forward cell identity if est_method can accept it + fmls <- names(formals(dp2$est_method)) + if ("g" %in% fmls) { + base_args$g <- g_val + base_args$t <- t_val + } + # add passthrough variables if specified + if (!is.null(est_method_vars)) { + base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE] + } + attgt <- do.call(dp2$est_method, c(base_args, extra_args)) } else if (dp2$est_method == "ipw") { # inverse-probability weights attgt <- std_ipw_did_rc(y=cohort_data[, y], @@ -306,7 +329,13 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){ } - return(list(att = attgt$ATT, inf_func = inf_func_vector)) + result <- list(att = attgt$ATT, inf_func = inf_func_vector) + # forward extra fields from custom est_method (if any) + if (custom_est_method) { + extra <- attgt[!names(attgt) %in% c("ATT", "att.inf.func")] + if (length(extra) > 0) result$extra <- extra + } + return(result) } @@ -495,7 +524,7 @@ compute.att_gt2 <- function(dp2) { # Save ATT and influence function inffunc_updates <- inf_func - gt_result <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates) + gt_result <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates, extra = gt_result$extra) return(gt_result) } } diff --git a/R/pre_process_did.R b/R/pre_process_did.R index 3ee13c5..e9c0f3a 100644 --- a/R/pre_process_did.R +++ b/R/pre_process_did.R @@ -32,6 +32,7 @@ pre_process_did <- function(yname, faster_mode = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL) { #----------------------------------------------------------------------------- # Data pre-processing and error checking @@ -79,7 +80,7 @@ pre_process_did <- function(yname, } # drop irrelevant columns from data - data <- cbind.data.frame(data[,c(idname, tname, yname, gname, weightsname, clustervars)], model.frame(xformla, data=data, na.action=na.pass)) + data <- cbind.data.frame(data[,c(idname, tname, yname, gname, weightsname, clustervars, est_method_vars)], model.frame(xformla, data=data, na.action=na.pass)) # check if any covariates were missing n_orig <- nrow(data) @@ -399,6 +400,7 @@ pre_process_did <- function(yname, pl=pl, cores=cores, est_method=est_method, + est_method_vars=est_method_vars, base_period=base_period, panel=panel, true_repeated_cross_sections=true_repeated_cross_sections, diff --git a/R/pre_process_did2.R b/R/pre_process_did2.R index 733fbe5..b589f76 100644 --- a/R/pre_process_did2.R +++ b/R/pre_process_did2.R @@ -107,7 +107,7 @@ validate_args <- function(args, data){ #' @noRd did_standardization <- function(data, args){ # keep relevant columns in data - cols_to_keep <- c(args$idname, args$tname, args$gname, args$yname, args$weightsname, args$clustervars) + cols_to_keep <- c(args$idname, args$tname, args$gname, args$yname, args$weightsname, args$clustervars, args$est_method_vars) model_frame <- model.frame(args$xformla, data = data, na.action = na.pass) # Subset the dataset to keep only the relevant columns @@ -570,6 +570,7 @@ pre_process_did2 <- function(yname, faster_mode=FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL) { diff --git a/man/DIDparams.Rd b/man/DIDparams.Rd index 30c0ff6..c22e4ac 100644 --- a/man/DIDparams.Rd +++ b/man/DIDparams.Rd @@ -32,6 +32,7 @@ DIDparams( nT = NULL, tlist = NULL, glist = NULL, + est_method_vars = NULL, call = NULL ) } @@ -191,6 +192,12 @@ of rows in a panel dataset).} \item{glist}{a vector containing each group} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{call}{Function call to att_gt} } \description{ diff --git a/man/att_gt.Rd b/man/att_gt.Rd index ac31192..d076c13 100644 --- a/man/att_gt.Rd +++ b/man/att_gt.Rd @@ -27,6 +27,7 @@ att_gt( print_details = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, ... ) } @@ -177,6 +178,12 @@ Default is \code{FALSE}.} \item{cores}{The number of cores to use for parallel processing} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{...}{Additional arguments to be passed to a custom \code{est_method} function. These are ignored when using built-in estimation methods (\code{"dr"}, \code{"ipw"}, \code{"reg"}).} diff --git a/man/pre_process_did.Rd b/man/pre_process_did.Rd index 61c27df..4934e6b 100644 --- a/man/pre_process_did.Rd +++ b/man/pre_process_did.Rd @@ -27,6 +27,7 @@ pre_process_did( faster_mode = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL ) } @@ -177,6 +178,12 @@ it is recommended for use with large datasets.} \item{cores}{The number of cores to use for parallel processing} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{call}{Function call to att_gt} } \value{ diff --git a/man/pre_process_did2.Rd b/man/pre_process_did2.Rd index b7b90d3..1cac978 100644 --- a/man/pre_process_did2.Rd +++ b/man/pre_process_did2.Rd @@ -27,6 +27,7 @@ pre_process_did2( faster_mode = FALSE, pl = FALSE, cores = 1, + est_method_vars = NULL, call = NULL ) } @@ -177,6 +178,12 @@ it is recommended for use with large datasets.} \item{cores}{The number of cores to use for parallel processing} +\item{est_method_vars}{Optional character vector of column names from \code{data} +to pass through to a custom \code{est_method} function. These columns are +preserved through preprocessing, subsetted to match each (g,t) partition, +and passed to \code{est_method} as an additional \code{data} argument (a data.frame). +Ignored when using built-in estimation methods. Default is \code{NULL}.} + \item{call}{Function call to att_gt} } \value{ diff --git a/tests/testthat/test-att_gt.R b/tests/testthat/test-att_gt.R index 5763d4f..ea1f87d 100644 --- a/tests/testthat/test-att_gt.R +++ b/tests/testthat/test-att_gt.R @@ -1275,3 +1275,158 @@ test_that("faster_mode time indexing with universal base period", { expect_equal(res_slow$att, res_fast$att) expect_equal(res_slow$se, as.numeric(res_fast$se)) }) + +test_that("est_method_vars passes through variables to custom est_method", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Add a column that we want to pass through + data$fold_id <- sample(1:3, nrow(data), replace = TRUE) + + # Custom est_method that checks for the data argument + my_est <- function(y1, y0, D, covariates, i.weights, inffunc, data) { + # Verify data is a data.frame with the right column + stopifnot(is.data.frame(data)) + stopifnot("fold_id" %in% names(data)) + stopifnot(nrow(data) == length(y1)) + + # Use DRDID to compute the actual estimate + DRDID::drdid_imp_panel(y1 = y1, y0 = y0, D = D, + covariates = covariates, + i.weights = i.weights, + inffunc = inffunc) + } + + # faster_mode = TRUE (panel) + res_fast <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est, est_method_vars = c("fold_id"), + bstrap = FALSE, cband = FALSE, faster_mode = TRUE) + expect_equal(res_fast$att[1], 1, tol = .5) + + # faster_mode = FALSE (panel) + res_slow <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est, est_method_vars = c("fold_id"), + bstrap = FALSE, cband = FALSE, faster_mode = FALSE) + expect_equal(res_slow$att[1], 1, tol = .5) + + # ATTs should be the same across modes + expect_equal(res_fast$att, res_slow$att) +}) + +test_that("est_method_vars works with multiple variables", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + data$fold_id <- sample(1:3, nrow(data), replace = TRUE) + data$stratum <- sample(letters[1:5], nrow(data), replace = TRUE) + + my_est <- function(y1, y0, D, covariates, i.weights, inffunc, data) { + stopifnot(all(c("fold_id", "stratum") %in% names(data))) + DRDID::drdid_imp_panel(y1 = y1, y0 = y0, D = D, + covariates = covariates, + i.weights = i.weights, inffunc = inffunc) + } + + res <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est, + est_method_vars = c("fold_id", "stratum"), + bstrap = FALSE, cband = FALSE) + expect_equal(res$att[1], 1, tol = .5) +}) + +test_that("est_method_vars validation errors", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Non-existent column + expect_error( + att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = function(...) NULL, + est_method_vars = c("nonexistent_col")), + "not found in data" + ) + + # Non-character input + expect_error( + att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = function(...) NULL, + est_method_vars = 42), + "character vector" + ) + + # Warning when used with built-in est_method + expect_warning( + att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = "dr", + est_method_vars = c("X")), + "not a custom function" + ) +}) + +test_that("extra_gt captures additional outputs from custom est_method", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + # Custom est_method that returns extra fields + my_est_extra <- function(y1, y0, D, covariates, i.weights, inffunc) { + res <- DRDID::drdid_imp_panel(y1 = y1, y0 = y0, D = D, + covariates = covariates, + i.weights = i.weights, + inffunc = inffunc) + # Add extra fields + res$n_treated <- sum(D) + res$n_control <- sum(1 - D) + res$my_diagnostic <- "ok" + res + } + + # faster_mode = TRUE + res_fast <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est_extra, + bstrap = FALSE, cband = FALSE, faster_mode = TRUE) + + # extra_gt should exist and be a list + expect_true(!is.null(res_fast$extra_gt)) + expect_true(is.list(res_fast$extra_gt)) + expect_equal(length(res_fast$extra_gt), length(res_fast$att)) + + # Each entry should have the extra fields + first_extra <- res_fast$extra_gt[[1]] + expect_true("n_treated" %in% names(first_extra)) + expect_true("n_control" %in% names(first_extra)) + expect_true("my_diagnostic" %in% names(first_extra)) + expect_equal(first_extra$my_diagnostic, "ok") + + # faster_mode = FALSE + res_slow <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = my_est_extra, + bstrap = FALSE, cband = FALSE, faster_mode = FALSE) + + expect_true(!is.null(res_slow$extra_gt)) + expect_equal(length(res_slow$extra_gt), length(res_slow$att)) + expect_true("n_treated" %in% names(res_slow$extra_gt[[1]])) +}) + +test_that("extra_gt is NULL for built-in est_method", { + set.seed(09142024) + sp <- did::reset.sim() + data <- did::build_sim_dataset(sp) + + res <- att_gt(yname = "Y", xformla = ~X, data = data, + tname = "period", idname = "id", gname = "G", + est_method = "dr", bstrap = FALSE, cband = FALSE) + + # Built-in methods should not produce extra_gt + expect_null(res$extra_gt) +})