Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ did.Rproj
.claude/
.vscode/
.revdep_manual/
vignettes/*_cache/
vignettes/*_cache/
2 changes: 2 additions & 0 deletions R/DIDparams.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ DIDparams <- function(yname,
nT=NULL,
tlist=NULL,
glist=NULL,
est_method_vars=NULL,
call=NULL) {

out <- list(yname=yname,
Expand All @@ -64,6 +65,7 @@ DIDparams <- function(yname,
pl=pl,
cores=cores,
est_method=est_method,
est_method_vars=est_method_vars,
base_period=base_period,
panel=panel,
true_repeated_cross_sections=true_repeated_cross_sections,
Expand Down
2 changes: 2 additions & 0 deletions R/DIDparams2.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DIDparams2 <- function(did_tensors, args, call=NULL) {
xformla <- args$xformla # formula of covariates
panel <- args$panel
est_method <- args$est_method
est_method_vars <- args$est_method_vars
bstrap <- args$bstrap
biters <- args$biters
cband <- args$cband
Expand Down Expand Up @@ -58,6 +59,7 @@ DIDparams2 <- function(did_tensors, args, call=NULL) {
xformla=xformla,
panel=panel,
est_method=est_method,
est_method_vars=est_method_vars,
bstrap=bstrap,
biters=biters,
cband=cband,
Expand Down
7 changes: 4 additions & 3 deletions R/MP.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
#' @param alp the significance level, default is 0.05
#' @param DIDparams a [`DIDparams`] object. A way to optionally return the parameters
#' of the call to [att_gt()] or [conditional_did_pretest()].
#' @param ... additional named elements to include in the MP object
#'
#' @return MP object
#' @export
MP <- function(group, t, att, V_analytical, se, c, inffunc, n=NULL, W=NULL, Wpval=NULL, aggte=NULL, alp = 0.05, DIDparams=NULL) {
out <- list(group=group, t=t, att=att, V_analytical=V_analytical, se=se, c=c,
MP <- function(group, t, att, V_analytical, se, c, inffunc, n=NULL, W=NULL, Wpval=NULL, aggte=NULL, alp = 0.05, DIDparams=NULL, ...) {
out <- c(list(group=group, t=t, att=att, V_analytical=V_analytical, se=se, c=c,
inffunc=inffunc, n=n, W=W, Wpval=Wpval, aggte=aggte, alp = alp,
DIDparams=DIDparams, call=DIDparams$call)
DIDparams=DIDparams, call=DIDparams$call), list(...))
class(out) <- "MP"
out
}
Expand Down
31 changes: 29 additions & 2 deletions R/att_gt.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@
#' the user allows for anticipation) to be equal to 0, but one
#' extra estimate in an earlier period.
#'
#' @param est_method_vars Optional character vector of column names from `data`
#' to pass through to a custom `est_method` function. These columns are
#' preserved through preprocessing, subsetted to match each (g,t) partition,
#' and passed to `est_method` as an additional `data` argument (a data.frame).
#' Ignored when using built-in estimation methods. Default is `NULL`.
#' @param ... Additional arguments to be passed to a custom `est_method`
#' function. These are ignored when using built-in estimation methods
#' (`"dr"`, `"ipw"`, `"reg"`).
Expand Down Expand Up @@ -206,6 +211,7 @@ att_gt <- function(yname,
print_details = FALSE,
pl = FALSE,
cores = 1,
est_method_vars = NULL,
...) {
# Capture extra arguments for custom est_method
extra_args <- list(...)
Expand All @@ -217,6 +223,22 @@ att_gt <- function(yname,
"\". Extra arguments are only passed to custom est_method functions.")
}

# Validate est_method_vars
if (!is.null(est_method_vars)) {
if (!is.character(est_method_vars)) {
stop("est_method_vars must be a character vector of column names from data.")
}
missing_emv <- setdiff(est_method_vars, colnames(data))
if (length(missing_emv) > 0) {
stop("The following est_method_vars are not found in data: ",
paste(missing_emv, collapse = ", "), ".")
}
if (!inherits(est_method, "function")) {
warning("est_method_vars is specified but est_method is not a custom function. ",
"est_method_vars will be ignored.")
}
}

# Validate est_method
if (!inherits(est_method, "function")) {
if (!is.character(est_method) || length(est_method) != 1) {
Expand Down Expand Up @@ -255,6 +277,7 @@ att_gt <- function(yname,
biters = biters,
clustervars = clustervars,
est_method = est_method,
est_method_vars = est_method_vars,
base_period = base_period,
print_details = print_details,
faster_mode = faster_mode,
Expand Down Expand Up @@ -290,6 +313,7 @@ att_gt <- function(yname,
biters = biters,
clustervars = clustervars,
est_method = est_method,
est_method_vars = est_method_vars,
base_period = base_period,
print_details = print_details,
pl = pl,
Expand Down Expand Up @@ -485,6 +509,9 @@ att_gt <- function(yname,
}


# Return this list
return(MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval, inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp, DIDparams = dp))
# Build the MP object, append extra_gt results form est_method calls if any
extra_gt <- Filter(Negate(is.null), lapply(attgt.list, function(x) x$extra))
MP(group = group, t = tt, att = att, V_analytical = V, se = se, c = cval,
inffunc = inffunc, n = n, W = W, Wpval = Wpval, alp = alp,
DIDparams = dp, extra_gt = if (length(extra_gt)) extra_gt)
}
35 changes: 30 additions & 5 deletions R/compute.att_gt.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ compute.att_gt <- function(dp) {
xformla <- dp$xformla
weightsname <- dp$weightsname
est_method <- dp$est_method
est_method_vars <- dp$est_method_vars
extra_args <- if (is.null(dp$extra_args)) list() else dp$extra_args
base_period <- dp$base_period
panel <- dp$panel
Expand Down Expand Up @@ -249,13 +250,24 @@ compute.att_gt <- function(dp) {
attgt <- tryCatch({
if (inherits(est_method, "function")) {
# user-specified function
res <- do.call(est_method, c(list(
base_args <- list(
y1 = Ypost, y0 = Ypre,
D = G,
covariates = covariates,
i.weights = w,
inffunc = TRUE
), extra_args))
)
# forward cell identity if est_method can accept it
fmls <- names(formals(est_method))
if ("g" %in% fmls) {
base_args$g <- glist[g]
base_args$t <- tlist[(t + tfac)]
}
# add passthrough variables if specified
if (!is.null(est_method_vars)) {
base_args$data <- disdat[, est_method_vars, with = FALSE]
}
res <- do.call(est_method, c(base_args, extra_args))
} else if (est_method == "ipw") {
# inverse-probability weights
res <- DRDID::std_ipw_did_panel(Ypost, Ypre, G,
Expand Down Expand Up @@ -416,14 +428,25 @@ compute.att_gt <- function(dp) {
attgt <- tryCatch({
if (inherits(est_method, "function")) {
# user-specified function
res <- do.call(est_method, c(list(
base_args <- list(
y = Y,
post = post,
D = G,
covariates = covariates,
i.weights = w,
inffunc = TRUE
), extra_args))
)
# forward cell identity if est_method can accept it
fmls <- names(formals(est_method))
if ("g" %in% fmls) {
base_args$g <- glist[g]
base_args$t <- tlist[(t + tfac)]
}
# add passthrough variables if specified
if (!is.null(est_method_vars)) {
base_args$data <- disdat[, est_method_vars, with = FALSE]
}
res <- do.call(est_method, c(base_args, extra_args))
} else if (est_method == "ipw") {
# inverse-probability weights
res <- DRDID::std_ipw_did_rc(
Expand Down Expand Up @@ -485,8 +508,10 @@ compute.att_gt <- function(dp) {
} # end panel if

# save results for this att(g,t)
extra <- if (custom_est_method) attgt[!names(attgt) %in% c("ATT", "att.inf.func")] else NULL
if (!length(extra)) extra <- NULL
attgt.list[[counter]] <- list(
att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat
att = attgt$ATT, group = glist[g], year = tlist[(t + tfac)], post = post.treat, extra = extra
)


Expand Down
61 changes: 45 additions & 16 deletions R/compute.att_gt2.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ get_did_cohort_index <- function(group, time, tfac, pret, dp2){
run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){

extra_args <- if (is.null(dp2$extra_args)) list() else dp2$extra_args
est_method_vars <- dp2$est_method_vars
gt_label <- if (!is.null(g_val) && !is.null(t_val)) paste0(" for group ", g_val, " in time period ", t_val) else ""

if(dp2$panel){
Expand Down Expand Up @@ -150,13 +151,24 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){

if (inherits(dp2$est_method, "function")) {
# user-specified function
attgt <- do.call(dp2$est_method, c(list(
y1=cohort_data[, y1],
y0=cohort_data[, y0],
D=cohort_data[, D],
covariates=covariates,
i.weights=cohort_data[, i.weights],
inffunc=TRUE), extra_args))
base_args <- list(
y1=cohort_data[, y1],
y0=cohort_data[, y0],
D=cohort_data[, D],
covariates=covariates,
i.weights=cohort_data[, i.weights],
inffunc=TRUE)
# forward cell identity if est_method can accept it
fmls <- names(formals(dp2$est_method))
if ("g" %in% fmls) {
base_args$g <- g_val
base_args$t <- t_val
}
# add passthrough variables if specified
if (!is.null(est_method_vars)) {
base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE]
}
attgt <- do.call(dp2$est_method, c(base_args, extra_args))
} else if (dp2$est_method == "ipw") {
# inverse-probability weights
attgt <- std_ipw_did_panel(y1=cohort_data[, y1],
Expand Down Expand Up @@ -252,13 +264,24 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){

if (inherits(dp2$est_method, "function")) {
# user-specified function
attgt <- do.call(dp2$est_method, c(list(
y=cohort_data[, y],
post=cohort_data[, post],
D=cohort_data[, D],
covariates=covariates,
i.weights=cohort_data[, i.weights],
inffunc=TRUE), extra_args))
base_args <- list(
y=cohort_data[, y],
post=cohort_data[, post],
D=cohort_data[, D],
covariates=covariates,
i.weights=cohort_data[, i.weights],
inffunc=TRUE)
# forward cell identity if est_method can accept it
fmls <- names(formals(dp2$est_method))
if ("g" %in% fmls) {
base_args$g <- g_val
base_args$t <- t_val
}
# add passthrough variables if specified
if (!is.null(est_method_vars)) {
base_args$data <- dp2$time_invariant_data[valid_obs, est_method_vars, with = FALSE]
}
attgt <- do.call(dp2$est_method, c(base_args, extra_args))
} else if (dp2$est_method == "ipw") {
# inverse-probability weights
attgt <- std_ipw_did_rc(y=cohort_data[, y],
Expand Down Expand Up @@ -306,7 +329,13 @@ run_DRDID <- function(cohort_data, covariates, dp2, g_val = NULL, t_val = NULL){

}

return(list(att = attgt$ATT, inf_func = inf_func_vector))
result <- list(att = attgt$ATT, inf_func = inf_func_vector)
# forward extra fields from custom est_method (if any)
if (custom_est_method) {
extra <- attgt[!names(attgt) %in% c("ATT", "att.inf.func")]
if (length(extra) > 0) result$extra <- extra
}
return(result)

}

Expand Down Expand Up @@ -495,7 +524,7 @@ compute.att_gt2 <- function(dp2) {

# Save ATT and influence function
inffunc_updates <- inf_func
gt_result <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates)
gt_result <- list(att = att, group = dp2$treated_groups[g], year = dp2$time_periods[t+tfac], post = post.treat, inffunc_updates = inffunc_updates, extra = gt_result$extra)
return(gt_result)
}
}
Expand Down
4 changes: 3 additions & 1 deletion R/pre_process_did.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pre_process_did <- function(yname,
faster_mode = FALSE,
pl = FALSE,
cores = 1,
est_method_vars = NULL,
call = NULL) {
#-----------------------------------------------------------------------------
# Data pre-processing and error checking
Expand Down Expand Up @@ -79,7 +80,7 @@ pre_process_did <- function(yname,
}

# drop irrelevant columns from data
data <- cbind.data.frame(data[,c(idname, tname, yname, gname, weightsname, clustervars)], model.frame(xformla, data=data, na.action=na.pass))
data <- cbind.data.frame(data[,c(idname, tname, yname, gname, weightsname, clustervars, est_method_vars)], model.frame(xformla, data=data, na.action=na.pass))

# check if any covariates were missing
n_orig <- nrow(data)
Expand Down Expand Up @@ -399,6 +400,7 @@ pre_process_did <- function(yname,
pl=pl,
cores=cores,
est_method=est_method,
est_method_vars=est_method_vars,
base_period=base_period,
panel=panel,
true_repeated_cross_sections=true_repeated_cross_sections,
Expand Down
3 changes: 2 additions & 1 deletion R/pre_process_did2.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ validate_args <- function(args, data){
#' @noRd
did_standardization <- function(data, args){
# keep relevant columns in data
cols_to_keep <- c(args$idname, args$tname, args$gname, args$yname, args$weightsname, args$clustervars)
cols_to_keep <- c(args$idname, args$tname, args$gname, args$yname, args$weightsname, args$clustervars, args$est_method_vars)

model_frame <- model.frame(args$xformla, data = data, na.action = na.pass)
# Subset the dataset to keep only the relevant columns
Expand Down Expand Up @@ -570,6 +570,7 @@ pre_process_did2 <- function(yname,
faster_mode=FALSE,
pl = FALSE,
cores = 1,
est_method_vars = NULL,
call = NULL) {


Expand Down
7 changes: 7 additions & 0 deletions man/DIDparams.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/att_gt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading