Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 167 additions & 3 deletions R/helper_vimpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,52 @@ register_robust_learners <- function() {
# print(pred)

### +++++++++++++++++++++++++++++++++ Helper Functions +++++++++++++++++++++++++++++++++ ###


#
#
#
ensure_dummy_rows_for_factors <- function(dt, target_col) {
dt <- data.table::copy(dt)

factor_cols <- names(dt)[sapply(dt, is.factor)]
factor_cols <- setdiff(factor_cols, target_col)

for (col in factor_cols) {
lvls <- levels(dt[[col]])
present <- unique(dt[[col]])
missing_lvls <- setdiff(lvls, present)

if (length(missing_lvls) > 0) {
for (lvl in missing_lvls) {
dummy <- dt[1]
for (fc in factor_cols) {
dummy[[fc]] <- levels(dt[[fc]])[1]
}
dummy[[col]] <- lvl
dummy[[target_col]] <- levels(dt[[target_col]])[1]
dt <- rbind(dt, dummy)
}
}
}
dt
}
#
#
#
needs_ranger_classif <- function(y, X) {
tab <- table(y)
imbalance <- min(tab) / sum(tab) < 0.05
high_dim <- ncol(X) > nrow(X) / 5
rare_levels <- any(sapply(X, function(col) {
is.factor(col) && any(table(col) < 10)
}))
multicollinear <- ncol(X) > 1 && {
mm <- model.matrix(~ ., data = X)
qr(mm)$rank < ncol(mm)
}
imbalance || high_dim || rare_levels || multicollinear
}
#
#
#
Expand Down Expand Up @@ -296,16 +342,19 @@ precheck <- function(
pmm,
formula,
method,
sequential
sequential,
pmm_k
) {

# check missing data
variables = colnames(data)
variables_NA = colnames(data)[apply(data, 2, function(x) any(is.na(x)))] # alle Variablen die missind data haben
if (length(variables_NA) == 0) {
stop ("Error: No missing data available")
stop("Error: No missing data available")
} else {
message ("Variables with Missing Data: ", paste (variables_NA, collapse = ","))
# if (verbose) {
# message("Variables with Missing Data: ", paste(variables_NA, collapse = ", "))
# }
}

# check data structure
Expand Down Expand Up @@ -351,6 +400,22 @@ precheck <- function(
}
check_pmm(pmm, variables)

# check pmm_k
if (any(unlist(pmm))) {
if (
!is.numeric(pmm_k) ||
length(pmm_k) != 1 ||
is.na(pmm_k) ||
pmm_k < 1 ||
pmm_k %% 1 != 0
) {
stop(
"Error: 'pmm_k' must be a single positive integer (>= 1) ",
"when predictive mean matching (PMM) is enabled."
)
}
}

# check methods
supported_methods <- c("ranger", "regularized", "xgboost", "robust")

Expand Down Expand Up @@ -379,6 +444,42 @@ precheck <- function(
stop("Error: 'method' must either be empty, a single string, a single-element list, have the same length as 'variables' or 'considered_variables' (if specified), or the number of variables must match NAs.")
}

# check method for regularized
# ---- Check regularized method for target and predictors ----
for (var in variables_NA) {
y_obs <- data[[var]][!is.na(data[[var]])]

# Target variable check
if (method[[var]] %in% c("regularized", "glmnet")) {
if (is.factor(y_obs) && any(table(y_obs) <= 1)) {
warning(paste0("Variable '", var, "' has too few observations per class for 'regularized'. Falling back to 'robust'."))
method[[var]] <- "robust"
next
}
if (is.numeric(y_obs) && length(unique(y_obs)) < 3) {
warning(paste0("Variable '", var, "' has too few unique values for 'regularized'. Falling back to 'robust'."))
method[[var]] <- "robust"
next
}

# Predictor check
predictors <- setdiff(names(data), var)
for (col in predictors) {
x_obs <- data[[col]][!is.na(data[[col]])]
if (is.factor(x_obs) && any(table(x_obs) <= 1)) {
warning(paste0("Predictor '", col, "' has too few observations per class. Falling back to 'robust' for target '", var, "'."))
method[[var]] <- "robust"
break
}
if (is.numeric(x_obs) && length(unique(x_obs)) < 2) {
warning(paste0("Predictor '", col, "' has too few unique values. Falling back to 'robust' for target '", var, "'."))
method[[var]] <- "robust"
break
}
}
}
}

# warning if more than 50% missing values
if (nrow(data) == 0) stop("Error: Data has no rows.")
missing_counts <- colSums(is.na(data))
Expand All @@ -388,13 +489,20 @@ precheck <- function(
}

# Datatypes
ordered_cols <- names(data)[sapply(data, inherits, "ordered")]
if (length(ordered_cols) > 0) {
data[, (ordered_cols) := lapply(.SD, function(x) factor(as.character(x))), .SDcols = ordered_cols]
}

data[, (variables) := lapply(.SD, function(x) {
if (is.numeric(x)) {
as.numeric(x) # Integer & Double in Numeric
} else if (is.character(x)) {
as.factor(x) # Strings in Factors
} else if (is.logical(x)) {
as.numeric(x) # TRUE/FALSE -> 1/0
# } else if (inherits(x, "ordered")) {
# as.factor(x)
} else if (is.factor(x)) {
x
} else {
Expand Down Expand Up @@ -513,3 +621,59 @@ check_factor_levels <- function(data, original_levels) {
}
}
}
#
#
#
# helper: inverse Transformation
inverse_transform <- function(x, method) {
switch(method,
exp = log(x),
log = exp(x),
sqrt = x^2,
inverse = 1 / x,
stop("Unknown transformation: ", method)
)
}
#
#
#
# helper decimal places
get_decimal_places <- function(x) {
if (is.na(x)) return(0)
if (x == floor(x)) return(0)
nchar(sub(".*\\.", "", as.character(x)))
}
#
#
#
# helper: ranger regression prediction via per-tree median
predict_ranger_median <- function(graph_learner, newdata, target_name = NULL) {
model_names <- names(graph_learner$model)
ranger_idx <- grep("regr\\.ranger$", model_names)
if (length(ranger_idx) == 0) {
return(NULL)
}

ranger_model <- graph_learner$model[[ranger_idx[1]]]$model
if (is.list(ranger_model) && !inherits(ranger_model, "ranger") && "model" %in% names(ranger_model)) {
ranger_model <- ranger_model$model
}
if (!inherits(ranger_model, "ranger")) {
return(NULL)
}

pred_dt <- as.data.table(newdata)
if (!is.null(target_name) && target_name %in% colnames(pred_dt)) {
pred_dt <- pred_dt[, setdiff(colnames(pred_dt), target_name), with = FALSE]
}

tree_preds <- predict(ranger_model, data = as.data.frame(pred_dt), predict.all = TRUE)$predictions
if (is.null(dim(tree_preds))) {
return(as.numeric(tree_preds))
}
if (length(dim(tree_preds)) != 2) {
return(NULL)
}

apply(tree_preds, 1, median)
}
121 changes: 68 additions & 53 deletions R/rangerImpute.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#' Random Forest Imputation
#'
#' Impute missing values based on a random forest model using [ranger::ranger()]
#' Impute missing values based on random-forest models via [vimpute()].
#' @param formula model formula for the imputation
#' @param data A `data.frame` containing the data
#' @param imp_var `TRUE`/`FALSE` if a `TRUE`/`FALSE` variables for each imputed
#' variable should be created show the imputation status
#' @param imp_suffix suffix used for TF imputation variables
#' @param ... Arguments passed to [ranger::ranger()]
#' @param ... Additional arguments. Currently ignored because
#' `rangerImpute()` delegates to [vimpute()].
#' @param verbose Show the number of observations used for training
#' and evaluating the RF-Model. This parameter is also passed down to
#' [ranger::ranger()] to show computation status.
#' @param median Use the median (rather than the arithmetic mean) to average
#' the values of individual trees for a more robust estimate.
#' and evaluating the RF-Model.
#' @param median `TRUE`/`FALSE`. If `TRUE`, ranger regression predictions are
#' aggregated tree-wise using the median (via [vimpute()]).
#' @return the imputed data set.
#' @family imputation methods
#' @examples
Expand All @@ -26,62 +26,77 @@ rangerImpute <- function(formula, data, imp_var = TRUE,
lhs <- gsub(" ", "", strsplit(formchar[2], "\\+")[[1]])
rhs <- formchar[3]
rhs2 <- gsub(" ", "", strsplit(rhs, "\\+")[[1]])
#Missings in RHS variables
rhs_na <- apply(subset(data, select = rhs2), 1, function(x) any(is.na(x)))
rhs2 <- rhs2[rhs2 != "1"]
dots <- list(...)
if (length(dots) > 0) {
warning("Additional ranger arguments are ignored; only `median` is passed to vimpute().")
}

data_out <- data
data_out_df <- as.data.frame(data_out)

for (lhsV in lhs) {
form <- as.formula(paste(lhsV, "~", rhs))
lhs_vector <- data[[lhsV]]
lhs_isfactor <- inherits(lhs_vector, "factor")

lhs_vector <- data_out[[lhsV]]
imp_col <- paste0(lhsV, "_", imp_suffix)

if (!any(is.na(lhs_vector))) {
cat(paste0("No missings in ", lhsV, ".\n"))
} else {
lhs_na <- is.na(lhs_vector)
if (verbose)
message("Training model for ", lhsV, " on ", sum(!rhs_na & !lhs_na), " observations")

if(lhs_isfactor){
mod <- ranger::ranger(form, subset(data, !rhs_na & !lhs_na), probability = TRUE, ..., verbose = verbose)
}else{
mod <- ranger::ranger(form, subset(data, !rhs_na & !lhs_na), ..., verbose = verbose)
}

if (verbose)
message("Evaluating model for ", lhsV, " on ", sum(!rhs_na & lhs_na), " observations")

if(lhs_isfactor){
predictions <- predict(mod, subset(data, !rhs_na & lhs_na))$predictions
predict_levels <- colnames(predictions)

predictions <- apply(predictions,1,function(z,lev){
z <- cumsum(z)
z_lev <- lev[z>runif(1)]
return(z_lev[1])
},lev=predict_levels)

}else{
if (median & inherits(lhs_vector, "numeric")) {
predictions <- apply(
predict(mod, subset(data, !rhs_na & lhs_na), predict.all = TRUE)$predictions,
1, median)
if (imp_var) {
if (imp_col %in% colnames(data_out)) {
data_out[[imp_col]] <- as.logical(data_out[[imp_col]])
warning(paste("The following TRUE/FALSE imputation status variables will be updated:", imp_col))
} else {
predictions <- predict(mod, subset(data, !rhs_na & lhs_na))$predictions
data_out[[imp_col]] <- is.na(lhs_vector)
}
}
next
}

considered <- unique(c(lhsV, rhs2))
method <- setNames(as.list(rep("ranger", length(considered))), considered)
pmm <- setNames(as.list(rep(FALSE, length(considered))), considered)

data[!rhs_na & lhs_na, lhsV] <- predictions
if (verbose) {
rhs_na <- if (length(rhs2) > 0) {
apply(data_out_df[, rhs2, drop = FALSE], 1, function(x) any(is.na(x)))
} else {
rep(FALSE, nrow(data_out_df))
}
lhs_na <- is.na(lhs_vector)
message("Training model for ", lhsV, " on ", sum(!rhs_na & !lhs_na), " observations")
message("Evaluating model for ", lhsV, " on ", sum(!rhs_na & lhs_na), " observations")
}


out <- vimpute(
data = data_out_df[, considered, drop = FALSE],
considered_variables = considered,
method = method,
pmm = pmm,
sequential = FALSE,
nseq = 1,
imp_var = imp_var,
pred_history = FALSE,
tune = FALSE,
verbose = verbose,
ranger_median = median
)

data_out[[lhsV]] <- out[[lhsV]]
data_out_df[[lhsV]] <- out[[lhsV]]

if (imp_var) {
if (imp_var %in% colnames(data)) {
data[, paste(lhsV, "_", imp_suffix, sep = "")] <- as.logical(data[, paste(lhsV, "_", imp_suffix, sep = "")])
warning(paste("The following TRUE/FALSE imputation status variables will be updated:",
paste(lhsV, "_", imp_suffix, sep = "")))
} else {
data$NEWIMPTFVARIABLE <- is.na(lhs_vector)
colnames(data)[ncol(data)] <- paste(lhsV, "_", imp_suffix, sep = "")
vimpute_imp_col <- paste0(lhsV, "_imp")
if (imp_col %in% colnames(data_out)) {
data_out[[imp_col]] <- as.logical(data_out[[imp_col]])
warning(paste("The following TRUE/FALSE imputation status variables will be updated:", imp_col))
}
if (vimpute_imp_col %in% colnames(out)) {
data_out[[imp_col]] <- as.logical(out[[vimpute_imp_col]])
} else if (!(imp_col %in% colnames(data_out))) {
data_out[[imp_col]] <- is.na(lhs_vector)
}
}
}
data
}

data_out
}
Loading
Loading