diff --git a/.Rbuildignore b/.Rbuildignore index 4c78b301..7a8c3cdf 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -16,6 +16,11 @@ # GitHub / CI ^\.github$ +^CONTRIBUTING\.md$ + +# Hidden config files (development only) +^\.lintr$ +^\.editorconfig$ # R CMD build artifacts ^doc$ diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 910bdf98..95fcc043 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -9,13 +9,7 @@ Rules for AI agents (Copilot, Claude, etc.) working on this codebase. be captured, use `<-`. With `=` R treats it as a named argument. Example: `expect_message(result <- foo(), "pattern")`. - No space between `if`/`for`/`while` and `(`: write `if(`, not `if (`. -- Enforced by `inst/styler/bgms_style.R`. Run before committing: - ```r - source("inst/styler/bgms_style.R") - styler::style_pkg(style = bgms_style) - ``` - After running, check test files for `expect_*(result = ...)` and - revert those to `result <- ...`. +- Enforced by `inst/styler/bgms_style.R` (see Pre-commit checks). ## Exported R functions (Tier 1) @@ -80,6 +74,40 @@ Rules for AI agents (Copilot, Claude, etc.) working on this codebase. the same commit. - When adding a new exported function, add it to `_pkgdown.yml`. +## Pre-commit checks + +Before every commit that touches R code, run these checks and fix +any issues they report: + +1. **Style** — enforce the project code style: + ```r + source("inst/styler/bgms_style.R") + styler::style_pkg(style = bgms_style) + ``` + After running, check test files for `expect_*(result = ...)` and + revert those to `result <- ...`. + +2. **Lint** — catch issues the CI lint workflow will flag: + ```r + lintr::lint_package() + ``` + The `.lintr` config enables `T_and_F_symbol_linter()` and + `seq_linter()` (among defaults). Fix all findings before + committing. + +3. **Roxygen** — regenerate Rd files if any roxygen comment changed: + ```r + roxygen2::roxygenise() + ``` + Stage the regenerated `man/*.Rd` files in the same commit. + +4. **R CMD check** — if the change is non-trivial, run: + ```r + rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran")) + ``` + CI uses `error-on = "warning"`, so warnings are treated as + failures. + ## Do not - Add `@keywords internal` to exported functions. diff --git a/.github/workflows/nightly-validation.yaml b/.github/workflows/nightly-validation.yaml new file mode 100644 index 00000000..18cef388 --- /dev/null +++ b/.github/workflows/nightly-validation.yaml @@ -0,0 +1,29 @@ +name: nightly-validation +on: + schedule: + - cron: '0 3 * * 1,4' # Monday and Thursday at 3 AM UTC + workflow_dispatch: {} + +jobs: + validate: + runs-on: ubuntu-latest + timeout-minutes: 120 + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + R_KEEP_PKG_SOURCE: yes + BGMS_RUN_SLOW_TESTS: true + + steps: + - uses: actions/checkout@v5 + + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + + - uses: r-lib/actions/setup-r-dependencies@v2 + with: + extra-packages: any::devtools + + - name: Run full test suite (including slow tests) + run: Rscript -e 'devtools::test()' + diff --git a/.github/workflows/weekly-compliance.yaml b/.github/workflows/weekly-compliance.yaml new file mode 100644 index 00000000..281955ea --- /dev/null +++ b/.github/workflows/weekly-compliance.yaml @@ -0,0 +1,35 @@ +name: weekly-compliance +on: + schedule: + - cron: '0 5 * * 0' # Sunday at 5 AM UTC + workflow_dispatch: {} + +jobs: + compliance: + runs-on: ubuntu-latest + timeout-minutes: 120 + env: + GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + R_KEEP_PKG_SOURCE: yes + + steps: + - uses: actions/checkout@v5 + + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + + - uses: r-lib/actions/setup-r-dependencies@v2 + + - name: Install mixedGM (needed for cross-package tests) + run: Rscript -e 'remotes::install_github("MaartenMarsman/mixedGM")' + + - name: Generate compliance fixtures (if missing) + run: | + if [ ! -f tests/compliance/fixtures/manifest.rds ]; then + Rscript tests/compliance/generate_fixtures.R || true + fi + + - name: Run bitwise compliance (OMRF configs only) + run: Rscript tests/compliance/test_compliance.R + diff --git a/.gitignore b/.gitignore index eef79de8..adccdc0f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ **/.Rhistory .RData .Ruserdata +.vscode/ src/*.o src/*.so src/*.dll @@ -18,4 +19,5 @@ src/sources.mk docs/* /doc/ /inst/doc/ -dev/plans/ +dev/ +/paper/ diff --git a/DESCRIPTION b/DESCRIPTION index cdb269eb..2bda4e14 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,9 +17,10 @@ Authors@R: c( ) Maintainer: Maarten Marsman Description: Bayesian estimation and variable selection for Markov random field - models of networks of binary, ordinal, and continuous variables. Supports - Gaussian graphical models, multi-group comparison via 'bgmCompare', and - provides simulation, prediction, and missing data imputation. + models of networks of binary, ordinal, continuous, and mixed variables. + Supports ordinal MRFs, Gaussian graphical models, mixed MRFs combining + discrete and continuous variables, multi-group comparison via 'bgmCompare', + and provides simulation, prediction, and missing data imputation. Copyright: Includes datasets 'ADHD' and 'Boredom', which are licensed under CC-BY 4. See individual data documentation for license and citation. License: GPL (>= 2) URL: https://Bayesian-Graphical-Modelling-Lab.github.io/bgms/, https://github.com/Bayesian-Graphical-Modelling-Lab/bgms @@ -39,6 +40,7 @@ LinkingTo: RcppParallel, dqrng, BH +Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.3 Depends: R (>= 3.5) diff --git a/NAMESPACE b/NAMESPACE index aa26428e..bc34e829 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,8 +4,6 @@ S3method(coef,bgmCompare) S3method(coef,bgms) S3method(extract_arguments,bgmCompare) S3method(extract_arguments,bgms) -S3method(extract_category_thresholds,bgmCompare) -S3method(extract_category_thresholds,bgms) S3method(extract_ess,bgmCompare) S3method(extract_ess,bgms) S3method(extract_group_params,bgmCompare) @@ -13,6 +11,8 @@ S3method(extract_indicator_priors,bgmCompare) S3method(extract_indicator_priors,bgms) S3method(extract_indicators,bgmCompare) S3method(extract_indicators,bgms) +S3method(extract_main_effects,bgmCompare) +S3method(extract_main_effects,bgms) S3method(extract_pairwise_interactions,bgmCompare) S3method(extract_pairwise_interactions,bgms) S3method(extract_posterior_inclusion_probabilities,bgmCompare) @@ -40,6 +40,7 @@ export(extract_ess) export(extract_group_params) export(extract_indicator_priors) export(extract_indicators) +export(extract_main_effects) export(extract_pairwise_interactions) export(extract_pairwise_thresholds) export(extract_posterior_inclusion_probabilities) diff --git a/NEWS.md b/NEWS.md index b47bafc0..f11e4d76 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## New features +* Mixed MRF models: `bgm()` accepts a per-variable `variable_type` vector that mixes `"ordinal"`, `"blume-capel"`, and `"continuous"` types to estimate networks with both discrete and continuous variables. `simulate.bgms()` and `predict.bgms()` also support mixed models. * Gaussian graphical models (GGM): `bgm(x, variable_type = "continuous")` fits a GGM with Bayesian edge selection. Pairwise effects are partial correlations from the precision matrix. * Missing data imputation: `na_action = "impute"` integrates over missing values during MCMC sampling for both ordinal and continuous models. diff --git a/R/RcppExports.R b/R/RcppExports.R index 4eca5302..e3007c1c 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -25,6 +25,10 @@ compute_conditional_probs <- function(observations, predict_vars, pairwise, main .Call(`_bgms_compute_conditional_probs`, observations, predict_vars, pairwise, main, num_categories, variable_type, baseline_category) } +compute_conditional_mixed <- function(x_observations, y_observations, predict_vars, Kxx, Kxy, Kyy, mux, muy, num_categories, variable_type, baseline_category) { + .Call(`_bgms_compute_conditional_mixed`, x_observations, y_observations, predict_vars, Kxx, Kxy, Kyy, mux, muy, num_categories, variable_type, baseline_category) +} + sample_omrf_gibbs <- function(num_states, num_variables, num_categories, pairwise, main, iter, seed) { .Call(`_bgms_sample_omrf_gibbs`, num_states, num_variables, num_categories, pairwise, main, iter, seed) } @@ -45,10 +49,22 @@ run_ggm_simulation_parallel <- function(pairwise_samples, main_samples, draw_ind .Call(`_bgms_run_ggm_simulation_parallel`, pairwise_samples, main_samples, draw_indices, num_states, num_variables, means, nThreads, seed, progress_type) } +sample_mixed_mrf_gibbs <- function(num_states, Kxx_r, Kxy_r, Kyy_r, mux_r, muy_r, num_categories_r, variable_type_r, baseline_category_r, iter, seed) { + .Call(`_bgms_sample_mixed_mrf_gibbs`, num_states, Kxx_r, Kxy_r, Kyy_r, mux_r, muy_r, num_categories_r, variable_type_r, baseline_category_r, iter, seed) +} + +run_mixed_simulation_parallel <- function(mux_samples, kxx_samples, muy_samples, kyy_samples, kxy_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) { + .Call(`_bgms_run_mixed_simulation_parallel`, mux_samples, kxx_samples, muy_samples, kyy_samples, kxy_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) +} + sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, na_impute = FALSE, missing_index_nullable = NULL) { .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, na_impute, missing_index_nullable) } +sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) { + .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable) +} + sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) { .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable) } diff --git a/R/bgm.R b/R/bgm.R index 71d5f6cd..29e4a17f 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -1,16 +1,18 @@ #' Bayesian Estimation or Edge Selection for Markov Random Fields #' #' @description -#' The \code{bgm} function estimates the pseudoposterior distribution of -#' category thresholds (main effects) and pairwise interaction parameters of a -#' Markov Random Field (MRF) model for binary and/or ordinal variables. -#' Optionally, it performs Bayesian edge selection using spike-and-slab -#' priors to infer the network structure. +#' The \code{bgm} function estimates the pseudoposterior distribution of the +#' parameters of a Markov Random Field (MRF) for binary, ordinal, continuous, +#' or mixed (discrete and continuous) variables. Depending on the variable +#' types, the model is an ordinal MRF, a Gaussian graphical model (GGM), or a +#' mixed MRF. Optionally, it performs Bayesian edge selection using +#' spike-and-slab priors to infer the network structure. #' #' @details -#' This function models the joint distribution of binary and ordinal variables -#' using a Markov Random Field, with support for edge selection through Bayesian -#' variable selection. The statistical foundation of the model is described in +#' This function models the joint distribution of binary, ordinal, continuous, +#' or mixed variables using a Markov Random Field, with support for edge +#' selection through Bayesian variable selection. The statistical foundation +#' of the model is described in #' \insertCite{MarsmanVandenBerghHaslbeck_2025;textual}{bgms}, where the ordinal #' MRF model and its Bayesian estimation procedure were first introduced. While #' the implementation in \pkg{bgms} has since been extended and updated (e.g., @@ -172,8 +174,11 @@ #' Blume–Capel variables, all categories are retained. #' #' @param variable_type Character or character vector. Specifies the type of -#' each variable in \code{x}. Allowed values: \code{"ordinal"} or -#' \code{"blume-capel"}. Binary variables are automatically treated as +#' each variable in \code{x}. Allowed values: \code{"ordinal"}, +#' \code{"blume-capel"}, or \code{"continuous"}. A single string applies +#' to all variables. A per-variable vector that mixes discrete +#' (\code{"ordinal"} / \code{"blume-capel"}) and \code{"continuous"} +#' types fits a mixed MRF. Binary variables are automatically treated as #' \code{"ordinal"}. Default: \code{"ordinal"}. #' #' @param baseline_category Integer or vector. Baseline category used in @@ -212,6 +217,21 @@ #' score endpoints \eqn{(-b, m-b)}. #' Default: \code{FALSE}. #' +#' @param pseudolikelihood Character. Specifies the pseudo-likelihood +#' approximation used for mixed MRF models (ignored for pure ordinal or +#' pure continuous data). Options: +#' \describe{ +#' \item{\code{"conditional"}}{Conditions on the observed continuous +#' variables when computing the discrete full conditionals. Faster +#' because the discrete pseudo-likelihood does not depend on the +#' continuous precision matrix.} +#' \item{\code{"marginal"}}{Integrates out the continuous variables, +#' giving discrete full conditionals that account for induced +#' interactions through the continuous block. More expensive per +#' iteration.} +#' } +#' Default: \code{"conditional"}. +#' #' @param main_alpha,main_beta Double. Shape parameters of the #' beta-prime prior for threshold parameters. Must be positive. If equal, #' the prior is symmetric. Defaults: \code{main_alpha = 0.5} and @@ -307,16 +327,26 @@ #' Main components include: #' \itemize{ #' \item \code{posterior_summary_main}: Data frame with posterior summaries -#' (mean, sd, MCSE, ESS, Rhat) for category threshold parameters. +#' (mean, sd, MCSE, ESS, Rhat) for main-effect parameters. +#' For OMRF models these are category thresholds; +#' for mixed MRF models these are discrete thresholds and +#' continuous means. \code{NULL} for GGM models (no main effects). +#' \item \code{posterior_summary_quadratic}: Data frame with posterior +#' summaries for the precision matrix diagonal. Present for GGM and +#' mixed MRF models; \code{NULL} for OMRF models. #' \item \code{posterior_summary_pairwise}: Data frame with posterior #' summaries for pairwise interaction parameters. #' \item \code{posterior_summary_indicator}: Data frame with posterior #' summaries for edge inclusion indicators (if \code{edge_selection = TRUE}). #' -#' \item \code{posterior_mean_main}: Matrix of posterior mean thresholds -#' (rows = variables, cols = categories or parameters). +#' \item \code{posterior_mean_main}: Posterior mean of main-effect +#' parameters. \code{NULL} for GGM models. For OMRF: a matrix +#' (p x max_categories) of category thresholds. For mixed MRF: a list +#' with \code{$discrete} (threshold matrix) and \code{$continuous} +#' (q x 1 matrix of means). #' \item \code{posterior_mean_pairwise}: Symmetric matrix of posterior mean -#' pairwise interaction strengths. +#' pairwise interaction strengths. For GGM and mixed MRF models the +#' precision matrix diagonal is included on the matrix diagonal. #' \item \code{posterior_mean_indicator}: Symmetric matrix of posterior mean #' inclusion probabilities (if edge selection was enabled). #' @@ -409,6 +439,7 @@ bgm = function( display_progress = c("per-chain", "total", "none"), seed = NULL, standardize = FALSE, + pseudolikelihood = c("conditional", "marginal"), verbose = getOption("bgms.verbose", TRUE), interaction_scale, burnin, @@ -481,7 +512,8 @@ bgm = function( cores = cores, seed = seed, display_progress = display_progress, - verbose = verbose + verbose = verbose, + pseudolikelihood = pseudolikelihood ) raw = run_sampler(spec) diff --git a/R/bgm_spec.R b/R/bgm_spec.R index c745007a..64ec1fa9 100644 --- a/R/bgm_spec.R +++ b/R/bgm_spec.R @@ -25,22 +25,33 @@ new_bgm_spec = function(model_type, data, variables, missing, prior, # --- top-level structure --- stopifnot( is.character(model_type), length(model_type) == 1L, - model_type %in% c("ggm", "omrf", "compare") + model_type %in% c("ggm", "omrf", "compare", "mixed_mrf") ) # --- data sub-list --- stopifnot(is.list(data)) - stopifnot(is.matrix(data$x)) + if(model_type == "mixed_mrf") { + stopifnot(is.matrix(data$x_discrete)) + stopifnot(is.matrix(data$x_continuous)) + } else { + stopifnot(is.matrix(data$x)) + } stopifnot(is.character(data$data_columnnames)) stopifnot(is.integer(data$num_variables), length(data$num_variables) == 1L) stopifnot(is.integer(data$num_cases), length(data$num_cases) == 1L) - if(model_type != "ggm") { + if(model_type == "omrf" || model_type == "compare") { stopifnot( is.integer(data$num_categories), length(data$num_categories) == data$num_variables ) } + if(model_type == "mixed_mrf") { + stopifnot( + is.integer(data$num_categories), + length(data$num_categories) == data$num_discrete + ) + } if(model_type == "compare") { @@ -68,17 +79,31 @@ new_bgm_spec = function(model_type, data, variables, missing, prior, if(!is.null(missing$missing_index)) { stopifnot(is.matrix(missing$missing_index)) } + # mixed MRF uses separate indices for discrete and continuous + if(!is.null(missing$missing_index_discrete)) { + stopifnot(is.matrix(missing$missing_index_discrete)) + } + if(!is.null(missing$missing_index_continuous)) { + stopifnot(is.matrix(missing$missing_index_continuous)) + } # --- prior sub-list --- stopifnot(is.list(prior)) - if(model_type != "ggm") { + if(model_type %in% c("omrf", "compare")) { stopifnot(is.numeric(prior$pairwise_scale), length(prior$pairwise_scale) == 1L) stopifnot(is.numeric(prior$main_alpha), length(prior$main_alpha) == 1L) stopifnot(is.numeric(prior$main_beta), length(prior$main_beta) == 1L) stopifnot(is.logical(prior$standardize), length(prior$standardize) == 1L) stopifnot(is.matrix(prior$pairwise_scaling_factors)) } - if(model_type %in% c("ggm", "omrf")) { + if(model_type == "mixed_mrf") { + stopifnot(is.numeric(prior$pairwise_scale), length(prior$pairwise_scale) == 1L) + stopifnot(is.numeric(prior$main_alpha), length(prior$main_alpha) == 1L) + stopifnot(is.numeric(prior$main_beta), length(prior$main_beta) == 1L) + stopifnot(is.logical(prior$standardize), length(prior$standardize) == 1L) + stopifnot(is.character(prior$pseudolikelihood), length(prior$pseudolikelihood) == 1L) + } + if(model_type %in% c("ggm", "omrf", "mixed_mrf")) { stopifnot(is.logical(prior$edge_selection), length(prior$edge_selection) == 1L) stopifnot(is.character(prior$edge_prior), length(prior$edge_prior) == 1L) stopifnot(is.matrix(prior$inclusion_probability)) @@ -162,14 +187,14 @@ validate_bgm_spec = function(spec) { } # Edge selection consistency - if(mt %in% c("ggm", "omrf")) { + if(mt %in% c("ggm", "omrf", "mixed_mrf")) { if(spec$prior$edge_selection && spec$prior$edge_prior == "Not Applicable") { stop("bgm_spec: edge_selection = TRUE but edge_prior = 'Not Applicable'.") } } # Scaling factors dimensions - if(mt != "ggm") { + if(mt %in% c("omrf", "compare")) { nv = spec$data$num_variables sf = spec$prior$pairwise_scaling_factors if(nrow(sf) != nv || ncol(sf) != nv) { @@ -181,11 +206,24 @@ validate_bgm_spec = function(spec) { } # num_categories length (OMRF / compare) - if(mt != "ggm") { + if(mt == "omrf" || mt == "compare") { if(length(spec$data$num_categories) != spec$data$num_variables) { stop("bgm_spec: num_categories length doesn't match num_variables.") } } + if(mt == "mixed_mrf") { + if(length(spec$data$num_categories) != spec$data$num_discrete) { + stop("bgm_spec: num_categories length doesn't match num_discrete.") + } + allowed = c("adaptive-metropolis", "hybrid-nuts") + if(!(spec$sampler$update_method %in% allowed)) { + stop( + "bgm_spec: model_type = 'mixed_mrf' requires update_method in ", + paste(sQuote(allowed), collapse = " or "), ". Got '", + spec$sampler$update_method, "'." + ) + } + } invisible(spec) } @@ -201,7 +239,7 @@ validate_bgm_spec = function(spec) { # Parameters mirror the union of bgm() and bgmCompare() arguments. # ============================================================================== bgm_spec = function(x, - model_type = c("omrf", "ggm", "compare"), + model_type = c("omrf", "ggm", "compare", "mixed_mrf"), # Variable specification variable_type = "ordinal", baseline_category = 0L, @@ -248,7 +286,8 @@ bgm_spec = function(x, cores = parallel::detectCores(), seed = NULL, display_progress = c("per-chain", "total", "none"), - verbose = TRUE) { + verbose = TRUE, + pseudolikelihood = c("conditional", "marginal")) { model_type = match.arg(model_type) na_action = tryCatch(match.arg(na_action), error = function(e) { stop(paste0( @@ -272,16 +311,21 @@ bgm_spec = function(x, variable_type = variable_type, num_variables = num_variables, allow_continuous = allow_continuous, + allow_mixed = (model_type != "compare"), caller = if(model_type == "compare") "bgmCompare" else "bgm" ) variable_type = vt$variable_type is_ordinal = vt$variable_bool is_continuous = vt$is_continuous + is_mixed = vt$is_mixed # Resolve model_type if "omrf" default was kept but data is continuous if(model_type == "omrf" && is_continuous) { model_type = "ggm" } + if(model_type == "omrf" && is_mixed) { + model_type = "mixed_mrf" + } # --- Sampler (needs is_continuous and edge_selection early) ------------------ sampler = validate_sampler( @@ -301,6 +345,12 @@ bgm_spec = function(x, verbose = verbose ) + # Mixed MRF: remap "nuts" to the hybrid sampler that uses NUTS for the + # unconstrained block and component-wise MH for the SPD-constrained Kyy. + if(is_mixed && sampler$update_method == "nuts") { + sampler$update_method = "hybrid-nuts" + } + # --- Build by model type ---------------------------------------------------- if(model_type == "ggm") { spec = build_spec_ggm( @@ -320,6 +370,27 @@ bgm_spec = function(x, dirichlet_alpha = dirichlet_alpha, lambda = lambda ) + } else if(model_type == "mixed_mrf") { + pseudolikelihood = match.arg(pseudolikelihood) + spec = build_spec_mixed_mrf( + x = x, data_columnnames = data_columnnames, + num_variables = num_variables, + variable_type = variable_type, is_ordinal = is_ordinal, + baseline_category = baseline_category, + na_action = na_action, sampler = sampler, + pairwise_scale = pairwise_scale, main_alpha = main_alpha, + main_beta = main_beta, standardize = standardize, + pseudolikelihood = pseudolikelihood, + edge_selection = edge_selection, + edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda + ) } else if(model_type == "omrf") { spec = build_spec_omrf( x = x, data_columnnames = data_columnnames, @@ -546,6 +617,171 @@ build_spec_omrf = function(x, data_columnnames, num_variables, } +# ------------------------------------------------------------------ +# build_spec_mixed_mrf +# ------------------------------------------------------------------ +# Builds a bgm_spec for the mixed MRF model (discrete + continuous). +# Splits the input data matrix into discrete and continuous parts, +# validates and recodes discrete variables (ordinal/BC), and assembles +# the spec with metadata needed by sample_mixed_mrf() and +# build_output_mixed_mrf(). +# ------------------------------------------------------------------ +build_spec_mixed_mrf = function(x, data_columnnames, num_variables, + variable_type, is_ordinal, + baseline_category, + na_action, sampler, + pairwise_scale, main_alpha, main_beta, + standardize, pseudolikelihood, + edge_selection, edge_prior, + inclusion_probability, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, + beta_bernoulli_beta_between, + dirichlet_alpha, lambda) { + # Identify discrete vs continuous columns + cont_idx = which(variable_type == "continuous") + disc_idx = which(variable_type != "continuous") + p = length(disc_idx) + q = length(cont_idx) + + # Split data + x_disc = x[, disc_idx, drop = FALSE] + x_cont = x[, cont_idx, drop = FALSE] + + # Ensure integer matrix for discrete data + storage.mode(x_disc) = "integer" + # Ensure numeric matrix for continuous data + storage.mode(x_cont) = "double" + + # Discrete variable properties (subset to discrete columns) + is_ordinal_disc = is_ordinal[disc_idx] + vtype_disc = variable_type[disc_idx] + + # Baseline category for discrete variables + bc = validate_baseline_category( + baseline_category = baseline_category, + baseline_category_provided = !identical(baseline_category, 0L), + x = x_disc, + variable_bool = is_ordinal_disc + ) + + # Missing data handling + na_impute = FALSE + missing_index_discrete = NULL + missing_index_continuous = NULL + + if(na_action == "listwise") { + missing_rows = apply(x_disc, 1, anyNA) | apply(x_cont, 1, anyNA) + if(all(missing_rows)) { + stop(paste0( + "All rows in x contain at least one missing response.\n", + "You could try option na_action = \"impute\"." + )) + } + n_removed = sum(missing_rows) + if(n_removed > 0 && isTRUE(getOption("bgms.verbose", TRUE))) { + n_remaining = nrow(x_disc) - n_removed + message( + n_removed, " row", if(n_removed > 1) "s" else "", + " with missing values excluded (n = ", n_remaining, " remaining).\n", + "To impute missing values instead, use na_action = \"impute\"." + ) + } + x_disc = x_disc[!missing_rows, , drop = FALSE] + x_cont = x_cont[!missing_rows, , drop = FALSE] + if(nrow(x_disc) < 2) { + stop(paste0( + "After removing missing observations from the input matrix x,\n", + "there were less than two rows left in x." + )) + } + } else { + # Impute path: handle discrete and continuous sub-matrices separately + md_disc = handle_impute(x_disc) + md_cont = handle_impute(x_cont) + x_disc = md_disc$x + x_cont = md_cont$x + na_impute = md_disc$na_impute || md_cont$na_impute + if(md_disc$na_impute) missing_index_discrete = md_disc$missing_index + if(md_cont$na_impute) missing_index_continuous = md_cont$missing_index + } + + # Ordinal recoding (reformat discrete data) + ord = reformat_ordinal_data( + x = x_disc, is_ordinal = is_ordinal_disc, + baseline_category = bc + ) + x_disc_recoded = ord$x + num_categories = ord$num_categories + bc_final = ord$baseline_category + + # Edge prior (total variables = p + q) + ep = validate_edge_prior( + edge_selection = edge_selection, edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + num_variables = num_variables, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, lambda = lambda + ) + + num_thresholds = sum(ifelse(is_ordinal_disc, num_categories, 2L)) + + new_bgm_spec( + model_type = "mixed_mrf", + data = list( + x_discrete = x_disc_recoded, + x_continuous = x_cont, + data_columnnames = data_columnnames, + data_columnnames_discrete = data_columnnames[disc_idx], + data_columnnames_continuous = data_columnnames[cont_idx], + num_variables = as.integer(num_variables), + num_discrete = as.integer(p), + num_continuous = as.integer(q), + num_cases = as.integer(nrow(x_disc_recoded)), + num_categories = as.integer(num_categories), + discrete_indices = disc_idx, + continuous_indices = cont_idx + ), + variables = list( + variable_type = variable_type, + is_ordinal = is_ordinal_disc, + is_continuous = FALSE, + is_mixed = TRUE, + baseline_category = as.integer(bc_final) + ), + missing = list( + na_action = na_action, + na_impute = na_impute, + missing_index_discrete = missing_index_discrete, + missing_index_continuous = missing_index_continuous + ), + prior = list( + pairwise_scale = pairwise_scale, + main_alpha = main_alpha, + main_beta = main_beta, + standardize = standardize, + pseudolikelihood = pseudolikelihood, + edge_selection = ep$edge_selection, + edge_prior = ep$edge_prior, + inclusion_probability = ep$inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda + ), + sampler = sampler_sublist(sampler), + precomputed = list( + num_thresholds = as.integer(num_thresholds) + ) + ) +} + + build_spec_compare = function(x, y, group_indicator, data_columnnames, num_variables, variable_type, is_ordinal, is_continuous, @@ -862,6 +1098,8 @@ build_arguments = function(spec) { build_arguments_ggm(spec) } else if(mt == "omrf") { build_arguments_omrf(spec) + } else if(mt == "mixed_mrf") { + build_arguments_mixed_mrf(spec) } else { build_arguments_compare(spec) } @@ -892,7 +1130,8 @@ build_arguments_ggm = function(spec) { num_chains = spec$sampler$chains, data_columnnames = spec$data$data_columnnames, no_variables = spec$data$num_variables, - is_continuous = TRUE + is_continuous = TRUE, + model_type = "ggm" ) } @@ -934,7 +1173,52 @@ build_arguments_omrf = function(spec) { data_columnnames = spec$data$data_columnnames, baseline_category = spec$variables$baseline_category, pairwise_scaling_factors = spec$prior$pairwise_scaling_factors, - no_variables = spec$data$num_variables + no_variables = spec$data$num_variables, + model_type = "omrf" + ) +} + + +build_arguments_mixed_mrf = function(spec) { + list( + num_variables = spec$data$num_variables, + num_discrete = spec$data$num_discrete, + num_continuous = spec$data$num_continuous, + num_cases = spec$data$num_cases, + variable_type = spec$variables$variable_type, + iter = spec$sampler$iter, + warmup = spec$sampler$warmup, + pairwise_scale = spec$prior$pairwise_scale, + standardize = spec$prior$standardize, + pseudolikelihood = spec$prior$pseudolikelihood, + main_alpha = spec$prior$main_alpha, + main_beta = spec$prior$main_beta, + edge_selection = spec$prior$edge_selection, + edge_prior = spec$prior$edge_prior, + inclusion_probability = spec$prior$inclusion_probability, + beta_bernoulli_alpha = spec$prior$beta_bernoulli_alpha, + beta_bernoulli_beta = spec$prior$beta_bernoulli_beta, + beta_bernoulli_alpha_between = spec$prior$beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = spec$prior$beta_bernoulli_beta_between, + dirichlet_alpha = spec$prior$dirichlet_alpha, + lambda = spec$prior$lambda, + na_action = spec$missing$na_action, + version = packageVersion("bgms"), + update_method = spec$sampler$update_method, + target_accept = spec$sampler$target_accept, + nuts_max_depth = spec$sampler$nuts_max_depth, + num_chains = spec$sampler$chains, + num_categories = spec$data$num_categories, + data_columnnames = spec$data$data_columnnames, + data_columnnames_discrete = spec$data$data_columnnames_discrete, + data_columnnames_continuous = spec$data$data_columnnames_continuous, + discrete_indices = spec$data$discrete_indices, + continuous_indices = spec$data$continuous_indices, + baseline_category = spec$variables$baseline_category, + is_ordinal = spec$variables$is_ordinal, + no_variables = spec$data$num_variables, + is_mixed = TRUE, + model_type = "mixed_mrf" ) } @@ -967,7 +1251,8 @@ build_arguments_compare = function(spec) { num_categories = spec$data$num_categories, is_ordinal_variable = spec$variables$is_ordinal, group = sort(spec$data$group), - pairwise_scaling_factors = spec$prior$pairwise_scaling_factors + pairwise_scaling_factors = spec$prior$pairwise_scaling_factors, + model_type = "compare" ) } diff --git a/R/bgmcompare-methods.r b/R/bgmcompare-methods.r index c069709a..f8ed7145 100644 --- a/R/bgmcompare-methods.r +++ b/R/bgmcompare-methods.r @@ -243,9 +243,9 @@ print.summary.bgmCompare = function(x, digits = 3, ...) { #' @return A list with components: #' \describe{ #' \item{main_effects_raw}{Posterior means of the raw main-effect parameters -#' (variables x [baseline + differences]).} +#' (variables x (baseline + differences)).} #' \item{pairwise_effects_raw}{Posterior means of the raw pairwise-effect parameters -#' (pairs x [baseline + differences]).} +#' (pairs x (baseline + differences)).} #' \item{main_effects_groups}{Posterior means of group-specific main effects #' (variables x groups), computed as baseline plus projected differences.} #' \item{pairwise_effects_groups}{Posterior means of group-specific pairwise effects diff --git a/R/bgms-methods.R b/R/bgms-methods.R index 3168f4f2..3f53be18 100644 --- a/R/bgms-methods.R +++ b/R/bgms-methods.R @@ -21,7 +21,7 @@ print.bgms = function(x, ...) { arguments = extract_arguments(x) - # Model type + # Estimation method if(isTRUE(arguments$edge_selection)) { prior_msg = switch(arguments$edge_prior, "Bernoulli" = "Bayesian Edge Selection using a Bernoulli prior on edge inclusion", @@ -34,6 +34,21 @@ print.bgms = function(x, ...) { cat("Bayesian Estimation\n") } + # Model type + mt = arguments$model_type + if(!is.null(mt)) { + mt_label = switch(mt, + ggm = "GGM (Gaussian Graphical Model)", + omrf = "OMRF (Ordinal Markov Random Field)", + mixed_mrf = sprintf( + "Mixed MRF (%d discrete, %d continuous)", + arguments$num_discrete, arguments$num_continuous + ), + mt + ) + cat(paste0(" Model: ", mt_label, "\n")) + } + # Dataset info cat(paste0(" Number of variables: ", arguments$num_variables, "\n")) if(isTRUE(arguments$standardize)) { @@ -83,11 +98,24 @@ print.bgms = function(x, ...) { summary.bgms = function(object, ...) { arguments = extract_arguments(object) - if(!is.null(object$posterior_summary_main) && !is.null(object$posterior_summary_pairwise)) { + has_main = !is.null(object$posterior_summary_main) + has_quad = !is.null(object$posterior_summary_quadratic) + has_pair = !is.null(object$posterior_summary_pairwise) + + if((has_main || has_quad) && has_pair) { + mt = arguments$model_type + main_label = switch(mt, + ggm = NULL, + omrf = "Category thresholds:", + mixed_mrf = "Main effects (discrete thresholds and continuous means):", + "Main effects:" + ) out = list( main = object$posterior_summary_main, + quadratic = object$posterior_summary_quadratic, pairwise = object$posterior_summary_pairwise ) + attr(out, "main_label") = main_label if(!is.null(object$posterior_summary_indicator)) { out$indicator = object$posterior_summary_indicator @@ -118,12 +146,20 @@ print.summary.bgms = function(x, digits = 3, ...) { cat("Posterior summaries from Bayesian estimation:\n\n") if(!is.null(x$main)) { - cat("Category thresholds:\n") + main_label = attr(x, "main_label") %||% "Main effects:" + cat(main_label, "\n") print(round(head(x$main, 6), digits = digits)) if(nrow(x$main) > 6) cat("... (use `summary(fit)$main` to see full output)\n") cat("\n") } + if(!is.null(x$quadratic)) { + cat("Precision matrix diagonal:\n") + print(round(head(x$quadratic, 6), digits = digits)) + if(nrow(x$quadratic) > 6) cat("... (use `summary(fit)$quadratic` to see full output)\n") + cat("\n") + } + if(!is.null(x$pairwise)) { cat("Pairwise interactions:\n") pair = head(x$pairwise, 6) @@ -179,15 +215,21 @@ print.summary.bgms = function(x, digits = 3, ...) { #' @title Extract Coefficients from a bgms Object #' @name coef.bgms -#' @description Returns the posterior mean thresholds, pairwise effects, and edge inclusion indicators from a \code{bgms} model fit. +#' @description Returns the posterior mean main effects, pairwise effects, and edge inclusion indicators from a \code{bgms} model fit. #' #' @param object An object of class \code{bgms}. #' @param ... Ignored. #' #' @return A list with the following components: #' \describe{ -#' \item{main}{Posterior mean of the category threshold parameters.} -#' \item{pairwise}{Posterior mean of the pairwise interaction matrix.} +#' \item{main}{Posterior mean of the main-effect parameters. \code{NULL} for +#' GGM models (no main effects). For OMRF models this is a numeric matrix +#' (p x max_categories) of category thresholds. For mixed MRF models this +#' is a list with \code{$discrete} (p x max_categories matrix) and +#' \code{$continuous} (q x 1 matrix of means).} +#' \item{pairwise}{Posterior mean of the pairwise interaction matrix. For GGM +#' and mixed MRF models the precision matrix diagonal is included on the +#' matrix diagonal.} #' \item{indicator}{Posterior mean of the edge inclusion indicators (if available).} #' } #' diff --git a/R/bgms-package.R b/R/bgms-package.R index 464d643b..cc4af018 100644 --- a/R/bgms-package.R +++ b/R/bgms-package.R @@ -1,16 +1,24 @@ -#' bgms: Bayesian Analysis of Networks of Binary and/or Ordinal Variables +#' bgms: Bayesian Analysis of Graphical Models #' #' @description #' The \code{R} package \strong{bgms} provides tools for Bayesian analysis of -#' the ordinal Markov random field (MRF), a graphical model describing networks -#' of binary and/or ordinal variables \insertCite{MarsmanVandenBerghHaslbeck_2025}{bgms}. -#' The likelihood is approximated via a pseudolikelihood, and Markov chain Monte -#' Carlo (MCMC) methods are used to sample from the corresponding pseudoposterior -#' distribution of model parameters. +#' graphical models describing networks of binary, ordinal, continuous, and +#' mixed variables +#' \insertCite{MarsmanVandenBerghHaslbeck_2025}{bgms}. +#' Supported model families include ordinal Markov random fields (MRFs), +#' Gaussian graphical models (GGMs), and mixed MRFs that combine discrete +#' and continuous variables in a single network. The likelihood is approximated +#' via a pseudolikelihood, and Markov chain Monte Carlo (MCMC) methods are used +#' to sample from the corresponding pseudoposterior distribution of model +#' parameters. #' #' The main entry points are: #' \itemize{ #' \item \strong{bgm}: estimation in a one-sample design. +#' Use \code{variable_type = "ordinal"} for an MRF, +#' \code{"continuous"} for a GGM, or a per-variable vector +#' mixing \code{"ordinal"}, \code{"blume-capel"}, and +#' \code{"continuous"} for a mixed MRF. #' \item \strong{bgmCompare}: estimation and group comparison in an #' independent-sample design. #' } diff --git a/R/build_output.R b/R/build_output.R index 865340df..565b1536 100644 --- a/R/build_output.R +++ b/R/build_output.R @@ -11,6 +11,194 @@ # ============================================================================== +# ------------------------------------------------------------------ +# fill_mixed_symmetric +# ------------------------------------------------------------------ +# Fills a symmetric (p+q)×(p+q) matrix from a flat vector of edge +# values stored in discrete-discrete / continuous-continuous / cross +# block order. Used for both pairwise means and indicator means in +# the mixed MRF output builder. +# +# @param values Flat numeric vector of edge values. +# @param p Number of discrete variables. +# @param q Number of continuous variables. +# @param disc_idx Integer vector mapping discrete 1:p to original columns. +# @param cont_idx Integer vector mapping continuous 1:q to original columns. +# @param dimnames List of row/colnames for the result matrix. +# +# Returns: Symmetric matrix with values placed in original-column order. +# ------------------------------------------------------------------ +fill_mixed_symmetric = function(values, p, q, disc_idx, cont_idx, dimnames) { + n = length(dimnames[[1]]) + mat = matrix(0, nrow = n, ncol = n, dimnames = dimnames) + idx = 0L + + # Discrete-discrete block (upper triangle) + if(p > 1) { + for(i in seq_len(p - 1)) { + for(j in seq(i + 1, p)) { + idx = idx + 1L + oi = disc_idx[i] + oj = disc_idx[j] + mat[oi, oj] = values[idx] + mat[oj, oi] = values[idx] + } + } + } + + # Continuous-continuous block (upper triangle) + if(q > 1) { + for(i in seq_len(q - 1)) { + for(j in seq(i + 1, q)) { + idx = idx + 1L + oi = cont_idx[i] + oj = cont_idx[j] + mat[oi, oj] = values[idx] + mat[oj, oi] = values[idx] + } + } + } + + # Cross block (all p × q pairs) + if(p > 0 && q > 0) { + for(i in seq_len(p)) { + for(j in seq_len(q)) { + idx = idx + 1L + oi = disc_idx[i] + oj = cont_idx[j] + mat[oi, oj] = values[idx] + mat[oj, oi] = values[idx] + } + } + } + + mat +} + + +# ------------------------------------------------------------------ +# compute_mixed_parameter_indices +# ------------------------------------------------------------------ +# Computes slice indices for the mixed MRF flat parameter vector. +# Groups main-effect indices (discrete thresholds, continuous means), +# quadratic-effect indices (precision diagonal), and pairwise indices +# (discrete edges, precision off-diagonal, cross edges). +# +# @param num_thresholds Total number of discrete threshold parameters. +# @param p Number of discrete variables. +# @param q Number of continuous variables. +# +# Returns: List with components num_thresholds, main_idx, pairwise_idx. +# ------------------------------------------------------------------ +compute_mixed_parameter_indices = function(num_thresholds, p, q) { + nt = num_thresholds + nxx = as.integer(p * (p - 1) / 2) + nyy_total = as.integer(q * (q + 1) / 2) + nyy_offdiag = as.integer(q * (q - 1) / 2) + nxy = as.integer(p * q) + + # Offsets in the flat vector (1-based) + main_discrete_start = 1L + main_discrete_end = nt + pairwise_discrete_start = nt + 1L + pairwise_discrete_end = nt + nxx + main_continuous_start = nt + nxx + 1L + main_continuous_end = nt + nxx + q + pairwise_cross_start = nt + nxx + q + 1L + pairwise_cross_end = nt + nxx + q + nxy + pairwise_continuous_start = nt + nxx + q + nxy + 1L + + # Precision diagonal vs off-diagonal within the continuous block + precision_diag_within = integer(q) + precision_offdiag_within = integer(nyy_offdiag) + k_diag = 0L + k_off = 0L + pos = 0L + for(i in seq_len(q)) { + for(j in i:q) { + pos = pos + 1L + if(i == j) { + k_diag = k_diag + 1L + precision_diag_within[k_diag] = pos + } else { + k_off = k_off + 1L + precision_offdiag_within[k_off] = pos + } + } + } + precision_diag_abs = pairwise_continuous_start - 1L + precision_diag_within + precision_offdiag_abs = pairwise_continuous_start - 1L + precision_offdiag_within + + # Main: discrete thresholds + continuous means + # Quadratic: precision diagonal (not a main effect) + main_idx = c( + seq(main_discrete_start, main_discrete_end), + seq(main_continuous_start, main_continuous_end), + precision_diag_abs + ) + + # Pairwise: discrete + precision off-diagonal + cross + pairwise_idx = c( + if(nxx > 0) seq(pairwise_discrete_start, pairwise_discrete_end) else integer(0), + precision_offdiag_abs, + if(nxy > 0) seq(pairwise_cross_start, pairwise_cross_end) else integer(0) + ) + + list( + num_thresholds = nt, + num_quadratic = q, + main_idx = main_idx, + pairwise_idx = pairwise_idx + ) +} + + +# ------------------------------------------------------------------ +# build_raw_samples_list +# ------------------------------------------------------------------ +# Assembles the $raw_samples list shared by all output builders. +# +# @param raw Per-chain list (normalized). +# @param edge_selection Logical. +# @param edge_prior Character string naming the edge prior. +# @param names_main Character vector of main-effect parameter names. +# @param edge_names Character vector of edge parameter names. +# @param allocation_names Optional character vector; when non-NULL, added +# to $parameter_names$allocations. +# +# Returns: List with main, pairwise, indicator, allocations, nchains, +# niter, parameter_names. +# ------------------------------------------------------------------ +build_raw_samples_list = function(raw, edge_selection, edge_prior, + names_main, edge_names, + allocation_names = NULL) { + list( + main = lapply(raw, function(chain) chain$main_samples), + pairwise = lapply(raw, function(chain) chain$pairwise_samples), + indicator = if(edge_selection) { + lapply(raw, function(chain) chain$indicator_samples) + } else { + NULL + }, + allocations = if(edge_selection && + identical(edge_prior, "Stochastic-Block") && + "allocations" %in% names(raw[[1]])) { + lapply(raw, `[[`, "allocations") + } else { + NULL + }, + nchains = length(raw), + niter = nrow(raw[[1]]$main_samples), + parameter_names = list( + main = names_main, + pairwise = edge_names, + indicator = if(edge_selection) edge_names else NULL, + allocations = allocation_names + ) + ) +} + + # ============================================================================== # build_output() — dispatcher # ============================================================================== @@ -18,9 +206,10 @@ build_output = function(spec, raw) { stopifnot(inherits(spec, "bgm_spec")) switch(spec$model_type, - ggm = build_output_bgm(spec, raw), - omrf = build_output_bgm(spec, raw), - compare = build_output_compare(spec, raw), + ggm = build_output_bgm(spec, raw), + omrf = build_output_bgm(spec, raw), + mixed_mrf = build_output_mixed_mrf(spec, raw), + compare = build_output_compare(spec, raw), stop("Unknown model_type: ", spec$model_type) ) } @@ -161,13 +350,20 @@ build_output_bgm = function(spec, raw) { rownames(pairwise_summary) = edge_names results = list() - results$posterior_summary_main = main_summary + + if(is_continuous) { + # GGM has no main effects; the precision diagonal is quadratic + results$posterior_summary_main = NULL + results$posterior_summary_quadratic = main_summary + } else { + results$posterior_summary_main = main_summary + } results$posterior_summary_pairwise = pairwise_summary # --- Edge selection summaries ----------------------------------------------- has_sbm = FALSE if(edge_selection) { - indicator_summary = summarize_indicator(raw, param_names = edge_names)[, -1] + indicator_summary = summary_list$indicator[, -1] rownames(indicator_summary) = edge_names results$posterior_summary_indicator = indicator_summary @@ -180,18 +376,14 @@ build_output_bgm = function(spec, raw) { node_names = data_columnnames ) results$posterior_summary_pairwise_allocations = sbm_convergence$sbm_summary + co_occur_matrix = sbm_convergence$co_occur_matrix } } # --- Posterior mean: main --------------------------------------------------- if(is_continuous) { - # GGM: p × 1 matrix - results$posterior_mean_main = matrix( - main_summary$mean, - nrow = num_variables, - ncol = 1, - dimnames = list(data_columnnames, "precision_diag") - ) + # GGM has no main effects + results$posterior_mean_main = NULL } else { # OMRF: p × max_categories matrix num_params = ifelse(is_ordinal_variable, num_categories, 2L) @@ -226,6 +418,11 @@ build_output_bgm = function(spec, raw) { results$posterior_mean_pairwise = results$posterior_mean_pairwise + t(results$posterior_mean_pairwise) + # --- Precision diagonal on the pairwise matrix (GGM) ----------------------- + if(is_continuous) { + diag(results$posterior_mean_pairwise) = main_summary$mean + } + # --- Posterior mean: indicator + SBM ---------------------------------------- if(edge_selection) { indicator_means = indicator_summary$mean @@ -239,11 +436,7 @@ build_output_bgm = function(spec, raw) { t(results$posterior_mean_indicator) if(has_sbm) { - sbm_convergence2 = summarize_alloc_pairs( - allocations = lapply(raw, `[[`, "allocations"), - node_names = data_columnnames - ) - results$posterior_mean_coclustering_matrix = sbm_convergence2$co_occur_matrix + results$posterior_mean_coclustering_matrix = co_occur_matrix arguments = build_arguments(spec) sbm_summary = posterior_summary_SBM( @@ -261,33 +454,13 @@ build_output_bgm = function(spec, raw) { class(results) = "bgms" # --- raw_samples ------------------------------------------------------------ - results$raw_samples = list( - main = lapply(raw, function(chain) chain$main_samples), - pairwise = lapply(raw, function(chain) chain$pairwise_samples), - indicator = if(edge_selection) { - lapply(raw, function(chain) chain$indicator_samples) - } else { - NULL - }, - allocations = if(edge_selection && - identical(edge_prior, "Stochastic-Block") && - "allocations" %in% names(raw[[1]])) { - lapply(raw, `[[`, "allocations") - } else { - NULL - }, - nchains = length(raw), - niter = nrow(raw[[1]]$main_samples), - parameter_names = list( - main = names_main, - pairwise = edge_names, - indicator = if(edge_selection) edge_names else NULL, - allocations = if(identical(edge_prior, "Stochastic-Block")) { - if(is_continuous) data_columnnames else edge_names - } else { - NULL - } - ) + alloc_names = if(identical(edge_prior, "Stochastic-Block")) { + if(is_continuous) data_columnnames else edge_names + } else { + NULL + } + results$raw_samples = build_raw_samples_list( + raw, edge_selection, edge_prior, names_main, edge_names, alloc_names ) # --- easybgm compat shim (OMRF only) --------------------------------------- @@ -303,7 +476,7 @@ build_output_bgm = function(spec, raw) { results$indicator = extract_indicators(results) } results$interactions = extract_pairwise_interactions(results) - results$thresholds = extract_category_thresholds(results) + results$thresholds = extract_main_effects(results) } } @@ -319,6 +492,266 @@ build_output_bgm = function(spec, raw) { } +# ============================================================================== +# build_output_mixed_mrf() — Mixed MRF builder +# ============================================================================== +# +# Handles the mixed discrete + continuous parameter layout: +# C++ flat vector: [main_discrete | pairwise_discrete_ut | main_continuous | pairwise_cross | pairwise_continuous_ut] +# C++ indicators: [Gxx_ut | Gyy_ut | Gxy] +# +# Splits into main (discrete thresholds, continuous means), +# quadratic (precision diagonal), and pairwise (discrete, precision +# off-diag, cross). The precision diagonal is placed on the diagonal +# of the pairwise interaction matrix, not under main effects. +# ============================================================================== +build_output_mixed_mrf = function(spec, raw) { + d = spec$data + v = spec$variables + pr = spec$prior + s = spec$sampler + + p = d$num_discrete + q = d$num_continuous + num_variables = d$num_variables + data_columnnames = d$data_columnnames + disc_names = d$data_columnnames_discrete + cont_names = d$data_columnnames_continuous + disc_idx = d$discrete_indices + cont_idx = d$continuous_indices + is_ordinal = v$is_ordinal + num_categories = d$num_categories + edge_selection = pr$edge_selection + + # --- Compute index layout in flat parameter vector -------------------------- + layout = compute_mixed_parameter_indices( + num_thresholds = spec$precomputed$num_thresholds, + p = p, + q = q + ) + nt = layout$num_thresholds + main_idx = layout$main_idx + pairwise_idx = layout$pairwise_idx + + # --- Indicator index layout ------------------------------------------------- + # C++ indicator vector: [Gxx_ut | Gyy_ut | Gxy] + # All are pairwise, so indicator_samples maps directly to pairwise order: + # Discrete, continuous, cross edges — same order as pairwise_idx above. + + # --- Normalize raw output per chain ----------------------------------------- + raw = lapply(raw, function(chain) { + samples_t = t(chain$samples) + res = list( + main_samples = samples_t[, main_idx, drop = FALSE], + pairwise_samples = samples_t[, pairwise_idx, drop = FALSE], + userInterrupt = isTRUE(chain$userInterrupt), + chain_id = chain$chain_id + ) + if(!is.null(chain$indicator_samples)) { + res$indicator_samples = t(chain$indicator_samples) + } + if(!is.null(chain$allocation_samples)) { + res$allocations = t(chain$allocation_samples) + } + if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth + if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if(!is.null(chain$energy)) res[["energy__"]] = chain$energy + res + }) + + # --- Parameter names -------------------------------------------------------- + # Main effect names (in internal order: discrete first, continuous second) + names_main = character() + for(si in seq_len(p)) { + if(is_ordinal[si]) { + cats = seq_len(num_categories[si]) + names_main = c(names_main, paste0(disc_names[si], " (", cats, ")")) + } else { + names_main = c( + names_main, + paste0(disc_names[si], " (linear)"), + paste0(disc_names[si], " (quadratic)") + ) + } + } + for(ji in seq_len(q)) { + names_main = c(names_main, paste0(cont_names[ji], " (mean)")) + } + for(ji in seq_len(q)) { + names_main = c(names_main, paste0(cont_names[ji], " (precision)")) + } + + # Pairwise edge names — internal order, mapped to original column names + # We need a mapping from internal index to original variable name + # Internal variables: [disc_1, ..., disc_p, cont_1, ..., cont_q] + # Their original names: c(disc_names, cont_names) + all_internal_names = c(disc_names, cont_names) + + edge_names = character() + # Discrete-discrete edges + if(p > 1) { + for(i in seq_len(p - 1)) { + for(j in seq(i + 1, p)) { + edge_names = c( + edge_names, + paste0(disc_names[i], "-", disc_names[j]) + ) + } + } + } + # Continuous-continuous edges (off-diagonal) + if(q > 1) { + for(i in seq_len(q - 1)) { + for(j in seq(i + 1, q)) { + edge_names = c( + edge_names, + paste0(cont_names[i], "-", cont_names[j]) + ) + } + } + } + # Cross edges (discrete-continuous) + if(p > 0 && q > 0) { + for(i in seq_len(p)) { + for(j in seq_len(q)) { + edge_names = c( + edge_names, + paste0(disc_names[i], "-", cont_names[j]) + ) + } + } + } + + # --- MCMC summaries --------------------------------------------------------- + summary_list = summarize_fit(raw, edge_selection = edge_selection) + main_summary = summary_list$main[, -1] + pairwise_summary = summary_list$pairwise[, -1] + + rownames(main_summary) = names_main + rownames(pairwise_summary) = edge_names + + # Split main_summary into true main effects and quadratic (precision diagonal) + n_main = nt + q # thresholds + continuous means + n_quad = layout$num_quadratic # precision diagonal entries + main_rows = seq_len(n_main) + quad_rows = n_main + seq_len(n_quad) + + results = list() + results$posterior_summary_main = main_summary[main_rows, , drop = FALSE] + results$posterior_summary_quadratic = main_summary[quad_rows, , drop = FALSE] + results$posterior_summary_pairwise = pairwise_summary + + # --- Edge selection summaries ----------------------------------------------- + edge_prior = pr$edge_prior + has_sbm = FALSE + if(edge_selection) { + indicator_summary = summary_list$indicator[, -1] + rownames(indicator_summary) = edge_names + results$posterior_summary_indicator = indicator_summary + + has_sbm = identical(edge_prior, "Stochastic-Block") && + "allocations" %in% names(raw[[1]]) + + if(has_sbm) { + sbm_convergence = summarize_alloc_pairs( + allocations = lapply(raw, `[[`, "allocations"), + node_names = all_internal_names + ) + results$posterior_summary_pairwise_allocations = sbm_convergence$sbm_summary + co_occur_matrix = sbm_convergence$co_occur_matrix + } + } + + # --- Posterior mean: main --------------------------------------------------- + # Discrete main effects: p × max_cats matrix (like OMRF) + num_params_disc = ifelse(is_ordinal, num_categories, 2L) + max_num_cats = max(num_params_disc) + pmm_disc = matrix(NA, nrow = p, ncol = max_num_cats) + start = 0L + stop = 0L + for(si in seq_len(p)) { + if(is_ordinal[si]) { + start = stop + 1L + stop = start + num_categories[si] - 1L + pmm_disc[si, seq_len(num_categories[si])] = main_summary$mean[start:stop] + } else { + start = stop + 1L + stop = start + 1L + pmm_disc[si, 1:2] = main_summary$mean[start:stop] + } + } + rownames(pmm_disc) = disc_names + colnames(pmm_disc) = paste0("cat (", seq_len(max_num_cats), ")") + + # Continuous main effects: q × 1 matrix (means only) + pmm_cont = matrix(main_summary$mean[nt + seq_len(q)], + nrow = q, ncol = 1, + dimnames = list(cont_names, "mean") + ) + + results$posterior_mean_main = list( + discrete = pmm_disc, + continuous = pmm_cont + ) + + # --- Posterior mean: pairwise as (p+q) × (p+q) matrix ----------------------- + dn = list(data_columnnames, data_columnnames) + results$posterior_mean_pairwise = fill_mixed_symmetric( + pairwise_summary$mean, p, q, disc_idx, cont_idx, dn + ) + + # --- Precision diagonal on the pairwise matrix (continuous block) ----------- + kyy_diag_means = main_summary$mean[nt + q + seq_len(q)] + for(j in seq_len(q)) { + results$posterior_mean_pairwise[cont_idx[j], cont_idx[j]] = kyy_diag_means[j] + } + + # --- Posterior mean: indicator ----------------------------------------------- + if(edge_selection) { + results$posterior_mean_indicator = fill_mixed_symmetric( + indicator_summary$mean, p, q, disc_idx, cont_idx, dn + ) + + if(has_sbm) { + results$posterior_mean_coclustering_matrix = co_occur_matrix + + arguments = build_arguments(spec) + sbm_summary = posterior_summary_SBM( + allocations = lapply(raw, `[[`, "allocations"), + arguments = arguments + ) + results$posterior_mean_allocations = sbm_summary$allocations_mean + results$posterior_mode_allocations = sbm_summary$allocations_mode + results$posterior_num_blocks = sbm_summary$blocks + } + } + + # --- arguments + class ------------------------------------------------------ + results$arguments = build_arguments(spec) + class(results) = "bgms" + + # --- raw_samples ------------------------------------------------------------ + alloc_names = if(identical(edge_prior, "Stochastic-Block")) { + all_internal_names + } else { + NULL + } + results$raw_samples = build_raw_samples_list( + raw, edge_selection, edge_prior, names_main, edge_names, alloc_names + ) + + # --- NUTS diagnostics ------------------------------------------------------- + if(s$update_method == "hybrid-nuts") { + results$nuts_diag = summarize_nuts_diagnostics( + raw, + nuts_max_depth = s$nuts_max_depth + ) + } + + results +} + + # ============================================================================== # build_output_compare() # ============================================================================== diff --git a/R/extractor_functions.R b/R/extractor_functions.R index 76e9bb14..82639041 100644 --- a/R/extractor_functions.R +++ b/R/extractor_functions.R @@ -407,7 +407,7 @@ extract_indicator_priors.bgmCompare = function(bgms_object) { #' interaction parameters.} #' } #' -#' @seealso [bgm()], [bgmCompare()], [extract_category_thresholds()] +#' @seealso [bgm()], [bgmCompare()], [extract_main_effects()] #' @family extractors #' @export extract_pairwise_interactions = function(bgms_object) { @@ -428,10 +428,17 @@ extract_pairwise_interactions.bgms = function(bgms_object) { mats = bgms_object$raw_samples$pairwise mat = do.call(rbind, mats) - edge_names = character() - for(i in 1:(num_vars - 1)) { - for(j in (i + 1):num_vars) { - edge_names = c(edge_names, paste0(var_names[i], "-", var_names[j])) + # Use stored parameter names when available (correct for all model types + # including mixed MRF where block order differs from upper-triangle order) + stored_names = bgms_object$raw_samples$parameter_names$pairwise + if(!is.null(stored_names)) { + edge_names = stored_names + } else { + edge_names = character() + for(i in 1:(num_vars - 1)) { + for(j in (i + 1):num_vars) { + edge_names = c(edge_names, paste0(var_names[i], "-", var_names[j])) + } } } @@ -492,46 +499,78 @@ extract_pairwise_interactions.bgmCompare = function(bgms_object) { stop("No pairwise interaction samples found in fit object.") } -#' Extract Category Threshold Estimates +#' Extract Main Effect Estimates +#' +#' @title Extract Main Effect Estimates #' #' @description -#' Retrieves category threshold parameters from a model fitted with -#' [bgm()] or [bgmCompare()]. +#' Retrieves posterior mean main-effect parameters from a model fitted with +#' [bgm()] or [bgmCompare()]. For OMRF models these are category thresholds; +#' for mixed MRF models these include discrete thresholds and continuous +#' means. GGM models have no main effects and return `NULL`. #' #' @param bgms_object A fitted model object of class `bgms` (from [bgm()]) #' or `bgmCompare` (from [bgmCompare()]). #' -#' @return +#' @return The structure depends on the model type: #' \describe{ -#' \item{bgms}{A matrix with one row per variable and one column per -#' category threshold, containing posterior means.} +#' \item{GGM (bgms)}{`NULL` (invisibly). GGM models have no main effects; +#' the precision matrix diagonal is on `coef(fit)$pairwise`.} +#' \item{OMRF (bgms)}{A numeric matrix with one row per variable and one +#' column per category threshold, containing posterior means. Columns +#' beyond the number of categories for a variable are `NA`.} +#' \item{Mixed MRF (bgms)}{A list with two elements: +#' \describe{ +#' \item{discrete}{A numeric matrix (p rows x max_categories columns) +#' of posterior mean thresholds for discrete variables.} +#' \item{continuous}{A numeric matrix (q rows x 1 column) of +#' posterior mean continuous variable means.} +#' }} #' \item{bgmCompare}{A matrix with one row per post-warmup iteration, -#' containing posterior samples of baseline threshold parameters.} +#' containing posterior samples of baseline main-effect parameters.} #' } #' -#' @seealso [bgm()], [bgmCompare()], [extract_pairwise_interactions()] +#' @examples +#' \donttest{ +#' fit = bgm(x = Wenchuan[, 1:3]) +#' extract_main_effects(fit) +#' } +#' +#' @seealso [bgm()], [bgmCompare()], [extract_pairwise_interactions()], +#' [extract_category_thresholds()] #' @family extractors #' @export -extract_category_thresholds = function(bgms_object) { - UseMethod("extract_category_thresholds") +extract_main_effects = function(bgms_object) { + UseMethod("extract_main_effects") } -#' @inheritParams extract_category_thresholds +#' @inheritParams extract_main_effects #' @exportS3Method #' @noRd -extract_category_thresholds.bgms = function(bgms_object) { +extract_main_effects.bgms = function(bgms_object) { arguments = extract_arguments(bgms_object) - var_names = arguments$data_columnnames + + # GGM: no main effects; precision diagonal is on the pairwise matrix + if(isTRUE(arguments$is_continuous)) { + return(invisible(NULL)) + } + + # Mixed MRF: return pre-built list from posterior_mean_main + if(isTRUE(arguments$is_mixed)) { + return(bgms_object$posterior_mean_main) + } # Current format (0.1.6.0+) if(!is.null(bgms_object$posterior_summary_main)) { vec = bgms_object$posterior_summary_main[, "mean"] - # Handle legacy field name (no_variables → num_variables in 0.1.6.0) + var_names = arguments$data_columnnames num_vars = arguments$num_variables %||% arguments$no_variables variable_type = arguments$variable_type if(length(variable_type) == 1) { variable_type = rep(variable_type, num_vars) } + + # OMRF: threshold matrix num_cats = arguments$num_categories max_cats = max(num_cats) mat = matrix(NA_real_, nrow = num_vars, ncol = max_cats) @@ -556,8 +595,8 @@ extract_category_thresholds.bgms = function(bgms_object) { "0.1.6.0", I("The '$thresholds' field is deprecated; please refit with bgms >= 0.1.6.0") ) + var_names = arguments$data_columnnames means = colMeans(bgms_object$thresholds) - # For binary variables in 0.1.4.x, there's 1 threshold per variable mat = matrix(means, nrow = length(means), ncol = 1) rownames(mat) = var_names return(mat) @@ -570,10 +609,10 @@ extract_category_thresholds.bgms = function(bgms_object) { ) } -#' @inheritParams extract_category_thresholds +#' @inheritParams extract_main_effects #' @exportS3Method #' @noRd -extract_category_thresholds.bgmCompare = function(bgms_object) { +extract_main_effects.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) # Current format (0.1.6.0+) @@ -604,11 +643,39 @@ extract_category_thresholds.bgmCompare = function(bgms_object) { "0.1.6.0", I("The '$thresholds_gr*' fields are deprecated; please refit with bgms >= 0.1.6.0") ) - # Combine the two groups' thresholds return(cbind(bgms_object$thresholds_gr1, bgms_object$thresholds_gr2)) } - stop("No category threshold samples found in fit object.") + stop("No main effect samples found in fit object.") +} + + +#' Extract Category Threshold Estimates +#' +#' @title Extract Category Threshold Estimates +#' +#' @description +#' `r lifecycle::badge("deprecated")` +#' +#' `extract_category_thresholds()` was renamed to [extract_main_effects()] to +#' reflect that main effects include continuous means and precisions +#' (mixed MRF), not only category thresholds. +#' +#' @param bgms_object A fitted model object of class `bgms` (from [bgm()]) +#' or `bgmCompare` (from [bgmCompare()]). +#' +#' @return See [extract_main_effects()] for details. +#' +#' @seealso [extract_main_effects()] +#' @family extractors +#' @export +extract_category_thresholds = function(bgms_object) { + lifecycle::deprecate_warn( + "0.1.6.4", + "extract_category_thresholds()", + "extract_main_effects()" + ) + extract_main_effects(bgms_object) } #' Extract Group-Specific Parameters @@ -624,7 +691,7 @@ extract_category_thresholds.bgmCompare = function(bgms_object) { #' group) and `pairwise_effects_groups` (pairwise effects per group). #' #' @seealso [bgmCompare()], [extract_pairwise_interactions()], -#' [extract_category_thresholds()] +#' [extract_main_effects()] #' @family extractors #' @export extract_group_params = function(bgms_object) { @@ -842,13 +909,13 @@ extract_edge_indicators = function(bgms_object) { extract_indicators(bgms_object) } -#' Deprecated: Use extract_category_thresholds instead +#' Deprecated: Use extract_main_effects instead #' @param bgms_object A bgms or bgmCompare object. #' @keywords internal #' @export extract_pairwise_thresholds = function(bgms_object) { - lifecycle::deprecate_warn("0.1.4.2", "extract_pairwise_thresholds()", "extract_category_thresholds()") - extract_category_thresholds(bgms_object) + lifecycle::deprecate_warn("0.1.4.2", "extract_pairwise_thresholds()", "extract_main_effects()") + extract_main_effects(bgms_object) } @@ -887,6 +954,12 @@ extract_rhat.bgms = function(bgms_object) { names(result$main) = rownames(bgms_object$posterior_summary_main) } + # Precision diagonal (quadratic) Rhat + if(!is.null(bgms_object$posterior_summary_quadratic)) { + result$quadratic = bgms_object$posterior_summary_quadratic$Rhat + names(result$quadratic) = rownames(bgms_object$posterior_summary_quadratic) + } + # Pairwise interaction Rhat if(!is.null(bgms_object$posterior_summary_pairwise)) { result$pairwise = bgms_object$posterior_summary_pairwise$Rhat @@ -985,6 +1058,12 @@ extract_ess.bgms = function(bgms_object) { names(result$main) = rownames(bgms_object$posterior_summary_main) } + # Precision diagonal (quadratic) ESS + if(!is.null(bgms_object$posterior_summary_quadratic)) { + result$quadratic = bgms_object$posterior_summary_quadratic$n_eff + names(result$quadratic) = rownames(bgms_object$posterior_summary_quadratic) + } + # Pairwise interaction ESS if(!is.null(bgms_object$posterior_summary_pairwise)) { result$pairwise = bgms_object$posterior_summary_pairwise$n_eff diff --git a/R/mcmc_summary.R b/R/mcmc_summary.R index 4f99e918..8c3c4957 100644 --- a/R/mcmc_summary.R +++ b/R/mcmc_summary.R @@ -34,9 +34,9 @@ compute_rhat_ess = function(draws) { } # Basic summarizer for continuous parameters -summarize_manual = function(fit, component = c("main_samples", "pairwise_samples"), param_names = NULL) { +summarize_manual = function(fit, component = c("main_samples", "pairwise_samples"), param_names = NULL, array3d = NULL) { component = match.arg(component) # Add options later - array3d = combine_chains(fit, component) + if(is.null(array3d)) array3d = combine_chains(fit, component) nparam = dim(array3d)[3] result = matrix(NA, nparam, 5) @@ -61,9 +61,9 @@ summarize_manual = function(fit, component = c("main_samples", "pairwise_samples } # Summarize binary indicator variables -summarize_indicator = function(fit, component = c("indicator_samples"), param_names = NULL) { +summarize_indicator = function(fit, component = c("indicator_samples"), param_names = NULL, array3d = NULL) { component = match.arg(component) # Add options later - array3d = combine_chains(fit, component) + if(is.null(array3d)) array3d = combine_chains(fit, component) nparam = dim(array3d)[3] nchains = dim(array3d)[2] @@ -112,9 +112,9 @@ summarize_indicator = function(fit, component = c("indicator_samples"), param_na } # Summarize slab values where indicators are 1 -summarize_slab = function(fit, component = c("pairwise_samples"), param_names = NULL) { +summarize_slab = function(fit, component = c("pairwise_samples"), param_names = NULL, array3d = NULL) { component = match.arg(component) # Add options later - array3d = combine_chains(fit, component) + if(is.null(array3d)) array3d = combine_chains(fit, component) nparam = dim(array3d)[3] result = matrix(NA, nparam, 5) colnames(result) = c("mean", "sd", "mcse", "n_eff", "Rhat") @@ -151,12 +151,18 @@ summarize_slab = function(fit, component = c("pairwise_samples"), param_names = summarize_pair = function(fit, indicator_component = c("indicator_samples"), slab_component = c("pairwise_samples"), - param_names = NULL) { + param_names = NULL, + summ_ind = NULL, + summ_slab = NULL, + array3d_id = NULL, + array3d_pw = NULL) { indicator_component = match.arg(indicator_component) # Add options later slab_component = match.arg(slab_component) # Add options later - summ_ind = summarize_indicator(fit, component = indicator_component) - summ_slab = summarize_slab(fit, component = slab_component) + if(is.null(array3d_id)) array3d_id = combine_chains(fit, indicator_component) + if(is.null(array3d_pw)) array3d_pw = combine_chains(fit, slab_component) + if(is.null(summ_ind)) summ_ind = summarize_indicator(fit, component = indicator_component, array3d = array3d_id) + if(is.null(summ_slab)) summ_slab = summarize_slab(fit, component = slab_component, array3d = array3d_pw) nparam = nrow(summ_ind) # EAP = indicator_mean * slab_mean. @@ -170,8 +176,6 @@ summarize_pair = function(fit, n_eff = v / mcse2 rhat = rep(NA_real_, nparam) - array3d_pw = combine_chains(fit, slab_component) - array3d_id = combine_chains(fit, indicator_component) nchains = dim(array3d_pw)[2] n_total = prod(dim(array3d_pw)[1:2]) @@ -223,32 +227,41 @@ summarize_fit = function(fit, edge_selection = FALSE) { if(!edge_selection) { pair_summary = summarize_manual(fit, component = "pairwise_samples") - } else { - # Get indicators and slab draws - ind_summary = summarize_indicator(fit, component = "indicator_samples") - slab_summary = summarize_slab(fit, component = "pairwise_samples") + return(list(main = main_summary, pairwise = pair_summary)) + } - all_selected = ind_summary$mean == 1 + # Build 3D arrays once; reused by all summary functions below + array3d_ind = combine_chains(fit, "indicator_samples") + array3d_pw = combine_chains(fit, "pairwise_samples") - # Replace NA with FALSE, so only definite TRUEs are considered - all_selected[is.na(all_selected)] = FALSE + # Compute indicator and slab summaries once + ind_summary = summarize_indicator(fit, component = "indicator_samples", array3d = array3d_ind) + slab_summary = summarize_slab(fit, component = "pairwise_samples", array3d = array3d_pw) - # Use summarize_pair only where not always selected - full_summary = summarize_pair(fit, - indicator_component = "indicator_samples", - slab_component = "pairwise_samples" - ) - manual_summary = summarize_manual(fit, component = "pairwise_samples") + all_selected = ind_summary$mean == 1 - # Replace rows in full_summary with manual results for fully selected entries - if(any(all_selected)) { - full_summary[all_selected, ] = manual_summary[all_selected, ] - } + # Replace NA with FALSE, so only definite TRUEs are considered + all_selected[is.na(all_selected)] = FALSE - pair_summary = full_summary + # Pass pre-computed summaries and arrays to avoid recomputation + full_summary = summarize_pair(fit, + indicator_component = "indicator_samples", + slab_component = "pairwise_samples", + summ_ind = ind_summary, + summ_slab = slab_summary, + array3d_id = array3d_ind, + array3d_pw = array3d_pw + ) + manual_summary = summarize_manual(fit, component = "pairwise_samples", array3d = array3d_pw) + + # Replace rows in full_summary with manual results for fully selected entries + if(any(all_selected)) { + full_summary[all_selected, ] = manual_summary[all_selected, ] } - list(main = main_summary, pairwise = pair_summary) + pair_summary = full_summary + + list(main = main_summary, pairwise = pair_summary, indicator = ind_summary) } diff --git a/R/run_sampler.R b/R/run_sampler.R index 601807a6..37462b6b 100644 --- a/R/run_sampler.R +++ b/R/run_sampler.R @@ -7,6 +7,20 @@ # ============================================================================== +# ------------------------------------------------------------------ +# bb_between_default +# ------------------------------------------------------------------ +# Maps NULL to -1.0 (C++ sentinel for "no between-cluster prior"). +# +# @param value Scalar or NULL from the prior spec. +# +# Returns: value unchanged, or -1.0 when NULL. +# ------------------------------------------------------------------ +bb_between_default = function(value) { + if(is.null(value)) -1.0 else value +} + + # ============================================================================== # run_sampler() — main dispatcher # ============================================================================== @@ -14,9 +28,10 @@ run_sampler = function(spec) { stopifnot(inherits(spec, "bgm_spec")) raw = switch(spec$model_type, - ggm = run_sampler_ggm(spec), - omrf = run_sampler_omrf(spec), - compare = run_sampler_compare(spec), + ggm = run_sampler_ggm(spec), + omrf = run_sampler_omrf(spec), + mixed_mrf = run_sampler_mixed_mrf(spec), + compare = run_sampler_compare(spec), stop("Unknown model_type: ", spec$model_type) ) @@ -40,17 +55,8 @@ run_sampler_ggm = function(spec) { s = spec$sampler m = spec$missing - # C++ expects -1 for "no between-cluster prior" - bb_alpha_between = if(is.null(p$beta_bernoulli_alpha_between)) { - -1.0 - } else { - p$beta_bernoulli_alpha_between - } - bb_beta_between = if(is.null(p$beta_bernoulli_beta_between)) { - -1.0 - } else { - p$beta_bernoulli_beta_between - } + bb_alpha_between = bb_between_default(p$beta_bernoulli_alpha_between) + bb_beta_between = bb_between_default(p$beta_bernoulli_beta_between) out_raw = sample_ggm( inputFromR = list(X = d$x), @@ -91,17 +97,8 @@ run_sampler_omrf = function(spec) { p = spec$prior s = spec$sampler - # C++ expects -1 for "no between-cluster prior" - bb_alpha_between = if(is.null(p$beta_bernoulli_alpha_between)) { - -1.0 - } else { - p$beta_bernoulli_alpha_between - } - bb_beta_between = if(is.null(p$beta_bernoulli_beta_between)) { - -1.0 - } else { - p$beta_bernoulli_beta_between - } + bb_alpha_between = bb_between_default(p$beta_bernoulli_alpha_between) + bb_beta_between = bb_between_default(p$beta_bernoulli_beta_between) input_list = list( observations = d$x, @@ -147,6 +144,65 @@ run_sampler_omrf = function(spec) { } +# ============================================================================== +# run_sampler_mixed_mrf() +# ============================================================================== +run_sampler_mixed_mrf = function(spec) { + d = spec$data + v = spec$variables + m = spec$missing + p = spec$prior + s = spec$sampler + + bb_alpha_between = bb_between_default(p$beta_bernoulli_alpha_between) + bb_beta_between = bb_between_default(p$beta_bernoulli_beta_between) + + input_list = list( + discrete_observations = d$x_discrete, + continuous_observations = d$x_continuous, + num_categories = d$num_categories, + is_ordinal_variable = as.integer(v$is_ordinal), + baseline_category = v$baseline_category, + main_alpha = p$main_alpha, + main_beta = p$main_beta, + pairwise_scale = p$pairwise_scale, + pseudolikelihood = p$pseudolikelihood + ) + + out_raw = sample_mixed_mrf( + inputFromR = input_list, + prior_inclusion_prob = p$inclusion_probability, + initial_edge_indicators = matrix(1L, + nrow = d$num_variables, + ncol = d$num_variables + ), + no_iter = s$iter, + no_warmup = s$warmup, + no_chains = s$chains, + edge_selection = p$edge_selection, + seed = s$seed, + no_threads = s$cores, + progress_type = s$progress_type, + edge_prior = p$edge_prior, + beta_bernoulli_alpha = p$beta_bernoulli_alpha, + beta_bernoulli_beta = p$beta_bernoulli_beta, + beta_bernoulli_alpha_between = bb_alpha_between, + beta_bernoulli_beta_between = bb_beta_between, + dirichlet_alpha = p$dirichlet_alpha, + lambda = p$lambda, + sampler_type = s$update_method, + target_acceptance = s$target_accept, + max_tree_depth = s$nuts_max_depth, + num_leapfrogs = s$hmc_num_leapfrogs, + na_impute = m$na_impute, + missing_index_discrete_nullable = m$missing_index_discrete, + missing_index_continuous_nullable = m$missing_index_continuous + ) + + out_raw +} + + # ============================================================================== # run_sampler_compare() # ============================================================================== diff --git a/R/simulate_predict.R b/R/simulate_predict.R index b1b291b2..e6158064 100644 --- a/R/simulate_predict.R +++ b/R/simulate_predict.R @@ -9,13 +9,33 @@ # ============================================================================== +# ------------------------------------------------------------------ +# expand_variable_type +# ------------------------------------------------------------------ +# Recycles a scalar variable_type to length num_variables. +# +# @param variable_type Character vector (possibly length 1). +# @param num_variables Target length. +# +# Returns: Character vector of length num_variables. +# ------------------------------------------------------------------ +expand_variable_type = function(variable_type, num_variables) { + if(length(variable_type) == 1) { + rep(variable_type, num_variables) + } else { + variable_type + } +} + + #' Simulate Observations from a Markov Random Field #' #' @description -#' `simulate_mrf()` generates observations from a Markov Random Field using -#' user-specified parameters. For ordinal and Blume-Capel variables, observations -#' are generated via Gibbs sampling. For continuous variables (Gaussian graphical -#' model), observations are drawn directly from the multivariate normal +#' `simulate_mrf()` generates observations from a Markov Random +#' Field using user-specified parameters. For ordinal and +#' Blume-Capel variables, observations are generated via Gibbs +#' sampling. For continuous variables (Gaussian graphical model), +#' observations are drawn directly from the multivariate normal #' distribution implied by the precision matrix. #' #' @details @@ -41,7 +61,8 @@ #' have a special type of baseline_category category, such as the neutral #' category in a Likert scale. The Blume-Capel model specifies the following #' quadratic model for the threshold parameters: -#' \deqn{\mu_{\text{c}} = \alpha \times (\text{c} - \text{r}) + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times (\text{c} - \text{r}) + \beta \times (\text{c} - \text{r})^2,}} +#' \deqn{\mu_{\text{c}} = \alpha (\text{c} - \text{r}) +#' + \beta (\text{c} - \text{r})^2} #' where \eqn{\mu_{\text{c}}}{\mu_{\text{c}}} is the threshold for category c #' (which now includes zero), \eqn{\alpha}{\alpha} offers a linear trend #' across categories (increasing threshold values if @@ -55,15 +76,18 @@ #' #' @param num_variables The number of variables in the MRF. #' -#' @param num_categories Either a positive integer or a vector of positive -#' integers of length \code{num_variables}. The number of response categories on top -#' of the base category: \code{num_categories = 1} generates binary states. +#' @param num_categories Either a positive integer or a vector +#' of positive integers of length \code{num_variables}. The +#' number of response categories on top of the base category: +#' \code{num_categories = 1} generates binary states. #' Only used for ordinal and Blume-Capel variables; ignored when #' \code{variable_type = "continuous"}. #' -#' @param pairwise A symmetric \code{num_variables} by \code{num_variables} matrix. -#' For ordinal and Blume-Capel variables, this contains the pairwise interaction -#' parameters; only the off-diagonal elements are used. For continuous variables, +#' @param pairwise A symmetric \code{num_variables} by +#' \code{num_variables} matrix. For ordinal and Blume-Capel +#' variables, this contains the pairwise interaction parameters; +#' only the off-diagonal elements are used. For continuous +#' variables, #' this is the precision matrix \eqn{\Omega}{Omega} (including diagonal) and #' must be positive definite. #' @@ -71,7 +95,8 @@ #' \code{num_variables} by \code{max(num_categories)} matrix of category #' thresholds. The elements in row \code{i} indicate the thresholds of #' variable \code{i}. If \code{num_categories} is a vector, only the first -#' \code{num_categories[i]} elements are used in row \code{i}. If the Blume-Capel +#' \code{num_categories[i]} elements are used in row \code{i}. +#' If the Blume-Capel #' model is used for the category thresholds for variable \code{i}, then row #' \code{i} requires two values (details below); the first is #' \eqn{\alpha}{\alpha}, the linear contribution of the Blume-Capel model and @@ -83,16 +108,18 @@ #' @param variable_type What kind of variables are simulated? Can be a single #' character string specifying the variable type of all \code{p} variables at #' once or a vector of character strings of length \code{p} specifying the type -#' for each variable separately. Currently, bgm supports ``ordinal'', -#' ``blume-capel'', and ``continuous''. Binary variables are automatically -#' treated as ``ordinal''. Ordinal and Blume-Capel variables can be mixed +#' for each variable separately. Currently, bgm supports \code{"ordinal"}, +#' \code{"blume-capel"}, and \code{"continuous"}. Binary variables are automatically +#' treated as \code{"ordinal"}. Ordinal and Blume-Capel variables can be mixed #' freely, but continuous variables cannot be mixed with ordinal or Blume-Capel #' variables. When \code{variable_type = "continuous"}, the function simulates #' from a Gaussian graphical model. #' Defaults to \code{variable_type = "ordinal"}. #' -#' @param baseline_category An integer vector of length \code{num_variables} specifying the -#' baseline_category category that is used for the Blume-Capel model (details below). +#' @param baseline_category An integer vector of length +#' \code{num_variables} specifying the baseline_category +#' category that is used for the Blume-Capel model +#' (details below). #' Can be any integer value between \code{0} and \code{num_categories} (or #' \code{num_categories[i]}). #' @@ -182,7 +209,7 @@ simulate_mrf = function(num_states, baseline_category, iter = 1e3, seed = NULL) { - # Check num_states, num_variables --------------------------------------------- + # Check num_states, num_variables ------ check_positive_integer(num_states, "num_states") check_positive_integer(num_variables, "num_variables") @@ -201,7 +228,8 @@ simulate_mrf = function(num_states, bc_binary = variable_type == "blume-capel" & num_categories < 2 if(any(bc_binary)) { stop(paste0( - "The Blume-Capel model only works for ordinal variables with more than two \n", + "The Blume-Capel model only works for ordinal ", + "variables with more than two \n", "response options. But variables ", paste(which(bc_binary), collapse = ", "), " are binary variables." @@ -223,7 +251,10 @@ simulate_mrf = function(num_states, stop("The matrix 'pairwise' needs to be symmetric.") } if(nrow(pairwise) != num_variables) { - stop("The matrix 'pairwise' needs to have 'num_variables' rows and columns.") + stop( + "The matrix 'pairwise' needs to have ", + "'num_variables' rows and columns." + ) } if(any(diag(pairwise) <= 0)) { stop("The diagonal of the precision matrix 'pairwise' must be positive.") @@ -266,19 +297,24 @@ simulate_mrf = function(num_states, # =========================================================================== check_positive_integer(iter, "iter") - # Check num_categories -------------------------------------------------------- + # Check num_categories ------ if(length(num_categories) == 1) { - if(num_categories <= 0 || - abs(num_categories - round(num_categories)) > .Machine$double.eps) { + not_pos_int = num_categories <= 0 || + abs(num_categories - round(num_categories)) > .Machine$double.eps + if(not_pos_int) { stop("``num_categories'' needs be a (vector of) positive integer(s).") } num_categories = rep(num_categories, num_variables) } else { for(variable in 1:num_variables) { - if(num_categories[variable] <= 0 || - abs(num_categories[variable] - round(num_categories[variable])) > - .Machine$double.eps) { - stop(paste("For variable", variable, "``num_categories'' was not a positive integer.")) + nc = num_categories[variable] + not_pos_int = nc <= 0 || abs(nc - round(nc)) > .Machine$double.eps + if(not_pos_int) { + stop(paste( + "For variable", variable, + "``num_categories'' was not a", + "positive integer." + )) } } } @@ -288,7 +324,10 @@ simulate_mrf = function(num_states, if(length(baseline_category) == 1) { baseline_category = rep(baseline_category, num_variables) } - if(any(baseline_category < 0) || any(abs(baseline_category - round(baseline_category)) > .Machine$double.eps)) { + bc_diff = abs(baseline_category - round(baseline_category)) + not_valid = any(baseline_category < 0) || + any(bc_diff > .Machine$double.eps) + if(not_valid) { stop(paste0( "For variables ", which(baseline_category < 0), @@ -299,7 +338,8 @@ simulate_mrf = function(num_states, stop(paste0( "For variables ", which(baseline_category - num_categories > 0), - " the ``baseline_category'' category was larger than the maximum category value." + " the ``baseline_category'' category was larger", + " than the maximum category value." )) } } @@ -312,7 +352,10 @@ simulate_mrf = function(num_states, stop("The matrix ``pairwise'' needs to be symmetric.") } if(nrow(pairwise) != num_variables) { - stop("The matrix ``pairwise'' needs to have ``num_variables'' rows and columns.") + stop( + "The matrix ``pairwise'' needs to have", + " ``num_variables'' rows and columns." + ) } # Check the threshold values ------------------------------------------------- @@ -341,16 +384,13 @@ simulate_mrf = function(num_states, for(variable in 1:num_variables) { if(variable_type[variable] != "blume-capel") { if(anyNA(main[variable, 1:num_categories[variable]])) { - tmp = which(is.na(main[variable, 1:num_categories[variable]])) - - string = paste(tmp, sep = ",") - + na_cats = which(is.na(main[variable, 1:num_categories[variable]])) stop(paste0( "The matrix ``main'' contains NA(s) for variable ", variable, " in category \n", "(categories) ", - paste(which(is.na(main[variable, 1:num_categories[variable]])), collapse = ", "), + paste(na_cats, collapse = ", "), ", where a numeric value is needed." )) } @@ -370,24 +410,31 @@ simulate_mrf = function(num_states, } else { if(anyNA(main[variable, 1:2])) { stop(paste0( - "The Blume-Capel model is chosen for the category thresholds of variable ", + "The Blume-Capel model is chosen for the ", + "category thresholds of variable ", variable, ". \n", - "This model has two parameters that need to be placed in columns 1 and 2, row \n", + "This model has two parameters that need ", + "to be placed in columns 1 and 2, row \n", variable, - ", of the ``main'' input matrix. Currently, there are NA(s) in these \n", + ", of the ``main'' input matrix. ", + "Currently, there are NA(s) in these \n", "entries, where a numeric value is needed." )) } if(ncol(main) > 2) { if(!anyNA(main[variable, 3:ncol(main)])) { warning(paste0( - "The Blume-Capel model is chosen for the category thresholds of variable ", + "The Blume-Capel model is chosen for ", + "the category thresholds of variable ", variable, ". \n", - "This model has two parameters that need to be placed in columns 1 and 2, row \n", + "This model has two parameters that ", + "need to be placed in columns 1 and ", + "2, row \n", variable, - ", of the ``main'' input matrix. However, there are numeric values \n", + ", of the ``main'' input matrix. ", + "However, there are numeric values \n", "in higher categories. These values will be ignored." )) } @@ -510,7 +557,8 @@ mrfSampler = function(num_states, #' #' @description #' Generates new observations from the Markov Random Field model using the -#' estimated parameters from a fitted \code{bgms} object. +#' estimated parameters from a fitted \code{bgms} object. Supports ordinal, +#' Blume-Capel, continuous (GGM), and mixed MRF models. #' #' @param object An object of class \code{bgms}. #' @param nsim Number of observations to simulate. Default: \code{500}. @@ -524,7 +572,8 @@ mrfSampler = function(num_states, #' uses parallel processing when \code{cores > 1}.} #' } #' @param ndraws Number of posterior draws to use when -#' \code{method = "posterior-sample"}. If \code{NULL}, uses all available draws. +#' \code{method = "posterior-sample"}. If \code{NULL}, +#' uses all available draws. #' @param iter Number of Gibbs iterations for equilibration before collecting #' samples. Default: \code{1000}. #' @param cores Number of CPU cores for parallel execution when @@ -542,9 +591,14 @@ mrfSampler = function(num_states, #' If \code{method = "posterior-sample"}: A list of matrices, one per posterior #' draw, each with \code{nsim} rows and \code{p} columns. #' +#' For mixed MRF models, discrete columns contain non-negative integers and +#' continuous columns contain real-valued observations, ordered as in the +#' original data. +#' #' @details -#' This function uses the estimated interaction and threshold parameters to -#' generate new data via Gibbs sampling. When \code{method = "posterior-sample"}, +#' This function uses the estimated interaction and threshold +#' parameters to generate new data via Gibbs sampling. When +#' \code{method = "posterior-sample"}, parameter uncertainty is #' parameter uncertainty is propagated to the simulated data by using different #' posterior draws. Parallel processing is available for this method via the #' \code{cores} argument. @@ -562,7 +616,11 @@ mrfSampler = function(num_states, #' new_data = simulate(fit, nsim = 100) #' #' # Simulate with parameter uncertainty (10 datasets) -#' new_data_list = simulate(fit, nsim = 100, method = "posterior-sample", ndraws = 10) +#' new_data_list = simulate( +#' fit, +#' nsim = 100, +#' method = "posterior-sample", ndraws = 10 +#' ) #' #' # Use parallel processing for faster simulation #' new_data_list = simulate(fit, @@ -601,10 +659,7 @@ simulate.bgms = function(object, data_columnnames = arguments$data_columnnames # Handle variable_type - - if(length(variable_type) == 1) { - variable_type = rep(variable_type, num_variables) - } + variable_type = expand_variable_type(variable_type, num_variables) # Get baseline_category (for Blume-Capel variables) baseline_category = arguments$baseline_category @@ -629,6 +684,23 @@ simulate.bgms = function(object, )) } + # ============================================================================ + # Mixed MRF (discrete + continuous) path + # ============================================================================ + if(isTRUE(arguments$is_mixed)) { + return(simulate_bgms_mixed( + object = object, + nsim = nsim, + seed = seed, + method = method, + ndraws = ndraws, + arguments = arguments, + iter = iter, + cores = cores, + progress_type = progress_type + )) + } + # ============================================================================ # OMRF (ordinal / Blume-Capel) path # ============================================================================ @@ -696,9 +768,9 @@ simulate.bgms = function(object, } -# ============================================================================== -# simulate.bgmCompare() - S3 Method for Simulating from Group-Comparison Models -# ============================================================================== +# ============================================================ +# simulate.bgmCompare() - S3 Method for Group-Comparison +# ============================================================ #' Simulate Data from a Fitted bgmCompare Model #' @@ -762,15 +834,23 @@ simulate.bgmCompare = function(object, # Validate group argument if(missing(group)) { - stop("Argument 'group' is required. Specify which group to simulate from (1 to num_groups).") + stop( + "Argument 'group' is required. ", + "Specify which group to simulate from ", + "(1 to num_groups)." + ) } arguments = extract_arguments(object) num_groups = arguments$num_groups - if(!is.numeric(group) || length(group) != 1 || is.na(group) || - group < 1 || group > num_groups) { - stop(sprintf("Argument 'group' must be an integer between 1 and %d.", num_groups)) + invalid_group = !is.numeric(group) || length(group) != 1 || + is.na(group) || group < 1 || group > num_groups + if(invalid_group) { + stop(sprintf( + "Argument 'group' must be an integer between 1 and %d.", + num_groups + )) } group = as.integer(group) @@ -782,7 +862,6 @@ simulate.bgmCompare = function(object, num_categories = arguments$num_categories is_ordinal = arguments$is_ordinal_variable data_columnnames = arguments$data_columnnames - projection = arguments$projection # [num_groups x (num_groups-1)] # Determine variable_type from is_ordinal variable_type = ifelse(is_ordinal, "ordinal", "blume-capel") @@ -851,7 +930,8 @@ simulate.bgmCompare = function(object, #' #' @description #' Computes conditional probability distributions for one or more variables -#' given the observed values of other variables in the data. +#' given the observed values of other variables in the data. Supports ordinal, +#' Blume-Capel, continuous (GGM), and mixed MRF models. #' #' @param object An object of class \code{bgms}. #' @param newdata A matrix or data frame with \code{n} rows and \code{p} columns @@ -873,10 +953,12 @@ simulate.bgmCompare = function(object, #' @param method Character string specifying which parameter estimates to use: #' \describe{ #' \item{\code{"posterior-mean"}}{Use posterior mean parameters.} -#' \item{\code{"posterior-sample"}}{Average predictions over posterior draws.} +#' \item{\code{"posterior-sample"}}{Average predictions +#' over posterior draws.} #' } #' @param ndraws Number of posterior draws to use when -#' \code{method = "posterior-sample"}. If \code{NULL}, uses all available draws. +#' \code{method = "posterior-sample"}. If \code{NULL}, +#' uses all available draws. #' @param seed Optional random seed for reproducibility when #' \code{method = "posterior-sample"}. #' @param ... Additional arguments (currently ignored). @@ -886,8 +968,9 @@ simulate.bgmCompare = function(object, #' #' For \code{type = "probabilities"}: A named list with one element per #' predicted variable. Each element is a matrix with \code{n} rows and -#' \code{num_categories + 1} columns containing \eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} for each -#' observation and category. +#' \code{num_categories + 1} columns containing +#' \eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} +#' for each observation and category. #' #' For \code{type = "response"}: A matrix with \code{n} rows and #' \code{length(variables)} columns containing predicted categories. @@ -910,6 +993,13 @@ simulate.bgmCompare = function(object, #' When \code{method = "posterior-sample"}, conditional parameters are #' averaged over posterior draws, and an attribute \code{"sd"} is included. #' +#' \strong{Mixed MRF models:} +#' +#' For mixed models, the return list contains elements for both discrete and +#' continuous predicted variables. Discrete variables return probability +#' matrices (as in ordinal models); continuous variables return conditional +#' mean and SD matrices (as in GGM models). +#' #' @details #' For each observation, the function computes the conditional distribution #' of the target variable(s) given the observed values of all other variables. @@ -918,7 +1008,8 @@ simulate.bgmCompare = function(object, #' #' For GGM (continuous) models, the conditional distribution of #' \eqn{X_j | X_{-j}}{X_j | X_{-j}} is Gaussian with mean -#' \eqn{-\omega_{jj}^{-1} \sum_{k \neq j} \omega_{jk} x_k}{-omega_jj^{-1} sum_{k != j} omega_jk x_k} +#' \eqn{-\omega_{jj}^{-1} \sum_{k \neq j} +#' \omega_{jk} x_k}{-omega_jj^{-1} sum_{k != j} omega_jk x_k} #' and variance \eqn{\omega_{jj}^{-1}}{omega_jj^{-1}}, where \eqn{\Omega}{Omega} #' is the precision matrix. #' @@ -961,7 +1052,11 @@ predict.bgms = function(object, # Validate newdata if(missing(newdata)) { - stop("Argument 'newdata' is required. Provide the data for which to compute predictions.") + stop( + "Argument 'newdata' is required. ", + "Provide the data for which to ", + "compute predictions." + ) } if(!inherits(newdata, "matrix") && !inherits(newdata, "data.frame")) { @@ -983,15 +1078,14 @@ predict.bgms = function(object, if(ncol(newdata) != num_variables) { stop(paste0( - "'newdata' must have ", num_variables, " columns (same as fitted model), ", + "'newdata' must have ", num_variables, + " columns (same as fitted model), ", "but has ", ncol(newdata), "." )) } # Handle variable_type - if(length(variable_type) == 1) { - variable_type = rep(variable_type, num_variables) - } + variable_type = expand_variable_type(variable_type, num_variables) # Get baseline_category baseline_category = arguments$baseline_category @@ -1008,7 +1102,13 @@ predict.bgms = function(object, } else if(is.character(variables)) { predict_vars = match(variables, data_columnnames) if(anyNA(predict_vars)) { - stop("Variable names not found: ", paste(variables[is.na(predict_vars)], collapse = ", ")) + stop( + "Variable names not found: ", + paste( + variables[is.na(predict_vars)], + collapse = ", " + ) + ) } } else { predict_vars = as.integer(variables) @@ -1033,12 +1133,29 @@ predict.bgms = function(object, )) } + # ============================================================================ + # Mixed MRF (discrete + continuous) path + # ============================================================================ + if(isTRUE(arguments$is_mixed)) { + return(predict_bgms_mixed( + object = object, + newdata = newdata, + predict_vars = predict_vars, + arguments = arguments, + type = type, + method = method, + ndraws = ndraws + )) + } + # ============================================================================ # OMRF (ordinal) path # ============================================================================ # Recode data to 0-based integers (matching what bgm() does) - newdata_recoded = recode_data_for_prediction(newdata, num_categories, is_ordinal) + newdata_recoded = recode_data_for_prediction( + newdata, num_categories, is_ordinal + ) if(method == "posterior-mean") { # Use posterior mean parameters @@ -1183,9 +1300,10 @@ predict.bgms = function(object, #' @param ... Additional arguments (currently ignored). #' #' @return -#' For \code{type = "probabilities"}: A named list with one element per -#' predicted variable. Each element is a matrix with \code{n} rows and -#' \code{num_categories + 1} columns containing \eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} +#' For \code{type = "probabilities"}: A named list with one +#' element per predicted variable. Each element is a matrix with +#' \code{n} rows and \code{num_categories + 1} columns containing +#' \eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} #' for each observation and category. #' #' For \code{type = "response"}: A matrix with \code{n} rows and @@ -1197,8 +1315,10 @@ predict.bgms = function(object, #' The function then computes the conditional distribution of target variables #' given the observed values of all other variables. #' -#' @seealso \code{\link{predict.bgms}} for predicting from single-group models, -#' \code{\link{simulate.bgmCompare}} for simulating from group-comparison models. +#' @seealso \code{\link{predict.bgms}} for predicting +#' from single-group models, +#' \code{\link{simulate.bgmCompare}} for simulating +#' from group-comparison models. #' @family prediction #' #' @examples @@ -1228,21 +1348,33 @@ predict.bgmCompare = function(object, # Validate group argument if(missing(group)) { - stop("Argument 'group' is required. Specify which group's parameters to use (1 to num_groups).") + stop( + "Argument 'group' is required. ", + "Specify which group's parameters ", + "to use (1 to num_groups)." + ) } arguments = extract_arguments(object) num_groups = arguments$num_groups - if(!is.numeric(group) || length(group) != 1 || is.na(group) || - group < 1 || group > num_groups) { - stop(sprintf("Argument 'group' must be an integer between 1 and %d.", num_groups)) + invalid_group = !is.numeric(group) || length(group) != 1 || + is.na(group) || group < 1 || group > num_groups + if(invalid_group) { + stop(sprintf( + "Argument 'group' must be an integer between 1 and %d.", + num_groups + )) } group = as.integer(group) # Validate newdata if(missing(newdata)) { - stop("Argument 'newdata' is required. Provide the data for which to compute predictions.") + stop( + "Argument 'newdata' is required. ", + "Provide the data for which to ", + "compute predictions." + ) } if(!inherits(newdata, "matrix") && !inherits(newdata, "data.frame")) { @@ -1258,12 +1390,12 @@ predict.bgmCompare = function(object, num_categories = arguments$num_categories is_ordinal = arguments$is_ordinal_variable data_columnnames = arguments$data_columnnames - projection = arguments$projection # Validate dimensions if(ncol(newdata) != num_variables) { stop(paste0( - "'newdata' must have ", num_variables, " columns (same as fitted model), ", + "'newdata' must have ", num_variables, + " columns (same as fitted model), ", "but has ", ncol(newdata), "." )) } @@ -1283,7 +1415,13 @@ predict.bgmCompare = function(object, } else if(is.character(variables)) { predict_vars = match(variables, data_columnnames) if(anyNA(predict_vars)) { - stop("Variable names not found: ", paste(variables[is.na(predict_vars)], collapse = ", ")) + stop( + "Variable names not found: ", + paste( + variables[is.na(predict_vars)], + collapse = ", " + ) + ) } } else { predict_vars = as.integer(variables) @@ -1293,7 +1431,9 @@ predict.bgmCompare = function(object, } # Recode data to 0-based integers - newdata_recoded = recode_data_for_prediction(newdata, num_categories, is_ordinal) + newdata_recoded = recode_data_for_prediction( + newdata, num_categories, is_ordinal + ) if(method == "posterior-mean") { # Extract group-specific parameters using projection @@ -1364,11 +1504,9 @@ predict.bgmCompare = function(object, # ============================================================================== # Helper function to reconstruct threshold matrix from flat vector -reconstruct_main = function(main_vec, num_variables, num_categories, variable_type) { - if(length(variable_type) == 1) { - variable_type = rep(variable_type, num_variables) - } - +reconstruct_main = function(main_vec, num_variables, + num_categories, + variable_type) { max_cats = max(num_categories) main = matrix(NA, nrow = num_variables, ncol = max_cats) @@ -1415,16 +1553,14 @@ recode_data_for_prediction = function(x, num_categories, is_ordinal) { # Reconstruct the full precision matrix from posterior mean components. # # @param posterior_mean_pairwise p x p symmetric matrix with off-diagonal -# precision elements (diagonal is zero). -# @param posterior_mean_main p x 1 matrix with diagonal precision elements -# (column named "precision_diag"). +# precision elements. For GGM and mixed MRF models the precision matrix +# diagonal is already included on the matrix diagonal. # # @return p x p precision matrix (Omega). -reconstruct_precision = function(posterior_mean_pairwise, posterior_mean_main) { +reconstruct_precision = function(posterior_mean_pairwise) { omega = posterior_mean_pairwise # Excluded edges (NA) have zero precision omega[is.na(omega)] = 0 - diag(omega) = as.numeric(posterior_mean_main) return(omega) } @@ -1437,7 +1573,7 @@ reconstruct_precision = function(posterior_mean_pairwise, posterior_mean_main) { # @param p Number of variables. # # @return p x p precision matrix (Omega). -reconstruct_precision_from_draw = function(pairwise_vec, main_vec, p) { +build_precision_from_draw = function(pairwise_vec, main_vec, p) { omega = matrix(0, nrow = p, ncol = p) omega[lower.tri(omega)] = pairwise_vec omega = omega + t(omega) @@ -1468,8 +1604,7 @@ predict_bgms_ggm = function(object, newdata, predict_vars, data_columnnames, if(method == "posterior-mean") { # Reconstruct precision matrix from posterior means omega = reconstruct_precision( - object$posterior_mean_pairwise, - object$posterior_mean_main + object$posterior_mean_pairwise ) result = compute_conditional_ggm( @@ -1482,7 +1617,9 @@ predict_bgms_ggm = function(object, newdata, predict_vars, data_columnnames, names(result) = data_columnnames[predict_vars] for(v in seq_along(result)) { colnames(result[[v]]) = c("mean", "sd") - result[[v]][, "mean"] = result[[v]][, "mean"] + newdata_means[predict_vars[v]] + result[[v]][, "mean"] = + result[[v]][, "mean"] + + newdata_means[predict_vars[v]] } } else { # Use posterior samples @@ -1503,7 +1640,7 @@ predict_bgms_ggm = function(object, newdata, predict_vars, data_columnnames, for(i in seq_len(ndraws)) { idx = draw_indices[i] - omega = reconstruct_precision_from_draw( + omega = build_precision_from_draw( pairwise_vec = pairwise_samples[idx, ], main_vec = main_samples[idx, ], p = num_variables @@ -1581,10 +1718,9 @@ simulate_bgms_ggm = function(object, nsim, seed, method, ndraws, num_variables, data_columnnames, cores, progress_type) { if(method == "posterior-mean") { - # Reconstruct precision matrix: inject diagonal from posterior_mean_main + # Reconstruct precision matrix: diagonal is already on posterior_mean_pairwise precision = reconstruct_precision( - object$posterior_mean_pairwise, - object$posterior_mean_main + object$posterior_mean_pairwise ) # Call simulate_mrf with variable_type = "continuous" @@ -1635,3 +1771,558 @@ simulate_bgms_ggm = function(object, nsim, seed, method, ndraws, return(results) } } + + +# ============================================================================== +# Mixed MRF Simulation Helper +# ============================================================================== + +# ------------------------------------------------------------------ +# simulate_bgms_mixed +# ------------------------------------------------------------------ +# Simulation implementation for mixed MRF models (called from simulate.bgms). +# +# @param object Fitted bgms object (mixed MRF). +# @param nsim Number of observations to simulate. +# @param seed Random seed. +# @param method "posterior-mean" or "posterior-sample". +# @param ndraws Number of posterior draws (for posterior-sample). +# @param arguments Output of extract_arguments(). +# @param iter Gibbs burn-in iterations. +# @param cores Number of threads. +# @param progress_type Progress bar type. +# +# Returns: matrix (posterior-mean) or list of matrices (posterior-sample). +# ------------------------------------------------------------------ +simulate_bgms_mixed = function(object, nsim, seed, method, ndraws, + arguments, iter, cores, progress_type) { + p = arguments$num_discrete + q = arguments$num_continuous + data_columnnames = arguments$data_columnnames + disc_idx = arguments$discrete_indices + cont_idx = arguments$continuous_indices + num_categories = arguments$num_categories + is_ordinal = arguments$is_ordinal + baseline_category_disc = arguments$baseline_category + + disc_variable_type = ifelse(is_ordinal, "ordinal", "blume-capel") + + bc = integer(p) + for(s in seq_len(p)) { + if(is_ordinal[s]) { + bc[s] = 0L + } else { + bc[s] = as.integer(baseline_category_disc[s]) + } + } + + if(method == "posterior-mean") { + params = build_mixed_params_mean(object, arguments) + + seed = check_seed(seed) + + result = sample_mixed_mrf_gibbs( + num_states = as.integer(nsim), + Kxx_r = params$Kxx, + Kxy_r = params$Kxy, + Kyy_r = params$Kyy, + mux_r = params$mux, + muy_r = params$muy, + num_categories_r = as.integer(num_categories), + variable_type_r = disc_variable_type, + baseline_category_r = as.integer(bc), + iter = as.integer(iter), + seed = seed + ) + + out = combine_mixed_result(result, disc_idx, cont_idx, data_columnnames) + return(out) + } else { + sample_info = split_mixed_raw_samples(object, arguments) + + total_draws = sample_info$total_draws + if(is.null(ndraws)) ndraws = total_draws + ndraws = min(ndraws, total_draws) + + if(!is.null(seed)) set.seed(seed) + draw_indices = sample.int(total_draws, ndraws) + + results = run_mixed_simulation_parallel( + mux_samples = sample_info$mux_samples, + kxx_samples = sample_info$kxx_samples, + muy_samples = sample_info$muy_samples, + kyy_samples = sample_info$kyy_samples, + kxy_samples = sample_info$kxy_samples, + draw_indices = as.integer(draw_indices), + num_states = as.integer(nsim), + p = as.integer(p), + q = as.integer(q), + num_categories = as.integer(num_categories), + variable_type_r = disc_variable_type, + baseline_category = as.integer(bc), + iter = as.integer(iter), + nThreads = cores, + seed = check_seed(seed), + progress_type = progress_type + ) + + for(i in seq_along(results)) { + results[[i]] = combine_mixed_result( + results[[i]], disc_idx, cont_idx, data_columnnames + ) + } + + return(results) + } +} + + +# ============================================================================== +# Mixed MRF Prediction Helper +# ============================================================================== + +# ------------------------------------------------------------------ +# predict_bgms_mixed +# ------------------------------------------------------------------ +# Prediction implementation for mixed MRF models (called from predict.bgms). +# +# @param object Fitted bgms object (mixed MRF). +# @param newdata n x (p+q) matrix of observed data. +# @param predict_vars 1-based indices into the combined variable list. +# @param arguments Output of extract_arguments(). +# @param type "probabilities" or "response". +# @param method "posterior-mean" or "posterior-sample". +# @param ndraws Number of posterior draws (for posterior-sample). +# +# Returns: Named list of prediction matrices. +# ------------------------------------------------------------------ +predict_bgms_mixed = function(object, newdata, predict_vars, arguments, + type, method, ndraws) { + p = arguments$num_discrete + q = arguments$num_continuous + data_columnnames = arguments$data_columnnames + disc_idx = arguments$discrete_indices + cont_idx = arguments$continuous_indices + num_categories = arguments$num_categories + is_ordinal = arguments$is_ordinal + baseline_category_disc = arguments$baseline_category + + disc_variable_type = ifelse(is_ordinal, "ordinal", "blume-capel") + + bc = integer(p) + for(s in seq_len(p)) { + if(is_ordinal[s]) { + bc[s] = 0L + } else { + bc[s] = as.integer(baseline_category_disc[s]) + } + } + + # Split newdata into discrete and continuous parts + x_data = as.matrix(newdata[, disc_idx, drop = FALSE]) + storage.mode(x_data) = "integer" + y_data = as.matrix(newdata[, cont_idx, drop = FALSE]) + storage.mode(y_data) = "double" + + # Map user predict_vars (1-based, original order) to internal 0-based indices + # Internal layout: [discrete_0..p-1, continuous_p..p+q-1] + internal_predict_vars = integer(length(predict_vars)) + for(k in seq_along(predict_vars)) { + orig_col = predict_vars[k] + disc_match = match(orig_col, disc_idx) + cont_match = match(orig_col, cont_idx) + if(!is.na(disc_match)) { + internal_predict_vars[k] = disc_match - 1L + } else if(!is.na(cont_match)) { + internal_predict_vars[k] = p + cont_match - 1L + } else { + stop( + "Variable index ", orig_col, + " not found in discrete or continuous indices." + ) + } + } + + compute_one_draw = function(Kxx, Kxy, Kyy, mux, muy) { + compute_conditional_mixed( + x_observations = x_data, + y_observations = y_data, + predict_vars = as.integer(internal_predict_vars), + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = as.integer(num_categories), + variable_type = disc_variable_type, + baseline_category = as.integer(bc) + ) + } + + if(method == "posterior-mean") { + params = build_mixed_params_mean(object, arguments) + raw_result = compute_one_draw( + params$Kxx, params$Kxy, params$Kyy, params$mux, params$muy + ) + + probs = format_mixed_predictions( + raw_result, predict_vars, internal_predict_vars, + p, num_categories, data_columnnames + ) + } else { + sample_info = split_mixed_raw_samples(object, arguments) + total_draws = sample_info$total_draws + if(is.null(ndraws)) ndraws = total_draws + ndraws = min(ndraws, total_draws) + + draw_indices = sample.int(total_draws, ndraws) + + all_results = vector("list", ndraws) + for(i in seq_len(ndraws)) { + params_i = build_mixed_params_row( + sample_info, draw_indices[i], p, q, num_categories, is_ordinal + ) + all_results[[i]] = compute_one_draw( + params_i$Kxx, params_i$Kxy, params_i$Kyy, params_i$mux, params_i$muy + ) + } + + # Average predictions across draws + num_pv = length(predict_vars) + probs = vector("list", num_pv) + probs_sd = vector("list", num_pv) + names(probs) = data_columnnames[predict_vars] + names(probs_sd) = data_columnnames[predict_vars] + + for(k in seq_len(num_pv)) { + # Stack all draws into an array: n x ncol x ndraws + var_preds = lapply(all_results, `[[`, k) + pred_array = array(unlist(var_preds), + dim = c(nrow(var_preds[[1]]), ncol(var_preds[[1]]), ndraws) + ) + + probs[[k]] = apply(pred_array, c(1, 2), mean) + probs_sd[[k]] = apply(pred_array, c(1, 2), sd) + } + + probs = format_mixed_predictions( + probs, predict_vars, internal_predict_vars, + p, num_categories, data_columnnames + ) + names(probs_sd) = names(probs) + attr(probs, "sd") = probs_sd + } + + if(type == "response") { + return(format_mixed_response( + probs, predict_vars, internal_predict_vars, + p, data_columnnames + )) + } + + return(probs) +} + + +# ============================================================================== +# Mixed MRF Internal Helpers +# ============================================================================== + +# ------------------------------------------------------------------ +# build_mixed_params_mean +# ------------------------------------------------------------------ +# Reconstruct Kxx, Kxy, Kyy, mux, muy from posterior mean summaries. +# +# @param object Fitted bgms object (mixed MRF). +# @param arguments Output of extract_arguments(). +# +# Returns: list with Kxx, Kxy, Kyy, mux, muy. +# ------------------------------------------------------------------ +build_mixed_params_mean = function(object, arguments) { + p = arguments$num_discrete + q = arguments$num_continuous + disc_idx = arguments$discrete_indices + cont_idx = arguments$continuous_indices + + pmat = object$posterior_mean_pairwise + + Kxx = matrix(0, p, p) + for(i in seq_len(p)) { + for(j in seq_len(p)) { + if(i != j) Kxx[i, j] = pmat[disc_idx[i], disc_idx[j]] + } + } + + Kxy = matrix(0, p, q) + for(i in seq_len(p)) { + for(j in seq_len(q)) { + Kxy[i, j] = pmat[disc_idx[i], cont_idx[j]] + } + } + + Kyy = matrix(0, q, q) + for(i in seq_len(q)) { + for(j in seq_len(q)) { + Kyy[i, j] = pmat[cont_idx[i], cont_idx[j]] + } + } + + mux = object$posterior_mean_main$discrete + mux[is.na(mux)] = 0 + + muy = as.numeric(object$posterior_mean_main$continuous[, "mean"]) + + list(Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, mux = mux, muy = muy) +} + + +# ------------------------------------------------------------------ +# split_mixed_raw_samples +# ------------------------------------------------------------------ +# Split raw main and pairwise sample matrices into separate component +# matrices for the C++ parallel simulation worker. +# +# @param object Fitted bgms object (mixed MRF). +# @param arguments Output of extract_arguments(). +# +# Returns: list with mux_samples, kxx_samples, muy_samples, +# kyy_samples, kxy_samples, total_draws. +# ------------------------------------------------------------------ +split_mixed_raw_samples = function(object, arguments) { + p = arguments$num_discrete + q = arguments$num_continuous + num_categories = arguments$num_categories + is_ordinal = arguments$is_ordinal + + main_all = do.call(rbind, object$raw_samples$main) + pairwise_all = do.call(rbind, object$raw_samples$pairwise) + total_draws = nrow(main_all) + + # Main layout: [mux_flat | muy | kyy_diag] + num_mux = sum(ifelse(is_ordinal, num_categories, 2L)) + mux_cols = seq_len(num_mux) + muy_cols = num_mux + seq_len(q) + kyy_diag_cols = num_mux + q + seq_len(q) + + mux_samples = main_all[, mux_cols, drop = FALSE] + muy_samples = main_all[, muy_cols, drop = FALSE] + kyy_diag_values = main_all[, kyy_diag_cols, drop = FALSE] + + # Pairwise layout: [kxx_ut | kyy_offdiag | kxy] + nxx = as.integer(p * (p - 1) / 2) + nyy_offdiag = as.integer(q * (q - 1) / 2) + nxy = as.integer(p * q) + + kyy_off_end = nxx + nyy_offdiag + kxy_end = kyy_off_end + nxy + + kxx_samples = if(nxx > 0) { + pairwise_all[, seq_len(nxx), drop = FALSE] + } else { + matrix(0, nrow = total_draws, ncol = 0) + } + + kyy_offdiag_values = if(nyy_offdiag > 0) { + pairwise_all[, (nxx + 1):kyy_off_end, drop = FALSE] + } else { + matrix(0, nrow = total_draws, ncol = 0) + } + + kxy_samples = if(nxy > 0) { + pairwise_all[, (kyy_off_end + 1):kxy_end, drop = FALSE] + } else { + matrix(0, nrow = total_draws, ncol = 0) + } + + # Combine Kyy diagonal and off-diagonal into upper-triangle format + # C++ expects column-major upper-triangle including diagonal + nyy_total = as.integer(q * (q + 1) / 2) + kyy_samples = matrix(0, nrow = total_draws, ncol = nyy_total) + diag_pos = 0L + offdiag_pos = 0L + out_pos = 0L + for(col in seq_len(q)) { + for(row in col:q) { + out_pos = out_pos + 1L + if(row == col) { + diag_pos = diag_pos + 1L + kyy_samples[, out_pos] = kyy_diag_values[, diag_pos] + } else { + offdiag_pos = offdiag_pos + 1L + kyy_samples[, out_pos] = kyy_offdiag_values[, offdiag_pos] + } + } + } + + list( + mux_samples = mux_samples, + kxx_samples = kxx_samples, + muy_samples = muy_samples, + kyy_samples = kyy_samples, + kxy_samples = kxy_samples, + total_draws = total_draws + ) +} + + +# ------------------------------------------------------------------ +# build_mixed_params_row +# ------------------------------------------------------------------ +# Reconstruct Kxx, Kxy, Kyy, mux, muy from a single row of split +# sample matrices (used by predict posterior-sample). +# +# @param sample_info Output of split_mixed_raw_samples(). +# @param row_idx 1-based row index. +# @param p Number of discrete variables. +# @param q Number of continuous variables. +# @param num_categories Categories per discrete variable. +# @param is_ordinal Logical vector. +# +# Returns: list with Kxx, Kxy, Kyy, mux, muy. +# ------------------------------------------------------------------ +build_mixed_params_row = function(sample_info, row_idx, + p, q, num_categories, + is_ordinal) { + mux_vec = sample_info$mux_samples[row_idx, ] + num_params_disc = ifelse(is_ordinal, num_categories, 2L) + max_cats = max(num_params_disc) + mux = matrix(0, nrow = p, ncol = max_cats) + pos = 1L + for(s in seq_len(p)) { + nc = num_params_disc[s] + mux[s, seq_len(nc)] = mux_vec[pos:(pos + nc - 1L)] + pos = pos + nc + } + + Kxx = matrix(0, p, p) + if(p > 1) { + kxx_vec = sample_info$kxx_samples[row_idx, ] + idx = 0L + for(col in seq_len(p - 1)) { + for(row in (col + 1):p) { + idx = idx + 1L + Kxx[row, col] = kxx_vec[idx] + Kxx[col, row] = kxx_vec[idx] + } + } + } + + muy = sample_info$muy_samples[row_idx, ] + + kyy_vec = sample_info$kyy_samples[row_idx, ] + Kyy = matrix(0, q, q) + idx = 0L + for(col in seq_len(q)) { + for(row in col:q) { + idx = idx + 1L + Kyy[row, col] = kyy_vec[idx] + if(row != col) Kyy[col, row] = kyy_vec[idx] + } + } + + Kxy = matrix(0, p, q) + if(p > 0 && q > 0) { + kxy_vec = sample_info$kxy_samples[row_idx, ] + idx = 0L + for(s in seq_len(p)) { + for(j in seq_len(q)) { + idx = idx + 1L + Kxy[s, j] = kxy_vec[idx] + } + } + } + + list(Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, mux = mux, muy = muy) +} + + +# ------------------------------------------------------------------ +# combine_mixed_result +# ------------------------------------------------------------------ +# Combine $x (n x p integer) and $y (n x q double) into a single +# n x (p+q) matrix in the original column order. +# +# @param result List with $x and $y matrices. +# @param disc_idx Original column indices for discrete variables. +# @param cont_idx Original column indices for continuous variables. +# @param colnames Original data column names. +# +# Returns: n x (p+q) numeric matrix. +# ------------------------------------------------------------------ +combine_mixed_result = function(result, disc_idx, cont_idx, colnames) { + n = nrow(result$x) + num_vars = length(disc_idx) + length(cont_idx) + out = matrix(NA_real_, nrow = n, ncol = num_vars) + out[, disc_idx] = result$x + out[, cont_idx] = result$y + colnames(out) = colnames + out +} + + +# ------------------------------------------------------------------ +# format_mixed_predictions +# ------------------------------------------------------------------ +# Add names and column labels to C++ prediction output. +# +# @param raw_result List from C++ compute_conditional_mixed. +# @param predict_vars 1-based user-facing variable indices. +# @param internal_predict_vars 0-based internal indices. +# @param p Number of discrete variables. +# @param num_categories Categories per discrete variable. +# @param data_columnnames Original data column names. +# +# Returns: Named list of prediction matrices. +# ------------------------------------------------------------------ +format_mixed_predictions = function(raw_result, predict_vars, + internal_predict_vars, p, + num_categories, data_columnnames) { + probs = raw_result + names(probs) = data_columnnames[predict_vars] + + for(k in seq_along(predict_vars)) { + int_idx = internal_predict_vars[k] + if(int_idx < p) { + s = int_idx + 1L + n_cats = num_categories[s] + 1 + colnames(probs[[k]]) = paste0("cat_", 0:(n_cats - 1)) + } else { + colnames(probs[[k]]) = c("mean", "sd") + } + } + + probs +} + + +# ------------------------------------------------------------------ +# format_mixed_response +# ------------------------------------------------------------------ +# Convert probability predictions to point predictions for mixed models. +# +# @param probs Named list of prediction matrices. +# @param predict_vars 1-based user-facing variable indices. +# @param internal_predict_vars 0-based internal indices. +# @param p Number of discrete variables. +# @param data_columnnames Original data column names. +# +# Returns: n x length(predict_vars) matrix of predicted values. +# ------------------------------------------------------------------ +format_mixed_response = function(probs, predict_vars, + internal_predict_vars, p, + data_columnnames) { + n = nrow(probs[[1]]) + out = matrix(NA_real_, nrow = n, ncol = length(predict_vars)) + colnames(out) = data_columnnames[predict_vars] + + for(k in seq_along(predict_vars)) { + int_idx = internal_predict_vars[k] + if(int_idx < p) { + out[, k] = apply(probs[[k]], 1, which.max) - 1L + } else { + out[, k] = probs[[k]][, 1] + } + } + + out +} diff --git a/R/validate_model.R b/R/validate_model.R index 49580c73..7593930c 100644 --- a/R/validate_model.R +++ b/R/validate_model.R @@ -26,6 +26,7 @@ validate_variable_types = function(variable_type, num_variables, allow_continuous = TRUE, + allow_mixed = FALSE, caller = "bgm") { valid_choices = if(allow_continuous) { c("ordinal", "blume-capel", "continuous") @@ -36,6 +37,7 @@ validate_variable_types = function(variable_type, supported_str = paste(valid_choices, collapse = ", ") is_continuous = FALSE + is_mixed = FALSE if(length(variable_type) == 1) { # --- Single string: replicate to all variables --- @@ -70,15 +72,58 @@ validate_variable_types = function(variable_type, } has_continuous = any(variable_type == "continuous") - if(has_continuous && !all(variable_type == "continuous")) { + has_discrete = any(variable_type %in% c("ordinal", "blume-capel")) + is_mixed = has_continuous && has_discrete + + if(has_continuous && !has_discrete) { + invalid_if_cont = setdiff(unique(variable_type), "continuous") + if(length(invalid_if_cont) > 0) { + stop(paste0( + "When using continuous variables, all variables must be of type ", + "'continuous' or mixed with ordinal/blume-capel variables." + )) + } + } + + if(has_continuous && !allow_continuous) { + stop(paste0( + "The ", caller, " function supports variables of type ", supported_str, + ", but not of type continuous." + )) + } + + if(is_mixed && !allow_mixed) { stop(paste0( "When using continuous variables, all variables must be of type ", "'continuous'. Mixtures of continuous and ordinal/blume-capel ", - "variables are not supported." + "variables are not supported by ", caller, "()." )) } - if(has_continuous) { + if(is_mixed) { + # Mixed: validate each entry individually + variable_type_checked = try( + match.arg( + arg = variable_type, + choices = valid_choices, + several.ok = TRUE + ), + silent = TRUE + ) + + if(inherits(variable_type_checked, what = "try-error")) { + invalid = setdiff(unique(variable_type), valid_choices) + stop(paste0( + "The ", caller, " function supports variables of type ", supported_str, + ", but not of type ", paste0(invalid, collapse = ", "), "." + )) + } + + variable_type = variable_type_checked + # variable_bool: TRUE = ordinal; FALSE = blume-capel/continuous + variable_bool = (variable_type == "ordinal") + is_continuous = FALSE + } else if(has_continuous) { if(!allow_continuous) { stop(paste0( "The ", caller, " function supports variables of type ", supported_str, @@ -144,7 +189,8 @@ validate_variable_types = function(variable_type, list( variable_type = variable_type, variable_bool = variable_bool, - is_continuous = is_continuous + is_continuous = is_continuous, + is_mixed = is_mixed ) } diff --git a/README.md b/README.md index 808b9799..b2c6927e 100644 --- a/README.md +++ b/README.md @@ -30,21 +30,23 @@ stable](https://img.shields.io/badge/lifecycle-stable-brightgreen.svg)](https:// **Bayesian analysis of graphical models** The **bgms** package implements Bayesian estimation and model comparison -for graphical models of binary, ordinal, and continuous variables +for graphical models of binary, ordinal, continuous, and mixed variables (Marsman, van den Bergh, et al., 2025). It -supports **ordinal Markov random fields (MRFs)** for discrete data and -**Gaussian graphical models (GGMs)** for continuous data. The likelihood -is approximated with a pseudolikelihood, and Markov chain Monte Carlo -(MCMC) methods are used to sample from the corresponding pseudoposterior -distribution of the model parameters. +supports **ordinal Markov random fields (MRFs)** for discrete data, +**Gaussian graphical models (GGMs)** for continuous data, and **mixed +MRFs** that combine discrete and continuous variables in a single +network. The likelihood is approximated with a pseudolikelihood, and +Markov chain Monte Carlo (MCMC) methods are used to sample from the +corresponding pseudoposterior distribution of the model parameters. ## Main functions The package has two main entry points: - `bgm()` – estimates a single network in a one-sample design. Use - `variable_type = "continuous"` for a GGM, or `"ordinal"` (default) for - an MRF. + `variable_type = "ordinal"` for an MRF, `"continuous"` for a GGM, or a + per-variable vector mixing `"ordinal"`, `"blume-capel"`, and + `"continuous"` for a mixed MRF. - `bgmCompare()` – compares networks between groups in an independent-sample design. diff --git a/Readme.Rmd b/Readme.Rmd index 57e93e0a..eb302f3f 100644 --- a/Readme.Rmd +++ b/Readme.Rmd @@ -38,10 +38,11 @@ library(bgms) **Bayesian analysis of graphical models** The **bgms** package implements Bayesian estimation and model comparison for -graphical models of binary, ordinal, and continuous variables +graphical models of binary, ordinal, continuous, and mixed variables [@MarsmanVandenBerghHaslbeck_2025]. -It supports **ordinal Markov random fields (MRFs)** for discrete data and -**Gaussian graphical models (GGMs)** for continuous data. +It supports **ordinal Markov random fields (MRFs)** for discrete data, +**Gaussian graphical models (GGMs)** for continuous data, and **mixed MRFs** +that combine discrete and continuous variables in a single network. The likelihood is approximated with a pseudolikelihood, and Markov chain Monte Carlo (MCMC) methods are used to sample from the corresponding pseudoposterior distribution of the model parameters. @@ -51,8 +52,9 @@ distribution of the model parameters. The package has two main entry points: - `bgm()` – estimates a single network in a one-sample design. - Use `variable_type = "continuous"` for a GGM, or `"ordinal"` (default) - for an MRF. + Use `variable_type = "ordinal"` for an MRF, `"continuous"` for a GGM, + or a per-variable vector mixing `"ordinal"`, `"blume-capel"`, and + `"continuous"` for a mixed MRF. - `bgmCompare()` – compares networks between groups in an independent-sample design. diff --git a/dev/README.md b/dev/README.md deleted file mode 100644 index 30593bb3..00000000 --- a/dev/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# dev/ - -Developer-only materials excluded from the built R package (via -`.Rbuildignore`). Contains documentation strategy, audit plans, -numerical analyses, test fixtures, and design notes. Nothing in -this directory is shipped to CRAN or installed by users. diff --git a/dev/benchmark_memoization.R b/dev/benchmark_memoization.R deleted file mode 100644 index 6d7b050f..00000000 --- a/dev/benchmark_memoization.R +++ /dev/null @@ -1,65 +0,0 @@ -# ============================================================================= -# Benchmark: Hash-map cache vs Single-entry cache for NUTS memoization -# ============================================================================= -# -# Purpose: Compare two memoization strategies in src/mcmc/mcmc_memoization.h -# 1. Hash-map cache (uses std::unordered_map) -# 2. Single-entry cache (uses memcmp on last theta) -# -# Instructions: -# 1. Install the version you want to test (modify mcmc_memoization.h) -# 2. Run: R CMD INSTALL . -# 3. Run this script: Rscript dev/benchmark_memoization.R -# -# ============================================================================= - -library(bgms) -library(psych) - -# ----------------------------------------------------------------------------- -# Configuration -# ----------------------------------------------------------------------------- -N_REPLICATES <- 5 -SEED <- 123 -WARMUP <- 1000 -ITER <- 500 - -# ----------------------------------------------------------------------------- -# Data setup -# ----------------------------------------------------------------------------- -data(bfi) -bfi_data <- bfi[, 1:25] # 25 personality items, NO na.omit - -cat("=== Memoization Benchmark ===\n") -cat("Data: psych::bfi[, 1:25]\n") -cat("Dimensions:", nrow(bfi_data), "x", ncol(bfi_data), "\n") -cat("Warmup:", WARMUP, "\n") -cat("Iterations:", ITER, "\n") -cat("Replicates:", N_REPLICATES, "\n") -cat("Seed:", SEED, "\n") -cat("==============================\n\n") - -# ----------------------------------------------------------------------------- -# Run benchmark -# ----------------------------------------------------------------------------- -set.seed(SEED) -times <- numeric(N_REPLICATES) - -for (i in 1:N_REPLICATES) { - cat("Replicate", i, "of", N_REPLICATES, "... ") - t0 <- Sys.time() - fit <- bgm(bfi_data, warmup = WARMUP, iter = ITER, verbose = FALSE) - times[i] <- as.numeric(difftime(Sys.time(), t0, units = "secs")) - cat(round(times[i], 1), "s\n") -} - -# ----------------------------------------------------------------------------- -# Results -# ----------------------------------------------------------------------------- -cat("\n=== RESULTS ===\n") -cat("Mean:", round(mean(times), 2), "s\n") -cat("SD:", round(sd(times), 2), "s\n") -cat("Min:", round(min(times), 2), "s\n") -cat("Max:", round(max(times), 2), "s\n") -cat("Times:", paste(round(times, 1), collapse = ", "), "\n") -cat("===============\n") diff --git a/dev/fixtures/compliance/bgm_adhd_mh_bernoulli.rds b/dev/fixtures/compliance/bgm_adhd_mh_bernoulli.rds deleted file mode 100644 index 4946eaf7..00000000 Binary files a/dev/fixtures/compliance/bgm_adhd_mh_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_adhd_nuts_bernoulli.rds b/dev/fixtures/compliance/bgm_adhd_nuts_bernoulli.rds deleted file mode 100644 index ceeb24b0..00000000 Binary files a/dev/fixtures/compliance/bgm_adhd_nuts_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_adhd_nuts_no_edgesel.rds b/dev/fixtures/compliance/bgm_adhd_nuts_no_edgesel.rds deleted file mode 100644 index 2a9cd1d5..00000000 Binary files a/dev/fixtures/compliance/bgm_adhd_nuts_no_edgesel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_adhd_nuts_sbm.rds b/dev/fixtures/compliance/bgm_adhd_nuts_sbm.rds deleted file mode 100644 index a78c233d..00000000 Binary files a/dev/fixtures/compliance/bgm_adhd_nuts_sbm.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_boredom_hmc_bernoulli.rds b/dev/fixtures/compliance/bgm_boredom_hmc_bernoulli.rds deleted file mode 100644 index 4b21b570..00000000 Binary files a/dev/fixtures/compliance/bgm_boredom_hmc_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_boredom_mh_betabern.rds b/dev/fixtures/compliance/bgm_boredom_mh_betabern.rds deleted file mode 100644 index 598d6042..00000000 Binary files a/dev/fixtures/compliance/bgm_boredom_mh_betabern.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_boredom_nuts_bernoulli.rds b/dev/fixtures/compliance/bgm_boredom_nuts_bernoulli.rds deleted file mode 100644 index 64e50a6e..00000000 Binary files a/dev/fixtures/compliance/bgm_boredom_nuts_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_boredom_nuts_no_edgesel.rds b/dev/fixtures/compliance/bgm_boredom_nuts_no_edgesel.rds deleted file mode 100644 index f74c66ec..00000000 Binary files a/dev/fixtures/compliance/bgm_boredom_nuts_no_edgesel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_hmc_bernoulli.rds b/dev/fixtures/compliance/bgm_wenchuan_hmc_bernoulli.rds deleted file mode 100644 index c8af4f3d..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_hmc_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_hmc_blumecapel.rds b/dev/fixtures/compliance/bgm_wenchuan_hmc_blumecapel.rds deleted file mode 100644 index 58f412fe..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_hmc_blumecapel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_mh_bernoulli.rds b/dev/fixtures/compliance/bgm_wenchuan_mh_bernoulli.rds deleted file mode 100644 index f2913c08..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_mh_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_mh_blumecapel.rds b/dev/fixtures/compliance/bgm_wenchuan_mh_blumecapel.rds deleted file mode 100644 index 66984110..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_mh_blumecapel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_bernoulli.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_bernoulli.rds deleted file mode 100644 index be04aa43..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_betabern.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_betabern.rds deleted file mode 100644 index f3ea5378..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_betabern.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel.rds deleted file mode 100644 index f21e41fe..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_baseline1.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_baseline1.rds deleted file mode 100644 index 21565765..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_baseline1.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_impute.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_impute.rds deleted file mode 100644 index e640cd60..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_impute.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_no_edgesel.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_no_edgesel.rds deleted file mode 100644 index 1a6fcb24..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_blumecapel_no_edgesel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_impute.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_impute.rds deleted file mode 100644 index 9af79375..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_impute.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_no_edgesel.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_no_edgesel.rds deleted file mode 100644 index 7fc50a89..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_no_edgesel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_sbm.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_sbm.rds deleted file mode 100644 index 82781879..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_sbm.rds and /dev/null differ diff --git a/dev/fixtures/compliance/bgm_wenchuan_nuts_scaled_prior.rds b/dev/fixtures/compliance/bgm_wenchuan_nuts_scaled_prior.rds deleted file mode 100644 index 9db1d7b1..00000000 Binary files a/dev/fixtures/compliance/bgm_wenchuan_nuts_scaled_prior.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_adhd_mh_bernoulli.rds b/dev/fixtures/compliance/cmp_adhd_mh_bernoulli.rds deleted file mode 100644 index 08089085..00000000 Binary files a/dev/fixtures/compliance/cmp_adhd_mh_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_adhd_nuts_bernoulli.rds b/dev/fixtures/compliance/cmp_adhd_nuts_bernoulli.rds deleted file mode 100644 index bd6b5079..00000000 Binary files a/dev/fixtures/compliance/cmp_adhd_nuts_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_boredom_mh_bernoulli.rds b/dev/fixtures/compliance/cmp_boredom_mh_bernoulli.rds deleted file mode 100644 index a892ca47..00000000 Binary files a/dev/fixtures/compliance/cmp_boredom_mh_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_boredom_nuts_bernoulli.rds b/dev/fixtures/compliance/cmp_boredom_nuts_bernoulli.rds deleted file mode 100644 index 79cd2a6a..00000000 Binary files a/dev/fixtures/compliance/cmp_boredom_nuts_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_hmc_bernoulli.rds b/dev/fixtures/compliance/cmp_wenchuan_hmc_bernoulli.rds deleted file mode 100644 index e80965e8..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_hmc_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_mh_bernoulli.rds b/dev/fixtures/compliance/cmp_wenchuan_mh_bernoulli.rds deleted file mode 100644 index 78498e34..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_mh_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_mh_blumecapel.rds b/dev/fixtures/compliance/cmp_wenchuan_mh_blumecapel.rds deleted file mode 100644 index d01bd2e7..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_mh_blumecapel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_bernoulli.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_bernoulli.rds deleted file mode 100644 index 0af4a410..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_bernoulli.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_betabern.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_betabern.rds deleted file mode 100644 index e84593ed..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_betabern.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_blumecapel.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_blumecapel.rds deleted file mode 100644 index 1f7bb111..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_blumecapel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_blumecapel_no_diffsel.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_blumecapel_no_diffsel.rds deleted file mode 100644 index 13e7302d..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_blumecapel_no_diffsel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_impute.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_impute.rds deleted file mode 100644 index df1f81f1..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_impute.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_main_diffsel.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_main_diffsel.rds deleted file mode 100644 index e130f413..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_main_diffsel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/cmp_wenchuan_nuts_no_diffsel.rds b/dev/fixtures/compliance/cmp_wenchuan_nuts_no_diffsel.rds deleted file mode 100644 index c11a045f..00000000 Binary files a/dev/fixtures/compliance/cmp_wenchuan_nuts_no_diffsel.rds and /dev/null differ diff --git a/dev/fixtures/compliance/manifest.rds b/dev/fixtures/compliance/manifest.rds deleted file mode 100644 index b58e1f97..00000000 Binary files a/dev/fixtures/compliance/manifest.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_ggm_bernoulli_listwise.rds b/dev/fixtures/scaffolding/bgm_ggm_bernoulli_listwise.rds deleted file mode 100644 index e4fb35a7..00000000 Binary files a/dev/fixtures/scaffolding/bgm_ggm_bernoulli_listwise.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_ggm_betabern_listwise_na.rds b/dev/fixtures/scaffolding/bgm_ggm_betabern_listwise_na.rds deleted file mode 100644 index 9cf914c6..00000000 Binary files a/dev/fixtures/scaffolding/bgm_ggm_betabern_listwise_na.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_blumecapel_bernoulli_listwise.rds b/dev/fixtures/scaffolding/bgm_omrf_blumecapel_bernoulli_listwise.rds deleted file mode 100644 index 2ec80f3a..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_blumecapel_bernoulli_listwise.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_blumecapel_betabern_impute_na.rds b/dev/fixtures/scaffolding/bgm_omrf_blumecapel_betabern_impute_na.rds deleted file mode 100644 index 5b71c261..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_blumecapel_betabern_impute_na.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_mixed_bernoulli_listwise.rds b/dev/fixtures/scaffolding/bgm_omrf_mixed_bernoulli_listwise.rds deleted file mode 100644 index 1384922c..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_mixed_bernoulli_listwise.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_ordinal_bernoulli_listwise.rds b/dev/fixtures/scaffolding/bgm_omrf_ordinal_bernoulli_listwise.rds deleted file mode 100644 index 6cb07a8c..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_ordinal_bernoulli_listwise.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_ordinal_betabern_impute_na.rds b/dev/fixtures/scaffolding/bgm_omrf_ordinal_betabern_impute_na.rds deleted file mode 100644 index b727db19..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_ordinal_betabern_impute_na.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_ordinal_no_edgesel.rds b/dev/fixtures/scaffolding/bgm_omrf_ordinal_no_edgesel.rds deleted file mode 100644 index 5323a8cb..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_ordinal_no_edgesel.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/bgm_omrf_ordinal_sbm_listwise.rds b/dev/fixtures/scaffolding/bgm_omrf_ordinal_sbm_listwise.rds deleted file mode 100644 index fa99518b..00000000 Binary files a/dev/fixtures/scaffolding/bgm_omrf_ordinal_sbm_listwise.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_blumecapel_listwise_2groups.rds b/dev/fixtures/scaffolding/compare_blumecapel_listwise_2groups.rds deleted file mode 100644 index d355878f..00000000 Binary files a/dev/fixtures/scaffolding/compare_blumecapel_listwise_2groups.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_mixed_missing_category_bc.rds b/dev/fixtures/scaffolding/compare_mixed_missing_category_bc.rds deleted file mode 100644 index f7fb07e0..00000000 Binary files a/dev/fixtures/scaffolding/compare_mixed_missing_category_bc.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_ordinal_betabern_diff.rds b/dev/fixtures/scaffolding/compare_ordinal_betabern_diff.rds deleted file mode 100644 index 6b4bb92b..00000000 Binary files a/dev/fixtures/scaffolding/compare_ordinal_betabern_diff.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_ordinal_impute_2groups_na.rds b/dev/fixtures/scaffolding/compare_ordinal_impute_2groups_na.rds deleted file mode 100644 index 40a5e9a9..00000000 Binary files a/dev/fixtures/scaffolding/compare_ordinal_impute_2groups_na.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_ordinal_listwise_2groups.rds b/dev/fixtures/scaffolding/compare_ordinal_listwise_2groups.rds deleted file mode 100644 index 9a1f80f1..00000000 Binary files a/dev/fixtures/scaffolding/compare_ordinal_listwise_2groups.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_ordinal_listwise_3groups.rds b/dev/fixtures/scaffolding/compare_ordinal_listwise_3groups.rds deleted file mode 100644 index aabe7cc5..00000000 Binary files a/dev/fixtures/scaffolding/compare_ordinal_listwise_3groups.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_ordinal_missing_category.rds b/dev/fixtures/scaffolding/compare_ordinal_missing_category.rds deleted file mode 100644 index c75cd924..00000000 Binary files a/dev/fixtures/scaffolding/compare_ordinal_missing_category.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/compare_ordinal_no_diffsel.rds b/dev/fixtures/scaffolding/compare_ordinal_no_diffsel.rds deleted file mode 100644 index 25879707..00000000 Binary files a/dev/fixtures/scaffolding/compare_ordinal_no_diffsel.rds and /dev/null differ diff --git a/dev/fixtures/scaffolding/manifest.rds b/dev/fixtures/scaffolding/manifest.rds deleted file mode 100644 index bb156dd1..00000000 Binary files a/dev/fixtures/scaffolding/manifest.rds and /dev/null differ diff --git a/dev/generate_scaffolding_fixtures.R b/dev/generate_scaffolding_fixtures.R deleted file mode 100644 index dec8003d..00000000 --- a/dev/generate_scaffolding_fixtures.R +++ /dev/null @@ -1,778 +0,0 @@ -# ============================================================================== -# Generate Golden-Snapshot Fixtures for R Scaffolding Refactor -# ============================================================================== -# -# Phase A-0 of the scaffolding refactor (dev/scaffolding/plan.md). -# -# This script captures the INTERMEDIATE outputs of the current validation and -# preprocessing functions — check_model(), check_compare_model(), -# reformat_data(), and compare_reformat_data() — for representative inputs. -# These fixtures are mechanical oracles: every refactored validator must -# reproduce exactly the same outputs. -# -# Unlike the refactor fixtures (which capture full model fits), these fixtures -# are FAST — no sampling is involved. They test the code we're about to -# restructure. -# -# Output: dev/fixtures/scaffolding/ -# - One .rds per fixture case -# - A manifest.rds listing all cases -# -# Usage: -# Rscript dev/generate_scaffolding_fixtures.R -# -# ============================================================================== - -library(bgms) - -# These are internal functions — access via ::: -check_model <- bgms:::check_model -check_compare_model <- bgms:::check_compare_model -reformat_data <- bgms:::reformat_data -compare_reformat_data <- bgms:::compare_reformat_data - -fixture_dir <- file.path("dev", "fixtures", "scaffolding") -dir.create(fixture_dir, recursive = TRUE, showWarnings = FALSE) - -set.seed(42) - -# ============================================================================== -# Helper: generate small synthetic datasets -# ============================================================================== - -# Small ordinal dataset (5 variables, 3 categories each: 0, 1, 2) -make_ordinal_data <- function(n = 50, p = 5, max_cat = 2) { - x <- matrix(sample(0:max_cat, n * p, replace = TRUE), nrow = n, ncol = p) - colnames(x) <- paste0("V", seq_len(p)) - x -} - -# Small continuous dataset -make_continuous_data <- function(n = 50, p = 5) { - x <- matrix(rnorm(n * p), nrow = n, ncol = p) - colnames(x) <- paste0("V", seq_len(p)) - x -} - -# Inject NAs into a dataset -inject_nas <- function(x, prop = 0.05) { - n_na <- max(1, floor(nrow(x) * ncol(x) * prop)) - idx <- sample(length(x), n_na) - x[idx] <- NA - x -} - -# Make a dataset where one category is missing from one group -make_missing_category_data <- function() { - # Group 1: categories 0, 1, 2 for all variables - x1 <- matrix(sample(0:2, 30 * 4, replace = TRUE), nrow = 30, ncol = 4) - # Group 2: variable 1 only has categories 0, 1 (no 2) - x2 <- matrix(sample(0:2, 30 * 4, replace = TRUE), nrow = 30, ncol = 4) - x2[, 1] <- sample(0:1, 30, replace = TRUE) - - x <- rbind(x1, x2) - colnames(x) <- paste0("V", 1:4) - group <- c(rep(1L, 30), rep(2L, 30)) - list(x = x, group = group) -} - -# ============================================================================== -# Fixture definition -# ============================================================================== - -# Each fixture stores: -# $id - unique name -# $desc - human-readable description -# $type - "bgm" or "compare" -# $input - the exact arguments passed to the functions -# $check_model - return value of check_model() or check_compare_model() -# $reformat_data - return value of reformat_data() or compare_reformat_data() - -fixtures <- list() - -# --------------------------------------------------------------------------- -# 1. bgm / GGM / continuous / Bernoulli / listwise -# --------------------------------------------------------------------------- -x_cont <- make_continuous_data(n = 50, p = 5) -cm <- check_model( - x = x_cont, - variable_type = "continuous", - baseline_category = 0L, - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 -) -fixtures[["bgm_ggm_bernoulli_listwise"]] <- list( - id = "bgm_ggm_bernoulli_listwise", - desc = "bgm / GGM continuous / Bernoulli / listwise / no NAs", - type = "bgm", - input = list( - x = x_cont, - variable_type = "continuous", - baseline_category = 0L, - na_action = "listwise", - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 - ), - check_model = cm, - reformat_data = NULL # GGM path doesn't call reformat_data() -) - -# --------------------------------------------------------------------------- -# 2. bgm / GGM / continuous / Beta-Bernoulli / listwise / with NAs -# --------------------------------------------------------------------------- -x_cont_na <- inject_nas(make_continuous_data(n = 60, p = 4)) -cm2 <- check_model( - x = x_cont_na, - variable_type = "continuous", - baseline_category = 0L, - edge_selection = TRUE, - edge_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - inclusion_probability = 0.5 -) -fixtures[["bgm_ggm_betabern_listwise_na"]] <- list( - id = "bgm_ggm_betabern_listwise_na", - desc = "bgm / GGM continuous / Beta-Bernoulli / listwise / with NAs", - type = "bgm", - input = list( - x = x_cont_na, - variable_type = "continuous", - baseline_category = 0L, - na_action = "listwise", - edge_selection = TRUE, - edge_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1 - ), - check_model = cm2, - reformat_data = NULL # GGM path doesn't call reformat_data() -) - -# --------------------------------------------------------------------------- -# 3. bgm / OMRF / ordinal / Bernoulli / listwise -# --------------------------------------------------------------------------- -x_ord <- make_ordinal_data(n = 50, p = 5, max_cat = 3) -cm3 <- check_model( - x = x_ord, - variable_type = "ordinal", - baseline_category = 0L, - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 -) -rd3 <- reformat_data( - x = x_ord, - na_action = "listwise", - variable_bool = cm3$variable_bool, - baseline_category = cm3$baseline_category -) -fixtures[["bgm_omrf_ordinal_bernoulli_listwise"]] <- list( - id = "bgm_omrf_ordinal_bernoulli_listwise", - desc = "bgm / OMRF ordinal / Bernoulli / listwise", - type = "bgm", - input = list( - x = x_ord, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 - ), - check_model = cm3, - reformat_data = rd3 -) - -# --------------------------------------------------------------------------- -# 4. bgm / OMRF / ordinal / Beta-Bernoulli / impute / with NAs -# --------------------------------------------------------------------------- -x_ord_na <- inject_nas(make_ordinal_data(n = 60, p = 5, max_cat = 3)) -cm4 <- check_model( - x = x_ord_na, - variable_type = "ordinal", - baseline_category = 0L, - edge_selection = TRUE, - edge_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - inclusion_probability = 0.5 -) -# Impute path: need a fresh copy since reformat_data mutates x -x_ord_na_copy <- x_ord_na -rd4 <- reformat_data( - x = x_ord_na_copy, - na_action = "impute", - variable_bool = cm4$variable_bool, - baseline_category = cm4$baseline_category -) -fixtures[["bgm_omrf_ordinal_betabern_impute_na"]] <- list( - id = "bgm_omrf_ordinal_betabern_impute_na", - desc = "bgm / OMRF ordinal / Beta-Bernoulli / impute / with NAs", - type = "bgm", - input = list( - x = x_ord_na, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "impute", - edge_selection = TRUE, - edge_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1 - ), - check_model = cm4, - reformat_data = rd4 -) - -# --------------------------------------------------------------------------- -# 5. bgm / OMRF / ordinal / SBM / listwise -# --------------------------------------------------------------------------- -x_ord5 <- make_ordinal_data(n = 50, p = 6, max_cat = 3) -cm5 <- check_model( - x = x_ord5, - variable_type = "ordinal", - baseline_category = 0L, - edge_selection = TRUE, - edge_prior = "Stochastic-Block", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - beta_bernoulli_alpha_between = 1, - beta_bernoulli_beta_between = 1, - dirichlet_alpha = 1, - lambda = 1, - inclusion_probability = 0.5 -) -rd5 <- reformat_data( - x = x_ord5, - na_action = "listwise", - variable_bool = cm5$variable_bool, - baseline_category = cm5$baseline_category -) -fixtures[["bgm_omrf_ordinal_sbm_listwise"]] <- list( - id = "bgm_omrf_ordinal_sbm_listwise", - desc = "bgm / OMRF ordinal / SBM / listwise", - type = "bgm", - input = list( - x = x_ord5, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - edge_selection = TRUE, - edge_prior = "Stochastic-Block", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - beta_bernoulli_alpha_between = 1, - beta_bernoulli_beta_between = 1, - dirichlet_alpha = 1, - lambda = 1 - ), - check_model = cm5, - reformat_data = rd5 -) - -# --------------------------------------------------------------------------- -# 6. bgm / OMRF / blume-capel / Bernoulli / listwise / custom baseline -# --------------------------------------------------------------------------- -# BC variables with scores starting at 1 (not 0) — triggers recoding -x_bc <- make_ordinal_data(n = 50, p = 4, max_cat = 4) -x_bc <- x_bc + 1L # shift to 1-based scores -vtype6 <- rep("blume-capel", 4) -cm6 <- check_model( - x = x_bc, - variable_type = vtype6, - baseline_category = 3L, - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 -) -rd6 <- reformat_data( - x = x_bc, - na_action = "listwise", - variable_bool = cm6$variable_bool, - baseline_category = cm6$baseline_category -) -fixtures[["bgm_omrf_blumecapel_bernoulli_listwise"]] <- list( - id = "bgm_omrf_blumecapel_bernoulli_listwise", - desc = "bgm / OMRF blume-capel / Bernoulli / listwise / custom baseline", - type = "bgm", - input = list( - x = x_bc, - variable_type = vtype6, - baseline_category = 3L, - na_action = "listwise", - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 - ), - check_model = cm6, - reformat_data = rd6 -) - -# --------------------------------------------------------------------------- -# 7. bgm / OMRF / blume-capel / Beta-Bernoulli / impute / with NAs -# --------------------------------------------------------------------------- -x_bc_na <- inject_nas(x_bc) -cm7 <- check_model( - x = x_bc_na, - variable_type = vtype6, - baseline_category = 3L, - edge_selection = TRUE, - edge_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1, - inclusion_probability = 0.5 -) -x_bc_na_copy <- x_bc_na -rd7 <- reformat_data( - x = x_bc_na_copy, - na_action = "impute", - variable_bool = cm7$variable_bool, - baseline_category = cm7$baseline_category -) -fixtures[["bgm_omrf_blumecapel_betabern_impute_na"]] <- list( - id = "bgm_omrf_blumecapel_betabern_impute_na", - desc = "bgm / OMRF blume-capel / Beta-Bernoulli / impute / with NAs", - type = "bgm", - input = list( - x = x_bc_na, - variable_type = vtype6, - baseline_category = 3L, - na_action = "impute", - edge_selection = TRUE, - edge_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1 - ), - check_model = cm7, - reformat_data = rd7 -) - -# --------------------------------------------------------------------------- -# 8. bgm / OMRF / mixed ordinal+BC / Bernoulli / listwise -# --------------------------------------------------------------------------- -x_mixed <- make_ordinal_data(n = 50, p = 5, max_cat = 4) -vtype8 <- c("ordinal", "ordinal", "blume-capel", "ordinal", "blume-capel") -bcat8 <- c(0L, 0L, 2L, 0L, 2L) -cm8 <- check_model( - x = x_mixed, - variable_type = vtype8, - baseline_category = bcat8, - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 -) -rd8 <- reformat_data( - x = x_mixed, - na_action = "listwise", - variable_bool = cm8$variable_bool, - baseline_category = cm8$baseline_category -) -fixtures[["bgm_omrf_mixed_bernoulli_listwise"]] <- list( - id = "bgm_omrf_mixed_bernoulli_listwise", - desc = "bgm / OMRF mixed ordinal+BC / Bernoulli / listwise", - type = "bgm", - input = list( - x = x_mixed, - variable_type = vtype8, - baseline_category = bcat8, - na_action = "listwise", - edge_selection = TRUE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 - ), - check_model = cm8, - reformat_data = rd8 -) - -# --------------------------------------------------------------------------- -# 9. bgmCompare / ordinal / listwise / 2 groups -# --------------------------------------------------------------------------- -x_comp9 <- make_ordinal_data(n = 60, p = 4, max_cat = 2) -group9 <- rep(1:2, each = 30) -cm9 <- check_compare_model( - x = x_comp9, - y = NULL, - group_indicator = group9, - difference_selection = TRUE, - variable_type = "ordinal", - baseline_category = 0L, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -rd9 <- compare_reformat_data( - x = cm9$x, - group = cm9$group_indicator, - na_action = "listwise", - variable_bool = cm9$variable_bool, - baseline_category = cm9$baseline_category -) -fixtures[["compare_ordinal_listwise_2groups"]] <- list( - id = "compare_ordinal_listwise_2groups", - desc = "bgmCompare / ordinal / listwise / 2 groups", - type = "compare", - input = list( - x = x_comp9, - group_indicator = group9, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - difference_selection = TRUE, - difference_prior = "Bernoulli", - difference_probability = 0.5 - ), - check_model = cm9, - reformat_data = rd9 -) - -# --------------------------------------------------------------------------- -# 10. bgmCompare / ordinal / impute / 2 groups + NAs -# --------------------------------------------------------------------------- -x_comp10 <- inject_nas(make_ordinal_data(n = 60, p = 4, max_cat = 2)) -group10 <- rep(1:2, each = 30) -cm10 <- check_compare_model( - x = x_comp10, - y = NULL, - group_indicator = group10, - difference_selection = TRUE, - variable_type = "ordinal", - baseline_category = 0L, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -x_comp10_copy <- cm10$x -rd10 <- compare_reformat_data( - x = x_comp10_copy, - group = cm10$group_indicator, - na_action = "impute", - variable_bool = cm10$variable_bool, - baseline_category = cm10$baseline_category -) -fixtures[["compare_ordinal_impute_2groups_na"]] <- list( - id = "compare_ordinal_impute_2groups_na", - desc = "bgmCompare / ordinal / impute / 2 groups + NAs", - type = "compare", - input = list( - x = x_comp10, - group_indicator = group10, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "impute", - difference_selection = TRUE, - difference_prior = "Bernoulli", - difference_probability = 0.5 - ), - check_model = cm10, - reformat_data = rd10 -) - -# --------------------------------------------------------------------------- -# 11. bgmCompare / blume-capel / listwise / 2 groups -# --------------------------------------------------------------------------- -x_comp11 <- make_ordinal_data(n = 60, p = 4, max_cat = 4) -group11 <- rep(1:2, each = 30) -vtype11 <- rep("blume-capel", 4) -cm11 <- check_compare_model( - x = x_comp11, - y = NULL, - group_indicator = group11, - difference_selection = TRUE, - variable_type = vtype11, - baseline_category = 2L, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -rd11 <- compare_reformat_data( - x = cm11$x, - group = cm11$group_indicator, - na_action = "listwise", - variable_bool = cm11$variable_bool, - baseline_category = cm11$baseline_category -) -fixtures[["compare_blumecapel_listwise_2groups"]] <- list( - id = "compare_blumecapel_listwise_2groups", - desc = "bgmCompare / blume-capel / listwise / 2 groups", - type = "compare", - input = list( - x = x_comp11, - group_indicator = group11, - variable_type = vtype11, - baseline_category = 2L, - na_action = "listwise", - difference_selection = TRUE, - difference_prior = "Bernoulli", - difference_probability = 0.5 - ), - check_model = cm11, - reformat_data = rd11 -) - -# --------------------------------------------------------------------------- -# 12. bgmCompare / ordinal / listwise / >2 groups -# --------------------------------------------------------------------------- -x_comp12 <- make_ordinal_data(n = 90, p = 4, max_cat = 2) -group12 <- rep(1:3, each = 30) -cm12 <- check_compare_model( - x = x_comp12, - y = NULL, - group_indicator = group12, - difference_selection = TRUE, - variable_type = "ordinal", - baseline_category = 0L, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -rd12 <- compare_reformat_data( - x = cm12$x, - group = cm12$group_indicator, - na_action = "listwise", - variable_bool = cm12$variable_bool, - baseline_category = cm12$baseline_category -) -fixtures[["compare_ordinal_listwise_3groups"]] <- list( - id = "compare_ordinal_listwise_3groups", - desc = "bgmCompare / ordinal / listwise / 3 groups", - type = "compare", - input = list( - x = x_comp12, - group_indicator = group12, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - difference_selection = TRUE, - difference_prior = "Bernoulli", - difference_probability = 0.5 - ), - check_model = cm12, - reformat_data = rd12 -) - -# --------------------------------------------------------------------------- -# 13. bgmCompare / ordinal / listwise / categories missing in 1 group -# --------------------------------------------------------------------------- -mc_data <- make_missing_category_data() -cm13 <- check_compare_model( - x = mc_data$x, - y = NULL, - group_indicator = mc_data$group, - difference_selection = TRUE, - variable_type = "ordinal", - baseline_category = 0L, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -rd13 <- compare_reformat_data( - x = cm13$x, - group = cm13$group_indicator, - na_action = "listwise", - variable_bool = cm13$variable_bool, - baseline_category = cm13$baseline_category -) -fixtures[["compare_ordinal_missing_category"]] <- list( - id = "compare_ordinal_missing_category", - desc = "bgmCompare / ordinal / listwise / category missing in group 2", - type = "compare", - input = list( - x = mc_data$x, - group_indicator = mc_data$group, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - difference_selection = TRUE, - difference_prior = "Bernoulli", - difference_probability = 0.5 - ), - check_model = cm13, - reformat_data = rd13 -) - -# --------------------------------------------------------------------------- -# 14. bgmCompare / mixed ordinal+BC / listwise / missing categories + BC -# --------------------------------------------------------------------------- -x_comp14 <- mc_data$x -# Extend to 5 vars with an extra BC variable -x_extra <- matrix(sample(0:4, 60, replace = TRUE), nrow = 60, ncol = 1) -x_comp14 <- cbind(x_comp14, x_extra) -colnames(x_comp14) <- paste0("V", 1:5) -group14 <- mc_data$group -vtype14 <- c("ordinal", "ordinal", "ordinal", "ordinal", "blume-capel") -bcat14 <- c(0L, 0L, 0L, 0L, 2L) -cm14 <- check_compare_model( - x = x_comp14, - y = NULL, - group_indicator = group14, - difference_selection = TRUE, - variable_type = vtype14, - baseline_category = bcat14, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -rd14 <- compare_reformat_data( - x = cm14$x, - group = cm14$group_indicator, - na_action = "listwise", - variable_bool = cm14$variable_bool, - baseline_category = cm14$baseline_category -) -fixtures[["compare_mixed_missing_category_bc"]] <- list( - id = "compare_mixed_missing_category_bc", - desc = "bgmCompare / mixed ord+BC / listwise / missing categories + BC", - type = "compare", - input = list( - x = x_comp14, - group_indicator = group14, - variable_type = vtype14, - baseline_category = bcat14, - na_action = "listwise", - difference_selection = TRUE, - difference_prior = "Bernoulli", - difference_probability = 0.5 - ), - check_model = cm14, - reformat_data = rd14 -) - -# --------------------------------------------------------------------------- -# 15. bgm / OMRF / ordinal / no edge selection -# --------------------------------------------------------------------------- -x_ord15 <- make_ordinal_data(n = 50, p = 4, max_cat = 2) -cm15 <- check_model( - x = x_ord15, - variable_type = "ordinal", - baseline_category = 0L, - edge_selection = FALSE, - edge_prior = "Bernoulli", - inclusion_probability = 0.5 -) -rd15 <- reformat_data( - x = x_ord15, - na_action = "listwise", - variable_bool = cm15$variable_bool, - baseline_category = cm15$baseline_category -) -fixtures[["bgm_omrf_ordinal_no_edgesel"]] <- list( - id = "bgm_omrf_ordinal_no_edgesel", - desc = "bgm / OMRF ordinal / no edge selection", - type = "bgm", - input = list( - x = x_ord15, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - edge_selection = FALSE - ), - check_model = cm15, - reformat_data = rd15 -) - -# --------------------------------------------------------------------------- -# 16. bgmCompare / ordinal / Beta-Bernoulli difference prior -# --------------------------------------------------------------------------- -x_comp16 <- make_ordinal_data(n = 60, p = 4, max_cat = 2) -group16 <- rep(1:2, each = 30) -cm16 <- check_compare_model( - x = x_comp16, - y = NULL, - group_indicator = group16, - difference_selection = TRUE, - variable_type = "ordinal", - baseline_category = 0L, - difference_scale = 1, - difference_prior = "Beta-Bernoulli", - difference_probability = 0.5, - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1 -) -rd16 <- compare_reformat_data( - x = cm16$x, - group = cm16$group_indicator, - na_action = "listwise", - variable_bool = cm16$variable_bool, - baseline_category = cm16$baseline_category -) -fixtures[["compare_ordinal_betabern_diff"]] <- list( - id = "compare_ordinal_betabern_diff", - desc = "bgmCompare / ordinal / Beta-Bernoulli difference prior", - type = "compare", - input = list( - x = x_comp16, - group_indicator = group16, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - difference_selection = TRUE, - difference_prior = "Beta-Bernoulli", - beta_bernoulli_alpha = 1, - beta_bernoulli_beta = 1 - ), - check_model = cm16, - reformat_data = rd16 -) - -# --------------------------------------------------------------------------- -# 17. bgmCompare / no difference selection -# --------------------------------------------------------------------------- -x_comp17 <- make_ordinal_data(n = 60, p = 4, max_cat = 2) -group17 <- rep(1:2, each = 30) -cm17 <- check_compare_model( - x = x_comp17, - y = NULL, - group_indicator = group17, - difference_selection = FALSE, - variable_type = "ordinal", - baseline_category = 0L, - difference_scale = 1, - difference_prior = "Bernoulli", - difference_probability = 0.5 -) -rd17 <- compare_reformat_data( - x = cm17$x, - group = cm17$group_indicator, - na_action = "listwise", - variable_bool = cm17$variable_bool, - baseline_category = cm17$baseline_category -) -fixtures[["compare_ordinal_no_diffsel"]] <- list( - id = "compare_ordinal_no_diffsel", - desc = "bgmCompare / ordinal / no difference selection", - type = "compare", - input = list( - x = x_comp17, - group_indicator = group17, - variable_type = "ordinal", - baseline_category = 0L, - na_action = "listwise", - difference_selection = FALSE - ), - check_model = cm17, - reformat_data = rd17 -) - -# ============================================================================== -# Save all fixtures -# ============================================================================== - -manifest <- data.frame( - id = vapply(fixtures, `[[`, character(1), "id"), - desc = vapply(fixtures, `[[`, character(1), "desc"), - type = vapply(fixtures, `[[`, character(1), "type"), - stringsAsFactors = FALSE -) - -cat("\nSaving", nrow(manifest), "scaffolding fixtures to", fixture_dir, "\n\n") - -for (f in fixtures) { - path <- file.path(fixture_dir, paste0(f$id, ".rds")) - saveRDS(f, path) - cat(sprintf(" [%s] %s\n", f$id, f$desc)) -} - -saveRDS(manifest, file.path(fixture_dir, "manifest.rds")) -cat("\nManifest saved. Done.\n") diff --git a/dev/mixedmrf/implementation_plan.md b/dev/mixedmrf/implementation_plan.md deleted file mode 100644 index 562a665c..00000000 --- a/dev/mixedmrf/implementation_plan.md +++ /dev/null @@ -1,1138 +0,0 @@ -# Mixed MRF — Implementation Plan - -**Date:** 2026-02-25 (updated 2026-03-04; review amendments applied 2026-03-04) -**Branch:** `ggm_mixed` (PR #78) -**Goal:** Build a monolithic `MixedMRFModel` in C++ that supports both -conditional and marginal pseudo-likelihood, with and without edge selection. - -**Reference code:** `MaartenMarsman/mixedGM` repository (R + Rcpp prototype) -**Theory:** `dev/plans/mixedMRF/A_Mixed_Graphical_Model_for_Continuous_and_Ordinal_Variables (1).pdf` - ---- - -## Prototype Status (mixedGM) - -The `mixedGM` package (`/Users/maartenmarsman/Documents/GitHub/mixedGM`) provides -a complete R + Rcpp prototype with: - -### Implemented & Tested Components - -| Component | Location | Status | -|-----------|----------|--------| -| Conditional OMRF likelihood | `src/log_likelihoods.cpp` | ✅ C++ (Rcpp) | -| Marginal OMRF likelihood | `src/log_likelihoods.cpp` | ✅ C++ (Rcpp) | -| Conditional GGM likelihood | `src/log_likelihoods.cpp` | ✅ C++ (Rcpp) | -| Θ computation | `src/log_likelihoods.cpp` | ✅ C++ (Rcpp) | -| Gibbs data generator | `src/mixed_gibbs.cpp` | ✅ C++ (Rcpp) | -| Cholesky update/downdate | `src/cholupdate.cpp` | ✅ C++ (Rcpp) | -| MH parameter updates | `R/cond_*_mh_update_functions.R` | R only | -| Edge selection | `R/cond_*_mh_update_functions.R` | R only | -| Main sampler loop | `R/mixed_sampler.R` | R only | -| Stan exact-likelihood model | `inst/stan/mixed_mrf_exact.stan` | ✅ Stan | -| Unit tests (likelihood) | `tests/testthat/test-likelihood-correctness.R` | ✅ | -| Parameter recovery tests | `tests/testthat/test-parameter-recovery.R` | ✅ | -| Edge selection tests | `tests/testthat/test-edge-selection.R` | ✅ | - -### Bug Fixes Already Applied - -The following issues identified in the plan reviews have been fixed in mixedGM: - -| Issue | Status | Location | -|-------|--------|----------| -| Missing factor of 2 in Θ | ✅ Fixed | `rcpp_compute_Theta()`: `Kxx + 2.0 * Kxy * Sigma_yy * Kxy.t()` | -| `dnorm` without `log=TRUE` | ✅ Fixed | All edge selection functions use `log = TRUE` | -| Cache invalidation | ✅ Fixed | `update_Kyy_cache()` called after every Kyy change | -| Marginal PL Kyy acceptance | ✅ Fixed | Includes all p marginal OMRF terms | - -### Validation Infrastructure - -- **Stan exact model** (`inst/stan/mixed_mrf_exact.stan`): Enumerates all ordinal - configurations for p ≤ 5 as gold-standard posterior. Serves as non-circular - validation against the pseudolikelihood. -- **Simulation study** (`dev/simulation_study_plan.md`): 54-cell design covering - p ∈ {5,10,15}, q ∈ {5,10,15}, n ∈ {250,500,1000}. - ---- - -## Table of contents - -1. [Model overview](#1-model-overview) -2. [Two pseudo-likelihood approaches](#2-two-pseudo-likelihood-approaches) -3. [Parameter groups and update schedule](#3-parameter-groups-and-update-schedule) -4. [File layout](#4-file-layout) -5. [Implementation phases](#5-implementation-phases) -6. [Phase A — Skeleton and data structures](#phase-a--skeleton-and-data-structures) -7. [Phase B — Conditional pseudo-likelihood](#phase-b--conditional-pseudo-likelihood) -8. [Phase C — Marginal pseudo-likelihood](#phase-c--marginal-pseudo-likelihood) -9. [Phase D — Edge selection](#phase-d--edge-selection) -10. [Phase E — R interface and integration](#phase-e--r-interface-and-integration) -11. [Phase F — Warmup, adaptation, and diagnostics](#phase-f--warmup-adaptation-and-diagnostics) -12. [Phase G — Simulation and prediction](#phase-g--simulation-and-prediction) -13. [Testing strategy](#testing-strategy) -14. [Reuse inventory](#reuse-inventory) -15. [Risk register](#risk-register) - ---- - -## 1. Model overview - -The mixed MRF models the joint distribution of $p$ ordinal variables -$x$ ($x_s \in \{0, 1, \ldots, C_s\}$) and $q$ continuous variables $y$: - -$$\log f(x, y) \propto \sum_s \mu_{x,s}(x_s) + x^\top K_{xx}\, x - - \tfrac{1}{2}(y - \mu_y)^\top K_{yy}\,(y - \mu_y) - + 2\, x^\top K_{xy}\, y$$ - -Parameters: - -| Symbol | Storage | Dimension | Role | -|--------|---------|-----------|------| -| $\mu_x$ | `mux_` | $p \times \max(C_s)$ | Ordinal thresholds (main effects) | -| $K_{xx}$ | `Kxx_` | $p \times p$ symmetric, zero diag | Discrete pairwise interactions | -| $K_{yy}$ | `Kyy_` | $q \times q$ SPD | Continuous precision matrix | -| $K_{xy}$ | `Kxy_` | $p \times q$ | Cross-type interactions | -| $\mu_y$ | `muy_` | $q$-vector | Continuous means | -| $G$ | `edge_indicators_` | $(p+q) \times (p+q)$ | Edge inclusion indicators | - -The factor 2 on $x^\top K_{xy}\, y$ reflects a symmetric parameterization: -the bilinear coupling between ordinal variable $x_s$ and continuous variable -$y_j$ contributes once as $(x_s, y_j)$ and once as $(y_j, x_s)$ in the -sufficient statistics sum over node pairs. Absorbing both contributions into -a single $K_{xy}$ matrix introduces the factor 2. This is a pure convention -(not a free parameter) and must be applied consistently in all likelihoods, -conditional means, and $\Theta$ computations. - -### Conventions carried into C++ - -- $K_{yy}$ stores the **positive-definite precision** $\Sigma_{yy}^{-1}$. The - LaTeX note writes $K_{yy} = -\tfrac{1}{2}\Sigma_{yy}^{-1}$; we absorb the - $-\tfrac{1}{2}$ into the log-density and always work with SPD matrices. -- The joint density is written as $-\tfrac{1}{2}(y - \mu_y)^\top K_{yy} - (y - \mu_y)$ so that $\mu_y$ is the literal continuous mean. This matches - the R prototype and keeps the conditional mean expression short. -- With this convention, all marginal rest-scores that arise from integrating - out $y$ must include $K_{yy}^{-1}$ explicitly. Whenever the LaTeX note shows - $\Sigma_{yy}$, substitute $K_{yy}^{-1}$. - ---- - -## 2. Two pseudo-likelihood approaches - -Both approaches approximate the intractable joint $f(x, y)$ using -pseudo-likelihoods. They share the same GGM part but differ in how the -discrete pseudo-likelihood handles the coupling to $y$. - -### 2.1 Conditional pseudo-likelihood - -$$\text{PL}_{\text{cond}}(x, y) = - \underbrace{f(y \mid x)}_{\text{conditional GGM}} \cdot - \prod_{s=1}^{p} \underbrace{f(x_s \mid x_{-s}, y)}_{\text{conditional OMRF}}$$ - -**Conditional OMRF** — full conditional of $x_s$ given $x_{-s}$ **and** $y$: - -$$r_s = x_{-s}^\top K_{xx,-s,s} + 2\,y^\top K_{yx,s}$$ -$$\log f(x_s = c \mid x_{-s}, y) = \mu_{x,s,c} + c \cdot r_s - - \log\!\Bigl(1 + \sum_{c'=1}^{C_s} \exp(\mu_{x,s,c'} + c' \cdot r_s)\Bigr)$$ - -The rest-score $r_s$ depends on $K_{xx}$ and $K_{xy}$ but **not** on $K_{yy}$. - -**Conditional GGM** — $y \mid x$ is multivariate Gaussian: - -$$y \mid x \sim N\bigl(\mu_y + 2\,x\,K_{xy}\,K_{yy}^{-1},\; K_{yy}^{-1}\bigr)$$ -$$\log f(y \mid x) = \frac{n}{2}\log|K_{yy}| - - \frac{1}{2}\sum_{v=1}^{n}(y_v - M_v)^\top K_{yy}\,(y_v - M_v)$$ - -where $M = \mathbf{1}\mu_y^\top + 2\,x\,K_{xy}\,K_{yy}^{-1}$ is the -$n \times q$ conditional mean matrix. - -### 2.2 Marginal pseudo-likelihood - -$$\text{PL}_{\text{marg}}(x, y) = - \underbrace{f(y \mid x)}_{\text{conditional GGM}} \cdot - \prod_{s=1}^{p} \underbrace{f(x_s \mid x_{-s})}_{\text{marginal OMRF}}$$ - -**Marginal OMRF** — full conditional of $x_s$ after integrating out $y$: - -$$\Theta = K_{xx} + 2\,K_{xy}\,K_{yy}^{-1}\,K_{yx}$$ -$$r_s = \bigl(x^\top \Theta_{\cdot,s}\bigr) - x_s\,\Theta_{ss} + 2\,(K_{xy})_{s\cdot}\,K_{yy}^{-1}\,\mu_y$$ - -The self-interaction $x_s\,\Theta_{ss}$ must be subtracted because the rest-score -conditions only on $x_{-s}$. In practice: -- Compute $r^{\text{row}} = x^\top \Theta_{\cdot,s}$ (matrix-vector product over all observations) -- Subtract `x_dbl.col(s) * Theta_(s,s)` from the result -- Add the scalar bias `2.0 * arma::dot(Kxy_.row(s), Kyy_inv_ * muy_)` (same for every observation) - -This matches the mixedGM implementation in `rcpp_log_pl_marginal_omrf` exactly. - -Same categorical form as conditional PL, but the effective interaction -matrix $\Theta$ absorbs the continuous variables. This means: -- Changing $K_{yy}$ or $K_{xy}$ requires recomputing $\Theta$ -- Changing $\mu_y$ changes all rest-scores through the - $K_{yy}^{-1}\,\mu_y$ term - -### 2.3 Which parameters affect which likelihoods - -| Parameter | Conditional OMRF | Conditional GGM | Marginal OMRF | -|-----------|:---:|:---:|:---:| -| $\mu_x$ | ✓ | | ✓ | -| $K_{xx}$ | ✓ | | ✓ (via $\Theta$) | -| $K_{yy}$ | | ✓ | ✓ (via $\Theta$) | -| $K_{xy}$ | ✓ | ✓ | ✓ (via $\Theta$) | -| $\mu_y$ | | ✓ | ✓ | - ---- - -## 3. Parameter groups and update schedule - -Each MCMC iteration updates 5 parameter groups in sequence, matching -the R prototype. The table shows which log-likelihood components enter -the MH acceptance ratio for each group. - -### 3.1 Conditional PL mode - -| Step | Parameter | Components in acceptance ratio | -|------|-----------|-------------------------------| -| 1 | $\mu_{x,s,c}$ (one threshold) | `cond_omrf(s)` + prior | -| 2 | $\mu_{y,j}$ (one mean) | `cond_ggm()` + prior | -| 3 | $K_{xx,ij}$ (one pair) | `cond_omrf(i) + cond_omrf(j)` + prior | -| 4 | $K_{yy}$ (one element) | `cond_ggm()` + prior (Cholesky proposal) | -| 5 | $K_{xy,ij}$ (one element) | `cond_omrf(i) + cond_ggm()` + prior | - -### 3.2 Marginal PL mode - -| Step | Parameter | Components in acceptance ratio | -|------|-----------|-------------------------------| -| 1 | $\mu_{x,s,c}$ | `marg_omrf(s)` + prior | -| 2 | $\mu_{y,j}$ | `cond_ggm()` + $\sum_s$ `marg_omrf(s)` + prior | -| 3 | $K_{xx,ij}$ | `marg_omrf(i) + marg_omrf(j)` + prior | -| 4 | $K_{yy}$ | `cond_ggm()` + $\sum_s$ `marg_omrf(s)` + prior (Cholesky proposal) | -| 5 | $K_{xy,ij}$ | $\sum_s$ `marg_omrf(s)` + `cond_ggm()` + prior | - -**Key difference:** In marginal mode, updating $\mu_y$ requires evaluating -ALL $p$ marginal OMRF terms because $\mu_y$ enters every rest-score. The -same is true for $K_{yy}$ and every $K_{xy,ij}$ because $\Theta$ changes -globally with those parameters. Marginal PL is therefore much more -expensive per iteration. - -### 3.3 Edge selection (post-warmup) - -Three independent edge-selection sweeps, same in both PL modes: - -| Edge type | Indicator | RJ proposal | Acceptance components | -|-----------|-----------|-------------|----------------------| -| Discrete-discrete | $G_{xx}$ | Toggle + spike-and-slab | `omrf(i) + omrf(j)` | -| Continuous-continuous | $G_{yy}$ | Cholesky-based toggle | `cond_ggm()` | -| Cross | $G_{xy}$ | Toggle + spike-and-slab | `omrf(i) + cond_ggm()` | - ---- - -## 4. File layout - -``` -src/ - models/ - mixed/ - mixed_mrf_model.h # Class declaration (follows GGMModel pattern) - mixed_mrf_model.cpp # Constructor, clone, vectorization - mixed_mrf_likelihoods.cpp # Port from mixedGM/src/log_likelihoods.cpp - mixed_mrf_metropolis.cpp # Port MH updates from mixedGM R code - mixed_mrf_edge_selection.cpp # Port edge selection from mixedGM R code - mixed_mrf_cholesky.cpp # Cholesky permute/R() (port from mixedGM R) - sample_mixed.cpp # Rcpp interface (copy sample_ggm.cpp pattern) - mrf_simulation.cpp # Extend with mixed_gibbs_generate() - mrf_prediction.cpp # Add mixed MRF prediction -R/ - bgm.R # Extend bgm() to dispatch mixed data - validate_data.R # Add mixed data validation - validate_model.R # Add mixed model validation -tests/ - testthat/ - test-mixed-mrf-likelihood.R # Likelihood correctness vs R prototype - test-mixed-mrf-sampling.R # Recovery tests - test-mixed-mrf-edge-sel.R # Edge selection tests -``` - -`configure` + `inst/generate_makevars_sources.R` already glob every `.cpp` -under `src/`, but Phase A must still run the script (or re-run `configure`) -so the new translation units show up in `sources.mk`/`Makevars`. Likewise, -add the new Rcpp export to `src/RcppExports.cpp`, `R/RcppExports.R`, and -`NAMESPACE` when `sample_mixed.cpp` lands. - ---- - -## 5. Implementation phases - -| Phase | What | Depends on | Deliverable | -|-------|------|------------|-------------| -| **A** | Skeleton: class, data structures, `BaseModel` overrides | — | Compiles, no sampling | -| **B** | Conditional PL: all 5 MH updates, no edge selection | A | Recovery test passes (cond PL, estimation only) | -| **C** | Marginal PL: $\Theta$ caching, marginal OMRF, $\mu_y$ full sweep | B | Recovery test passes (marg PL, estimation only) | -| **D** | Edge selection: 3 RJ sweeps | B | Structure recovery test passes | -| **E** | R interface: `bgm()` dispatch, output formatting | B | End-to-end `bgm(mixed_data)` works | -| **F** | Warmup schedule, adaptation, diagnostics | E | Full warmup pipeline | -| **G** | Simulation and prediction | E | `simulate.bgms` and `predict.bgms` for mixed | - ---- - -## Phase A — Skeleton and data structures - -### A.1 Create `mixed_mrf_model.h` - -```cpp -class MixedMRFModel : public BaseModel { -public: - // Construction - MixedMRFModel( - const arma::imat& x, // n × p discrete (0-based categories) - const arma::mat& y, // n × q continuous - const arma::ivec& num_categories, // p-vector - bool edge_selection, - const std::string& pseudolikelihood, // "conditional" or "marginal" - int seed - ); - - MixedMRFModel(const MixedMRFModel& other); - - // BaseModel overrides (all 13 pure virtuals) - void do_one_metropolis_step(int iteration = -1) override; - void update_edge_indicators() override; - size_t parameter_dimension() const override; - arma::vec get_vectorized_parameters() const override; - void set_vectorized_parameters(const arma::vec& params) override; - arma::ivec get_vectorized_indicator_parameters() override; - size_t full_parameter_dimension() const override; - arma::vec get_full_vectorized_parameters() const override; - void set_seed(int seed) override; - std::unique_ptr clone() const override; - SafeRNG& get_rng() override; - const arma::imat& get_edge_indicators() const override; - arma::mat& get_inclusion_probability() override; - int get_num_variables() const override; - int get_num_pairwise() const override; - void prepare_iteration() override; - void set_edge_selection_active(bool active) override; - void initialize_graph() override; - void init_metropolis_adaptation(const WarmupSchedule& schedule) override; - void tune_proposal_sd(int iteration, const WarmupSchedule& schedule) override; - bool has_missing_data() const override; - void impute_missing() override; - - // Capability queries - bool has_edge_selection() const override; - bool has_adaptive_metropolis() const override; - -private: - // --- Data --- - arma::imat x_; // n × p discrete observations - arma::mat y_; // n × q continuous observations - int n_, p_, q_; - arma::ivec num_categories_; // p-vector - int max_cats_; // max(num_categories) - - // --- Parameters --- - arma::mat mux_; // p × max_cats thresholds - // mux_(s, c) is the threshold for category c+1 of variable s; - // category 0 is the reference level (threshold fixed at 0). - // In C++ loops always use index c-1 when accessing mux_ for category c. - arma::vec muy_; // q-vector continuous means - arma::mat Kxx_; // p × p discrete interactions (diagonal always zero; - // enforced by construction — not a free parameter) - arma::mat Kyy_; // q × q SPD precision - arma::mat Kxy_; // p × q cross interactions - - // --- Edge indicators --- - // Single combined (p+q)×(p+q) matrix (Decision: Option A). - // Gxx block : rows [0,p), cols [0,p) — symmetric, zero diag - // Gyy block : rows [p,p+q), cols [p,p+q) — symmetric, zero diag - // Gxy block : rows [0,p), cols [p,p+q) — full p×q rectangle - // (lower-left mirror [p,p+q)×[0,p) unused; Gxy is not symmetric) - // - // Accessor helpers (prefer over raw index arithmetic throughout): - // int& gxx(int i, int j) { return edge_indicators_(i, j); } - // int& gyy(int i, int j) { return edge_indicators_(p_+i, p_+j); } - // int& gxy(int i, int j) { return edge_indicators_(i, p_+j); } - // - // Serialization order for get_vectorized_indicator_parameters(): - // 1. upper-tri(Gxx) row-major — length p(p-1)/2 - // 2. upper-tri(Gyy) row-major — length q(q-1)/2 - // 3. full Gxy row-major — length p*q - // Total length: p(p-1)/2 + q(q-1)/2 + p*q - arma::imat edge_indicators_; // (p+q) × (p+q) - arma::mat inclusion_probability_; // (p+q) × (p+q) - bool edge_selection_; - bool edge_selection_active_; - - // --- Proposal SDs (Robbins-Monro) --- - arma::mat prop_sd_Kxx_; // p × p - arma::mat prop_sd_Kyy_; // q × q - arma::mat prop_sd_Kxy_; // p × q - arma::mat prop_sd_mux_; // p × max_cats - arma::vec prop_sd_muy_; // q-vector - - // --- Cached quantities --- - arma::mat Kyy_inv_; // q × q inverse of Kyy (always maintained) - arma::mat Kyy_chol_; // q × q upper Cholesky of Kyy - double Kyy_log_det_; // log|Kyy| - arma::mat Theta_; // p × p Kxx + 2 * Kxy * Kyy_inv * Kyx - // (marginal PL only) - arma::mat conditional_mean_; // n × q muy + 2 * x * Kxy * Kyy_inv - - // --- Configuration --- - bool use_marginal_pl_; // true = marginal, false = conditional - SafeRNG rng_; - - // --- Edge update order --- - arma::uvec edge_order_; // shuffled pair indices -}; -``` - -### A.2 Implement trivial overrides - -Implement `parameter_dimension`, `get/set_vectorized_parameters`, -`get_full_vectorized_parameters`, `clone`, `set_seed`, `get_rng`, -`get_edge_indicators`, `get_inclusion_probability`, `get_num_variables`, -`get_num_pairwise`, `has_edge_selection`, `has_adaptive_metropolis`. - -**Parameter vectorization order (free parameters only):** -1. `mux_`: For each variable $s = 0,\ldots,p-1$, copy columns $0 \ldots C_s-1$ - of row $s$ (i.e., the thresholds for categories $1 \ldots C_s$; category 0 - is the fixed reference level). Each variable contributes $C_s$ entries; - the total count is $\sum_s C_s$. Variables with different $C_s$ produce - runs of different length — **do not use a fixed stride**. - ``` - idx = 0 - for s in 0..p-1: - for c in 0..num_categories_[s]-1: - out[idx++] = mux_(s, c) // threshold for category c+1 - ``` -2. `Kxx_`: strictly upper-triangular entries (symmetry supplies the lower - half). Count = $p(p-1)/2$. Diagonal entries are always zero and excluded. -3. `muy_`: all $q$ means. -4. `Kyy_`: upper-triangle **including** the diagonal. Count = $q(q+1)/2$. -5. `Kxy_`: all $p\times q$ cross entries, row-major. - -**`prepare_iteration()`** should be included in Phase A.2's list of trivial -overrides: it shuffles `edge_order_` (same pattern as `OMRFModel`) so that -RJ sweeps have no order bias. - -`parameter_dimension()` returns the number of currently **active** -parameters under edge selection (mirrors `OMRFModel`). -`full_parameter_dimension()` always returns the total count above so that -sample buffers have a fixed width even while RJ toggles edges on/off. - -`get_vectorized_indicator_parameters()` serializes in three contiguous -blocks (matching the layout documented in the class header comment): -1. Upper triangle of $G_{xx}$ (rows/cols $[0,p)$) — length $p(p-1)/2$ -2. Upper triangle of $G_{yy}$ (rows/cols $[p,p+q)$) — length $q(q-1)/2$ -3. Full $G_{xy}$ block (rows $[0,p)$, cols $[p,p+q)$) row-major — length $pq$ - -Total length: $p(p-1)/2 + q(q-1)/2 + pq$. -The diagonal of `edge_indicators_` is always zero and excluded throughout. - -### A.3 Testing checkpoint - -- Class compiles and links -- `parameter_dimension()` returns correct count -- Round-trip: `set_vectorized_parameters(get_vectorized_parameters())` is identity - ---- - -## Phase B — Conditional pseudo-likelihood - -### B.1 Implement log-likelihood functions - -In `mixed_mrf_likelihoods.cpp`: - -#### `log_conditional_omrf(int s)` — per-variable discrete PL - -Computes $\log f(x_s \mid x_{-s}, y)$ summed over all $n$ observations. - -``` -rest_score = x_[, -s] * Kxx_[-s, s] + 2 * y_ * Kxy_[s, ]^T - ^-- n×(p-1) * (p-1)×1 ^-- n×q * q×1 - = n-vector - -For each observation v: - For c = 1..C_s: - eta[v,c] = mux_[s, c] + c * rest_score[v] - log_Z[v] = log(1 + sum_c exp(eta[v,c])) (log-sum-exp stabilized) - ll += mux_[s, x_[v,s]] + x_[v,s] * rest_score[v] - log_Z[v] - (only if x_[v,s] > 0 for the threshold part) -``` - -**Reuse opportunity:** The inner loop (compute log-partition for -one ordinal variable given a rest-score vector) is identical to -`OMRFModel::ordinal_log_pseudolikelihood_ratio()`. Extract or -duplicate the stabilized log-sum-exp computation. - -#### `log_conditional_ggm()` — conditional GGM log-likelihood - -Computes $\log f(y \mid x)$ using cached `Kyy_inv_`, `Kyy_log_det_`, -and `conditional_mean_`. - -``` -conditional_mean_ = 1*muy_^T + 2 * x_ * Kxy_ * Kyy_inv_ -D = y_ - conditional_mean_ -quad_sum = sum((D * Kyy_) .* D) // trace of Kyy * D^T * D -ll = n/2 * (-q * log(2*pi) + Kyy_log_det_) - quad_sum / 2 -``` - -**Reuse opportunity:** This is structurally identical to `GGMModel::log_density_impl` -but with a non-zero, observation-dependent conditional mean. - -### B.2 Implement cache maintenance - -#### `recompute_conditional_mean()` -Recompute `conditional_mean_ = 1*muy^T + 2 * x_ * Kxy_ * Kyy_inv_`. -Called after any change to `muy_`, `Kxy_`, or `Kyy_`. - -#### `recompute_Kyy_decomposition()` -Recompute `Kyy_chol_`, `Kyy_inv_`, `Kyy_log_det_` from `Kyy_`. -Called after any change to `Kyy_`. - -#### `recompute_Theta()` (marginal PL only) -Recompute `Theta_ = Kxx_ + 2 * Kxy_ * Kyy_inv_ * Kxy_^T`. -Called after any change to `Kxx_`, `Kxy_`, or `Kyy_`. - -**Cache update order — proposed-state bookkeeping:** - -The guiding rule is: **all caches must be consistent with the proposed -parameter value before evaluating the proposed likelihood**. On rejection, -temporaries are discarded cheaply (Armadillo move semantics). - -1. **Kyy proposal:** Build proposed `Kyy_chol_prop`, `Kyy_inv_prop`, - `Kyy_log_det_prop`, and `conditional_mean_prop` (and `Theta_prop` if - marginal PL) into temporaries first. Evaluate proposed likelihood with - these temporaries. On acceptance, swap into cached members in the order: - `Kyy_` → `Kyy_chol_`/`Kyy_inv_`/`Kyy_log_det_` → `conditional_mean_` - → `Theta_` (if needed). On rejection, discard temporaries. -2. **Kxx proposal (marginal PL mode):** In marginal PL mode, `Theta_` - changes when `Kxx_` changes. Build `Theta_prop` from the proposed Kxx - into a temporary before evaluating `log_marginal_omrf`. On acceptance, - swap `Kxx_` and `Theta_`. On rejection, discard `Theta_prop`. In - conditional PL mode, no cache depends on `Kxx_`, so no temporary is - needed. -3. **Kxy proposal:** Build `conditional_mean_prop` (both modes) and - `Theta_prop` (marginal mode) into temporaries before evaluating the - proposed likelihood. On acceptance, swap `Kxy_`, `conditional_mean_`, - and (if marginal) `Theta_`. On rejection, discard temporaries. -4. **muy proposal (marginal PL mode):** `conditional_mean_` depends on - `muy_` directly. Build `conditional_mean_prop` (and re-evaluate all - rest-score bias terms `2 * dot(Kxy_.row(s), Kyy_inv_ * muy_prop)`) before - evaluating the proposed likelihood. Swap on acceptance. -5. **Every accepted RJ edge toggle (Phase D)** must trigger the same - cache-refresh logic as an MH acceptance for that parameter type. - -### B.3 Implement MH updates (conditional PL) - -Each update function follows the pattern: -1. Save current parameter value -2. Propose new value from $N(\text{current}, \text{prop\_sd})$ -3. Compute log acceptance ratio (log-likelihood change + log-prior change) -4. Accept/reject -5. Robbins-Monro update of proposal SD - -#### `update_threshold(int s, int c)` — one threshold $\mu_{x,s,c}$ -- Propose: $\mu'_{x,s,c} \sim N(\mu_{x,s,c}, \sigma_{\mu_x,s,c})$ -- Acceptance: `log_conditional_omrf(s)` at proposed vs current + prior ratio -- Prior: Normal(0, $\sigma^2$) - -#### `update_continuous_mean(int j)` — one mean $\mu_{y,j}$ -- Propose: $\mu'_{y,j} \sim N(\mu_{y,j}, \sigma_{\mu_y,j})$ -- Acceptance: `log_conditional_ggm()` at proposed vs current + prior ratio -- Must update `conditional_mean_` before evaluating proposed likelihood -- Prior: Normal(0, $\sigma^2$) - -#### `update_Kxx(int i, int j)` — one discrete interaction -- Propose: $K'_{xx,ij} \sim N(K_{xx,ij}, \sigma_{K_{xx},ij})$ -- Set both $K_{xx,ij}$ and $K_{xx,ji}$ (symmetric) -- Acceptance: `log_conditional_omrf(i) + log_conditional_omrf(j)` + prior -- Prior: Cauchy(0, scale) -- Only if $G_{xx,ij} = 1$ - -#### `update_Kyy_element(int i, int j)` — one precision element (Cholesky) -- Uses the same Cholesky-based proposal as `GGMModel`: - 1. Permute rows/cols so (i,j) maps to (q-1, q) position - 2. Compute Cholesky, extract constants - 3. Propose on the Cholesky scale: $\phi' \sim N(\phi, \sigma_{K_{yy},ij})$ - 4. Rebuild $K'_{yy}$ maintaining positive definiteness - 5. Unpermute -- Acceptance: `log_conditional_ggm()` at proposed vs current + priors -- Priors: Cauchy(0, scale) on off-diagonal; Gamma(shape, rate) on diagonal -- Must update all Kyy-dependent caches after acceptance - -Unlike `GGMModel`, no rank-2 determinant lemma shortcut is available: the -conditional mean depends on $K_{yy}^{-1}$, so each proposal evaluates the -full `log_conditional_ggm()` with freshly computed `Kyy_inv_` and -`conditional_mean_`. - -**Jacobian for diagonal Kyy proposals.** The diagonal element is proposed -on the log scale to guarantee positivity: -``` -// Diagonal element i: -theta_curr = log(L(i, i)) // L is the Cholesky factor -theta_prop = rnorm(theta_curr, prop_sd) -L_prop(i, i) = exp(theta_prop) -// Include log-Jacobian in acceptance: -ln_alpha += theta_prop - theta_curr -``` -Omitting this Jacobian biases the diagonal distribution. Off-diagonal -elements are proposed linearly and need no Jacobian. - -**Cholesky update strategy.** `GGMModel` uses rank-1 Cholesky -update/downdate (`cholupdate.h`) applied to a single off-diagonal entry. -The mixed MRF follows the `mixedGM` approach instead: **permute** the -target row/column pair to the last two positions, perform the 2×2 block -update, then unpermute. These are distinct algorithms. Place the permute -helpers in `mixed_mrf_cholesky.cpp`, ported from -`mixedGM/R/continuous_variable_helper.R`. Do **not** attempt to reuse -`GGMModel`'s rank-1 routines for this purpose. - -#### `update_Kxy(int i, int j)` — one cross interaction -- Propose: $K'_{xy,ij} \sim N(K_{xy,ij}, \sigma_{K_{xy},ij})$ -- Acceptance: `log_conditional_omrf(i) + log_conditional_ggm()` + prior -- Must update `conditional_mean_` after modifying `Kxy_` -- Prior: Cauchy(0, scale) -- Only if $G_{xy,ij} = 1$ - -### B.4 Implement `do_one_metropolis_step(int iteration)` - -```cpp -void MixedMRFModel::do_one_metropolis_step(int iteration) { - // Step 1: Update all thresholds - for (int s = 0; s < p_; ++s) - for (int c = 0; c < num_categories_[s]; ++c) - update_threshold(s, c, iteration); - - // Step 2: Update all continuous means - for (int j = 0; j < q_; ++j) - update_continuous_mean(j, iteration); - - // Step 3: Update Kxx (upper triangle, edge-gated) - for (int i = 0; i < p_ - 1; ++i) - for (int j = i + 1; j < p_; ++j) - if (!edge_selection_active_ || gxx(i, j) == 1) - update_Kxx(i, j, iteration); - - // Step 4: Update Kyy (off-diag + diagonal, edge-gated) - for (int i = 0; i < q_ - 1; ++i) - for (int j = i + 1; j < q_; ++j) - if (!edge_selection_active_ || gyy(i, j) == 1) - update_Kyy_offdiag(i, j, iteration); - for (int i = 0; i < q_; ++i) - update_Kyy_diag(i, iteration); // diagonal always active - - // Step 5: Update Kxy (edge-gated) - for (int i = 0; i < p_; ++i) - for (int j = 0; j < q_; ++j) - if (!edge_selection_active_ || gxy(i, j) == 1) - update_Kxy(i, j, iteration); -} -``` - -### B.5 Testing checkpoint — conditional PL estimation - -**Test 1: Likelihood agreement** -- Generate data from `mixed_gibbs_generate()` in R -- Compute `log_conditional_omrf(s)` and `log_conditional_ggm()` in both - R (prototype) and C++ at the same parameter values -- Assert agreement to machine precision - -**Test 2: Recovery (estimation only, no edge selection)** -- Generate data from known parameters (p=3 ordinal, q=2 continuous) -- Run mixed sampler with conditional PL, no edge selection -- Check posterior means recover true parameters (correlation > 0.9) -- Use `dev/plans/mixedMRF/mixedGM/dev/conditional_vs_marginal_pl.R` - as template - ---- - -## Phase C — Marginal pseudo-likelihood - -### C.1 Implement `log_marginal_omrf(int s)` - -Same structure as `log_conditional_omrf(s)` but uses $\Theta$ instead -of $K_{xx}$ and adds the $\mu_y$ bias term. The self-interaction must be -excluded explicitly: - -``` -Theta_ = Kxx_ + 2 * Kxy_ * Kyy_inv_ * Kxy_^T (cached; see C.3) - -// n-vector of rest-scores: -rest_score = x_dbl * Theta_.col(s) // include all n obs, all p vars - - x_dbl.col(s) * Theta_(s, s) // subtract self-interaction - + 2.0 * arma::dot(Kxy_.row(s), - Kyy_inv_ * muy_) // scalar bias, same for all obs -``` - -This matches `rcpp_log_pl_marginal_omrf` in `mixedGM/src/log_likelihoods.cpp` -exactly. The inner log-partition loop is then identical to `log_conditional_omrf`. - -### C.2 Modify `update_continuous_mean(int j)` for marginal mode - -In marginal PL mode, changing $\mu_y$ affects all $p$ marginal OMRF terms. -The acceptance ratio becomes: - -``` -sum_{s=1}^{p} [log_marginal_omrf(s, proposed) - log_marginal_omrf(s, current)] -+ [log_conditional_ggm(proposed) - log_conditional_ggm(current)] -+ log_prior_ratio -``` - -This is more expensive; patch the R prototype (factor-of-two fix + -`dnorm(log = TRUE)`) before generating fixtures so C++ and R target the -same expression. - -### C.3 Cache invalidation for marginal PL - -**Per-proposal proposed-Theta rule** (see Phase B.2): When proposing a -$K_{xx}$, $K_{xy}$, or $K_{yy}$ move in marginal PL mode, the proposed -marginal OMRF likelihood must be evaluated with a proposed $\Theta$ -computed from the proposed parameter values. This means each of those -update functions must build a local `Theta_prop` before calling -`log_marginal_omrf(s)` with the proposed value. See Phase B.2 for the -exact temporary-variable protocol. - -**Θ recompute granularity.** Recomputing $\Theta$ after every individual -element change is $O(p^2 q + pq^2)$ per move, which is expensive for -large $p, q$. The mitigation already noted in the risk register covers -the accepted-Theta path: after **accepting** a $K_{xx}$/$K_{xy}$/$K_{yy}$ -move, update the cached `Theta_` once from the new parameter state. Do -not recompute `Theta_` more than once per accepted move. - -For initial implementation, full recompute -is simpler and correct. Rank-1 shortcuts for single $K_{xy,ij}$ changes -are a future optimization. - -### C.4 Configuration dispatch - -`do_one_metropolis_step` dispatches between conditional and marginal -via `use_marginal_pl_`. The dispatch happens inside each update function, -not at the loop level — most updates have different acceptance ratios -in the two modes. - -### C.5 Testing checkpoint — marginal PL estimation - -**Test 3: Marginal likelihood agreement** -- Same data as Test 1, compute `log_marginal_omrf(s)` in R and C++ -- Assert agreement - -**Test 4: Recovery (marginal PL)** -- Same setup as Test 2 but `pseudolikelihood = "marginal"` -- Check posterior means recover true parameters - -**Test 5: Conditional vs marginal agreement** -- Run both modes on the same data, compare posterior means -- They should be similar (not identical — different approximations) - ---- - -## Phase D — Edge selection - -### D.1 Discrete edge selection (`update_edge_indicator_Kxx`) - -For each pair $(i, j)$ with $i < j$: -1. Propose $G'_{xx,ij} = 1 - G_{xx,ij}$ -2. If **birth** ($G_{xx,ij}=0 \to 1$): propose $K'_{xx,ij} \sim N(K_{xx,ij}, \sigma)$; - `k_curr = Kxx_(i,j)`, `k_prop = rnorm(k_curr, sigma)` (k_curr = 0 on a true birth). -3. If **death** ($G_{xx,ij}=1 \to 0$): set $K'_{xx,ij} = 0$. -4. Compute log acceptance ratio: - - Likelihood: `omrf(i) + omrf(j)` at proposed vs current parameters - - Slab prior: on birth add `log_slab_prior(k_prop)`, on death subtract `log_slab_prior(k_curr)` - - **Hastings ratio** (proposal asymmetry): - - On birth: subtract `dnorm(k_prop, k_curr, sigma, log=true)` (cost of generating the proposed value) - - On death: add `dnorm(k_curr, k_prop, sigma, log=true)` = `dnorm(k_curr, 0, sigma, log=true)` - (cost of the reverse birth, which would propose k_curr from a Normal centred on 0) - - Inclusion prior: $\log(\pi) - \log(1 - \pi)$ ratio (or reverse on death) -5. After an **accepted** move, refresh `Theta_` (marginal mode). - -All log-density calls must use `log = true` to keep all terms on the log scale. -Follows `cond_omrf_update_association_indicator_pair` in the R prototype. -Transplant the Hastings terms verbatim from that R code; do not re-derive. - -### D.2 Continuous edge selection (`update_edge_indicator_Kyy`) - -Cholesky-based birth/death as in `GGMModel`: -1. Permute, Cholesky, extract constants -2. If birth: propose $\phi' \sim N(0, \sigma)$, rebuild -3. If death: set off-diagonal to 0, rebuild -4. Acceptance: `cond_ggm()` + $\sum_s$ `marg_omrf(s)` (marginal mode only) - + inclusion prior + slab/proposal density -5. On acceptance, run the same cache pipeline as an MH precision update - (`Kyy` → decompositions → `conditional_mean_` → `Theta_`). - -Follows `cond_ggm_update_precision_indicator_pair` in the R prototype. - -**Reuse opportunity:** Directly reuse `GGMModel::update_edge_indicator_parameter_pair` -logic, adapted for the mixed model's conditional GGM likelihood. - -### D.3 Cross edge selection (`update_edge_indicator_Kxy`) - -Cross-edges $G_{xy,ij}$ share the same Bernoulli inclusion prior $\pi$ as -$G_{xx}$ and $G_{yy}$ (decided). A single $\pi$ keeps the SBM prior -uniform across all edge types; if different sparsity assumptions are needed -for cross-type edges in the future, a separate hyperparameter can be added -then. - -For each pair $(i, j)$ where $i \in \{1..p\}$, $j \in \{1..q\}$: -1. Propose $G'_{xy,ij} = 1 - G_{xy,ij}$ -2. If birth: propose $K'_{xy,ij} \sim N(K_{xy,ij}, \sigma)$ -3. If death: set $K'_{xy,ij} = 0$ -4. Acceptance: `omrf(i) + cond_ggm()` in conditional mode; - `cond_ggm() + \sum_s marg_omrf(s)` in marginal mode; spike-and-slab priors - plus log-dense Hastings terms. -5. On acceptance, update `conditional_mean_` (both modes) and `Theta_` - (marginal mode) before the next likelihood evaluation. - -Follows `cond_omrf_update_cross_association_indicator_pair` in the R prototype. - -### D.4 Implement `update_edge_indicators()` - -```cpp -void MixedMRFModel::update_edge_indicators() { - if (!edge_selection_active_) return; - - // Shuffle edge order - // ... - - // Discrete-discrete edges - for (int i = 0; i < p_ - 1; ++i) - for (int j = i + 1; j < p_; ++j) - update_edge_indicator_Kxx(i, j); - - // Continuous-continuous edges - for (int i = 0; i < q_ - 1; ++i) - for (int j = i + 1; j < q_; ++j) - update_edge_indicator_Kyy(i, j); - - // Cross edges - for (int i = 0; i < p_; ++i) - for (int j = 0; j < q_; ++j) - update_edge_indicator_Kxy(i, j); -} -``` - -### D.5 Testing checkpoint — edge selection - -**Test 6: Structure recovery** -- Generate data from a sparse mixed graph (some edges zero) -- Run with edge selection, check posterior inclusion probabilities - recover the true structure (true edges have high PIP, false edges low) - ---- - -## Phase E — R interface and integration - -### E.1 Create `sample_mixed.cpp` - -Rcpp interface function `sample_mixed_mrf_cpp(...)`: -- Takes R data (integer matrix `x`, numeric matrix `y`) -- Creates `MixedMRFModel` -- Creates edge prior -- Calls `run_mcmc_sampler()` -- Returns results as `Rcpp::List` - -Follow the pattern of `sample_ggm.cpp` and `sample_omrf.cpp`. - -### E.2 Extend `bgm()` in R - -The user interface uses **Option A** (decided): a single data frame plus a -`variable_type` argument: - -```r -bgm(data, variable_type = c("o", "o", "c", ...)) -``` - -- `data`: an $n \times (p+q)$ data frame or matrix with all variables. -- `variable_type`: character vector, length $p+q$; values `"o"` (ordinal) - or `"c"` (continuous). Column order must match `variable_type`. - -`bgm()` splits `data` into the integer matrix `x` (ordinal columns) and -numeric matrix `y` (continuous columns), then dispatches to -`sample_mixed_mrf_cpp()`. The split indices must be stored in the output -object so that `coef`, `predict`, and `simulate` methods can reconstruct -the original column order. - -**Missing data** is an explicit non-goal for this PR. -`has_missing_data()` returns `false` and `impute_missing()` is an empty -override so the shared sampler pipeline compiles. Future imputation support -for mixed data should be a separate phase (Phase H). - -### E.3 Extend `build_output.R` - -The output structure needs to accommodate: -- Separate interaction matrices: `Kxx`, `Kyy`, `Kxy` (or a combined one) -- Threshold samples: `mux` array -- Mean samples: `muy` matrix -- Edge indicators decomposed by type - -### E.4 Testing checkpoint — end-to-end - -**Test 7: `bgm()` with mixed data** -- Call `bgm()` with `variable_type` containing both ordinal and continuous -- Check output structure matches expected format -- Verify S3 methods work (`print`, `summary`, `coef`, `predict`) - ---- - -## Phase F — Warmup, adaptation, and diagnostics - -### F.1 Warmup schedule - -**Pre-condition check (before Phase F begins):** Verify that `WarmupSchedule` -and `ChainRunner` support a Metropolis-only model without NUTS mass-matrix -or step-size hooks. If `WarmupSchedule::stage_2_windows()` triggers -NUTS-specific adaptation that `MixedMRFModel` cannot satisfy, the class -must either no-op those hooks or a simplified schedule must be added. -Check `GGMModel`'s implementation of `init_metropolis_adaptation` and -`tune_proposal_sd` as the reference pattern. - -The mixed MRF is Metropolis-only (no NUTS/HMC), matching the GGM model. -Use the same warmup schedule: -- Stage 1: Initial fast adaptation (75 iterations) -- Stage 2: Doubling windows for covariance adaptation -- Stage 3a: Terminal fast adaptation -- Stage 3b: Proposal SD tuning with edge selection (if enabled) -- Stage 3c: Step-size re-adaptation with edge selection active - -### F.2 Robbins-Monro adaptation - -Per-parameter proposal SD adaptation, matching the R prototype: -``` -sigma_new = sigma_old + (acceptance - target) * weight -weight = (1/iter)^0.6 -target = 0.44 -``` - -Clamp to [0.001, 2.0]. - -### F.3 Init metropolis adaptation - -Override `init_metropolis_adaptation(const WarmupSchedule&)` to store -the schedule for use in `tune_proposal_sd()`. - -Override `tune_proposal_sd(int iteration, const WarmupSchedule&)` — the -Robbins-Monro is already embedded in each update function. - ---- - -## Phase G — Simulation and prediction - -### G.1 Mixed MRF simulation - -Extend `mrf_simulation.cpp` to support mixed data generation via -block Gibbs sampling, matching `mixed_gibbs_generate()`: - -1. Sample $x_s \mid x_{-s}, y$: categorical from log-sum-exp stabilized - probabilities -2. Sample $y \mid x$: multivariate Normal with conditional mean - $\mu_y + 2\,x\,K_{xy}\,K_{yy}^{-1}$ and covariance $K_{yy}^{-1}$ - -### G.2 Mixed MRF prediction - -Extend `mrf_prediction.cpp` for posterior predictive checks on mixed data. - ---- - -## Testing strategy - -### Existing test infrastructure (from mixedGM) - -The mixedGM prototype already provides comprehensive tests that can be adapted -for the bgms C++ implementation: - -| mixedGM Test File | What it covers | Adaptation for bgms | -|-------------------|----------------|---------------------| -| `test-likelihood-correctness.R` | C++ likelihoods vs hand-computed values | Use as reference for bgms unit tests | -| `test-parameter-recovery.R` | Posterior mean recovery (p=2,q=1 and larger) | Port test scenarios | -| `test-edge-selection.R` | Edge birth/death moves, PIP calibration | Port test scenarios | -| `test-edge-cases.R` | p=1, q=1, binary-only ordinal | Port edge cases | -| `test-cholesky-update.R` | Cholesky update/downdate correctness | Already covered by bgms GGM tests | -| `test-data-generator.R` | Gibbs generator sanity checks | Port | -| `test-mcmc-diagnostics.R` | ESS, R-hat convergence checks | Port | -| `test-pl-comparison.R` | Conditional vs marginal PL agreement | Port | - -### Test fixture availability - -The mixedGM prototype implements all C++ likelihood functions as Rcpp exports. -This means test fixtures can be generated by running mixedGM R code and -comparing against the bgms C++ port: - -```r -# Generate fixture in mixedGM -library(mixedGM) -set.seed(42) # ALWAYS use a documented seed -ll_cond_omrf = rcpp_log_pl_conditional_omrf(x, y, Kxx, Kxy, mux, num_categories, i) -# Compare against bgms MixedMRFModel::log_conditional_omrf(i) -``` - -All fixture generation must use `set.seed()` before any -RNG-dependent step. The seed and the mixedGM package version must be -recorded in a comment at the top of each fixture file. The mixedGM -tests use `set.seed(42)` consistently; use the same seed for bgms fixtures -unless a specific test requires otherwise. - -No need to patch the prototype — all math bugs have already been fixed. - -Cross-reference the existing mixedGM test files for each scenario: -- `tests/testthat/test-likelihood-correctness.R` (likelihood unit tests) -- `tests/testthat/test-parameter-recovery.R` (recovery scenarios) -- `tests/testthat/test-edge-selection.R` (structure recovery scenarios) -- `tests/testthat/test-edge-cases.R` (p=1, q=1, binary ordinal) - -### Unit tests (per-function correctness) - -| Test | What | How | -|------|------|-----| -| T1 | `log_conditional_omrf(s)` | Compare C++ vs R prototype at known parameters | -| T2 | `log_conditional_ggm()` | Compare C++ vs R prototype at known parameters | -| T3 | `log_marginal_omrf(s)` | Compare C++ vs R prototype at known parameters | -| T4 | `parameter_dimension()` | Check count for known p, q, num_categories | -| T5 | Vectorization round-trip | `set(get(params)) == params` | -| T6 | Cholesky permutation | Verify permute is involution, PD maintained | -| T7 | Analytic Gaussian check | Set $K_{xy}=0$ so `log_conditional_ggm()` reduces to a standard MVN and compare against closed form | -| T8 | Fixture replay | Load deterministic R fixtures and ensure C++ reproduces saved log-likelihood components bit-for-bit | -| T9 | Cache freshness | After each parameter tweak, recompute `conditional_mean_` and (if needed) `Theta_` from scratch and compare with cached copies | - -### Integration tests (sampling correctness) - -| Test | What | How | -|------|------|-----| -| T7 | Conditional PL recovery | Generate → estimate → check cor > 0.9 | -| T8 | Marginal PL recovery | Generate → estimate → check cor > 0.9 | -| T9 | Edge selection recovery | Sparse graph → check PIP > 0.5 for true, < 0.5 for false | -| T10 | Cond vs marginal agreement | Both modes → posterior means close | -| T11 | Reproducibility | Same seed → identical output | -| T12 | Multi-chain | 4 chains → R-hat < 1.1 | - -### Edge-case & stress tests - -| Test | What | How | -|------|------|-----| -| T13 | Kyy PD invariant | Run 1k iterations with aggressive proposals; Cholesky of $K_{yy}$ must succeed after every accepted move | -| T14 | Cache consistency under RJ | After edge births/deaths, recompute caches and assert equality (debug build) | -| T15 | Degenerate $p=1$ | Single ordinal variable, ensure sampler runs and recovers $K_{xy}, K_{yy}$ | -| T16 | Degenerate $q=1$ | Scalar precision, ensure Cholesky permutation code handles the trivial case | -| T17 | Binary-only ordinal | All $C_s=1$, verify conditional PL matches logistic rest-scores | -| T18 | Gibbs generator sanity | Compare empirical moments from `mixed_gibbs_generate()` against theoretical targets before using it for fixtures | - -### Regression tests - -| Test | What | -|------|------| -| R1 | Existing GGM tests still pass | -| R2 | Existing OMRF tests still pass | -| R3 | Existing bgmCompare tests still pass | - ---- - -## Reuse inventory - -### From mixedGM prototype (port to bgms) - -| Component | mixedGM Location | Target in bgms | Action | -|-----------|-----------------|----------------|--------| -| Conditional OMRF likelihood | `src/log_likelihoods.cpp` | `mixed_mrf_likelihoods.cpp` | Port (remove Rcpp exports) | -| Marginal OMRF likelihood | `src/log_likelihoods.cpp` | `mixed_mrf_likelihoods.cpp` | Port | -| Conditional GGM likelihood | `src/log_likelihoods.cpp` | `mixed_mrf_likelihoods.cpp` | Port | -| Θ computation | `src/log_likelihoods.cpp` | `mixed_mrf_likelihoods.cpp` | Port | -| Log-sum-exp helper | `src/log_likelihoods.cpp` | `mixed_mrf_likelihoods.cpp` | Port | -| Gibbs data generator | `src/mixed_gibbs.cpp` | `mrf_simulation.cpp` | Extend | -| Cholesky helpers | `src/cholupdate.cpp` | Already in bgms | N/A | -| Cholesky permute/R() | `R/continuous_variable_helper.R` | `mixed_mrf_cholesky.cpp` | Port to C++ | -| MH Kxx updates | `R/cond_omrf_mh_update_functions.R` | `mixed_mrf_metropolis.cpp` | Port to C++ | -| MH Kyy updates | `R/cond_ggm_mh_update_functions.R` | `mixed_mrf_metropolis.cpp` | Port to C++ | -| MH Kxy updates | `R/cond_omrf_mh_update_functions.R` | `mixed_mrf_metropolis.cpp` | Port to C++ | -| Edge selection (Kxx) | `R/cond_omrf_mh_update_functions.R` | `mixed_mrf_edge_selection.cpp` | Port to C++ | -| Edge selection (Kyy) | `R/cond_ggm_mh_update_functions.R` | `mixed_mrf_edge_selection.cpp` | Port to C++ | -| Edge selection (Kxy) | `R/cond_omrf_mh_update_functions.R` | `mixed_mrf_edge_selection.cpp` | Port to C++ | -| Test fixtures | `tests/testthat/test-*.R` | `tests/testthat/test-mixed-*.R` | Adapt | - -### From bgms existing infrastructure (direct reuse) - -| Component | Source | Reuse type | -|-----------|--------|------------| -| Cholesky update/downdate | `src/models/ggm/cholupdate.h` | Direct include | -| Log-sum-exp stabilization | `OMRFModel::compute_logZ_*` | Adapt pattern | -| Robbins-Monro adaptation | `OMRFModel::robbins_monro_*` | Direct reuse | -| Edge prior (SBM) | `src/priors/sbm_edge_prior.h` | Direct reuse | -| WarmupSchedule | `src/mcmc/execution/warmup_schedule.h` | Direct reuse | -| ChainRunner | `src/mcmc/execution/chain_runner.h` | Direct reuse | -| SafeRNG | `src/rng/rng_utils.h` | Direct reuse | -| Progress manager | `src/utils/progress_manager.h` | Direct reuse | -| R output builder | `R/build_output.R` | Extend | -| Validation functions | `R/validate_data.R` | Extend | -| Rcpp interface pattern | `src/sample_ggm.cpp` | Copy pattern | - ---- - -## Risk register - -| Risk | Impact | Mitigation | -|------|--------|------------| -| Kyy inversion per MH step is $O(q^3)$ | Slow for large $q$ | Cache `Kyy_inv_`; incremental rank-1 updates after Cholesky proposals | -| Marginal PL $\mu_y$ update evaluates all $p$ OMRF terms | Slow for large $p$ | Cache rest-scores; provide conditional PL as faster default | -| $\Theta$ recomputation after every $K_{yy}$/$K_{xy}$ change | $O(p^2 q + pq^2)$ | Defer $\Theta$ recompute to once per accepted move, not per-element; proposed-state uses a temporary (Phase B.2, C.3) | -| Edge selection order effects | Bias | Shuffle edge order each iteration (already done in OMRF) | -| **Sticky edge indicators** | Poor PIP mixing | Shuffle cross-edge and xx-edge update order each iteration; monitor PIP stability and autocorrelation across chains | -| Numerical instability in log-sum-exp | NaN/Inf | Use stabilized version from OMRF (subtract max) | -| Factor 2 convention mismatch | Wrong posteriors | Document consistently; unit-test against R prototype | -| PD violation during Kyy proposals | Crash | Cholesky-based proposals guarantee PD by construction | -| Large parameter space mixing | Poor ESS | Per-parameter Robbins-Monro; future: block updates or HMC | - ---- - -## Updated implementation workflow - -Given that mixedGM provides tested C++ likelihood implementations and -validated R MH updates, the recommended workflow is: - -### Phase 1: Port likelihoods (1-2 days) - -1. Create `src/models/mixed/` directory -2. Copy `mixedGM/src/log_likelihoods.cpp` → `mixed_mrf_likelihoods.cpp` -3. Remove Rcpp exports, convert to internal functions -4. Adapt to bgms coding conventions (SafeRNG, etc.) -5. **Validation**: Call mixedGM Rcpp functions from R and compare to bgms - -### Phase 2: Skeleton class (1-2 days) - -1. Create `mixed_mrf_model.h` following GGMModel pattern -2. Implement all BaseModel pure virtuals (stubs where needed) -3. Implement `parameter_dimension()`, `get/set_vectorized_parameters()` -4. Implement `clone()`, `set_seed()`, lifecycle hooks -5. **Validation**: Class compiles, vectorization round-trip passes - -### Phase 3: MH updates (2-3 days) - -1. Port `cond_omrf_update_*` functions from R to C++ -2. Port `cond_ggm_update_*` functions from R to C++ -3. Port `cond_update_cross_associations` from R to C++ -4. Implement `do_one_metropolis_step()` calling all update functions -5. **Validation**: Compare posterior samples vs mixedGM on same data - -### Phase 4: Edge selection (2-3 days) - -1. Port `*_update_*_indicator_pair` functions from R to C++ -2. Implement `update_edge_indicators()` calling all three edge types -3. Implement `initialize_graph()`, `prepare_iteration()` -4. **Validation**: Run edge selection tests from mixedGM - -### Phase 5: R interface (1-2 days) - -1. Create `sample_mixed.cpp` following `sample_ggm.cpp` pattern -2. Extend `bgm()` to detect and dispatch mixed data -3. Extend `build_output.R` for mixed model output structure -4. **Validation**: End-to-end `bgm()` call works - -### Phase 6: Gibbs generator (1 day) - -1. Port `mixed_gibbs_generate_cpp()` to `mrf_simulation.cpp` -2. Extend `simulate.bgms()` for mixed models -3. **Validation**: Generated data matches mixedGM output - -### Total estimated effort: 8-13 days - -The C++ likelihood implementations (Phase B-C in the original plan) are -largely complete via mixedGM, reducing the remaining work to porting MH -updates and wiring into the bgms infrastructure. diff --git a/dev/numerical_analyses/bgm_blumecapel_normalization_PL.R b/dev/numerical_analyses/bgm_blumecapel_normalization_PL.R deleted file mode 100644 index 26359859..00000000 --- a/dev/numerical_analyses/bgm_blumecapel_normalization_PL.R +++ /dev/null @@ -1,647 +0,0 @@ -# ============================================================================== -# Blume–Capel Numerical Stability Study (reparametrized) -# File: dev/numerical_analyses/BCvar_normalization_PL.r -# -# Goal -# ---- -# Compare numerical stability of four ways to compute the Blume–Capel -# normalizing constant across a range of residual scores r, using the -# reparametrized form -# -# Z(r) = sum_{s=0}^C exp( θ_part(s) + s * r ), -# -# where -# -# θ_part(s) = θ_lin * (s - ref) + θ_quad * (s - ref)^2. -# -# This corresponds to the reformulated denominator where: -# - scores s are in {0, 1, ..., C}, -# - the quadratic/linear θ-part is in terms of the centered score (s - ref), -# - the “residual” r enters only through s * r. -# -# Methods (exactly four): -# 1) Direct -# Unbounded sum of exp(θ_part(s) + s * r). -# -# 2) Preexp -# Unbounded “power-chain” over s, precomputing exp(θ_part(s)) and -# reusing exp(r): -# Z(r) = sum_s exp(θ_part(s)) * (exp(r))^s . -# -# 3) Direct + max-bound -# Per-r max-term bound M(r) = max_s (θ_part(s) + s * r), -# computing -# Z(r) = exp(M(r)) * sum_s exp(θ_part(s) + s * r - M(r)), -# but returning only the *scaled* sum: -# sum_s exp(θ_part(s) + s * r - M(r)). -# -# 4) Preexp + max-bound -# Same max-term bound M(r) as in (3), but using the power-chain: -# sum_s exp(θ_part(s)) * exp(s * r - M(r)). -# -# References (for error calculation): -# - ref_unscaled = MPFR sum_s exp(θ_part(s) + s * r) -# - ref_scaled = MPFR sum_s exp(θ_part(s) + s * r - M(r)), -# where M(r) = max_s (θ_part(s) + s * r) in MPFR. -# -# Dependencies -# ------------ -# - Rmpfr -# -# Outputs -# ------- -# compare_bc_all_methods(...) returns a data.frame with: -# r : grid of residual scores -# direct : numeric, Σ_s exp(θ_part(s) + s * r) -# preexp : numeric, Σ_s via power-chain (unbounded) -# direct_bound : numeric, Σ_s exp(θ_part(s) + s * r - M(r)) -# preexp_bound : numeric, Σ_s via power-chain with max-term bound -# err_direct : |(direct - ref_unscaled)/ref_unscaled| -# err_preexp : |(preexp - ref_unscaled)/ref_unscaled| -# err_direct_bound : |(direct_bound - ref_scaled )/ref_scaled | -# err_preexp_bound : |(preexp_bound - ref_scaled )/ref_scaled | -# ref_unscaled : numeric MPFR reference (unbounded) -# ref_scaled : numeric MPFR reference (max-term scaled) -# -# Plotting helpers (unchanged interface): -# - plot_bc_four(res, ...) -# - summarize_bc_four(res) -# -# ============================================================================== - -library(Rmpfr) - -# ------------------------------------------------------------------------------ -# compare_bc_all_methods -# ------------------------------------------------------------------------------ -# Compute all four methods and MPFR references over a vector of r-values -# for the reparametrized Blume–Capel normalizing constant -# -# Z(r) = sum_{s=0}^C exp( θ_lin * (s - ref) + θ_quad * (s - ref)^2 + s * r ). -# -# Args: -# max_cat : integer, max category C (scores are s = 0..C) -# ref : integer, baseline category index for centering (s - ref) -# r_vals : numeric vector of r values to scan -# theta_lin : numeric, linear θ parameter -# theta_quad : numeric, quadratic θ parameter -# mpfr_prec : integer, MPFR precision (bits) for reference calculations -# -# Returns: -# data.frame with columns described in the file header (see “Outputs”). -# ------------------------------------------------------------------------------ - -compare_bc_all_methods <- function(max_cat = 10, - ref = 3, - r_vals = seq(-70, 70, length.out = 2000), - theta_lin = 0.12, - theta_quad = -0.02, - mpfr_prec = 256) { - - # --- score grid and θ-part --------------------------------------------------- - scores <- 0:max_cat # s = 0..C - centered <- scores - ref # (s - ref) - - # θ_part(s) = θ_lin*(s - ref) + θ_quad*(s - ref)^2 - theta_part <- theta_lin * centered + theta_quad * centered^2 - - # For the unbounded power-chain: exp(θ_part(s)) - exp_m <- exp(theta_part) - - # Output container ------------------------------------------------------------ - res <- data.frame( - r = r_vals, - direct = NA_real_, - preexp = NA_real_, - direct_bound = NA_real_, - preexp_bound = NA_real_, - err_direct = NA_real_, - err_preexp = NA_real_, - err_direct_bound = NA_real_, - err_preexp_bound = NA_real_, - ref_unscaled = NA_real_, - ref_scaled = NA_real_, - bound = NA_real_, # term_max = M(r), puur ter inspectie - theta_lin = theta_lin, - theta_quad = theta_quad, - max_cat = max_cat, - ref = ref - ) - - # --- MPFR constants independent of r ---------------------------------------- - tl_mpfr <- mpfr(theta_lin, mpfr_prec) - tq_mpfr <- mpfr(theta_quad, mpfr_prec) - sc_center_mpfr <- mpfr(centered, mpfr_prec) # (s - ref) - sc_raw_mpfr <- mpfr(scores, mpfr_prec) # s - - # --- Main loop over r -------------------------------------------------------- - for (i in seq_along(r_vals)) { - r <- r_vals[i] - - # Standard double-precision exponents - term <- theta_part + scores * r - - # ---------- MPFR references ---------- - r_mpfr <- mpfr(r, mpfr_prec) - term_mpfr <- tl_mpfr * sc_center_mpfr + - tq_mpfr * sc_center_mpfr * sc_center_mpfr + - sc_raw_mpfr * r_mpfr - - term_max_mpfr <- mpfr(max(asNumeric(term_mpfr)), mpfr_prec) - ref_unscaled_mpfr <- sum(exp(term_mpfr)) - ref_scaled_mpfr <- sum(exp(term_mpfr - term_max_mpfr)) - - # Store numeric references - res$ref_unscaled[i] <- asNumeric(ref_unscaled_mpfr) - res$ref_scaled[i] <- asNumeric(ref_scaled_mpfr) - - # ---------- (1) Direct (unbounded) ---------- - v_direct <- sum(exp(term)) - res$direct[i] <- v_direct - - # ---------- (2) Preexp (unbounded) ---------- - # Power-chain on exp(r): s = 0..max_cat, so start at s=0 with pow = 1 - eR <- exp(r) - pow <- 1.0 - S_pre <- 0.0 - for (j in seq_along(scores)) { - S_pre <- S_pre + exp_m[j] * pow - pow <- pow * eR - } - res$preexp[i] <- S_pre - - # ---------- (3) Direct + max-bound ---------- - term_max <- max(term) # M(r) - res$bound[i] <- term_max - - sum_direct_bound <- 0.0 - for (j in seq_along(scores)) { - sum_direct_bound <- sum_direct_bound + - exp(theta_part[j] + scores[j] * r - term_max) - } - res$direct_bound[i] <- sum_direct_bound - - # ---------- (4) Preexp + max-bound ---------- - pow_b <- exp(-term_max) # s = 0 → exp(0*r - term_max) - S_pre_b <- 0.0 - for (j in seq_along(scores)) { - S_pre_b <- S_pre_b + exp_m[j] * pow_b - pow_b <- pow_b * eR - } - res$preexp_bound[i] <- S_pre_b - - # ---------- Errors (vs MPFR) ---------- - res$err_direct[i] <- - asNumeric(abs((mpfr(v_direct, mpfr_prec) - ref_unscaled_mpfr) / ref_unscaled_mpfr)) - res$err_preexp[i] <- - asNumeric(abs((mpfr(S_pre, mpfr_prec) - ref_unscaled_mpfr) / ref_unscaled_mpfr)) - res$err_direct_bound[i] <- - asNumeric(abs((mpfr(sum_direct_bound, mpfr_prec) - ref_scaled_mpfr) / ref_scaled_mpfr)) - res$err_preexp_bound[i] <- - asNumeric(abs((mpfr(S_pre_b, mpfr_prec) - ref_scaled_mpfr) / ref_scaled_mpfr)) - } - - res -} - - - -# ------------------------------------------------------------------------------ -# plot_bc_four -# ------------------------------------------------------------------------------ -# Plot the four relative error curves on a log y-axis. -# -# Args: -# res : data.frame produced by compare_bc_all_methods() -# draw_order : character vector with any ordering of: -# c("err_direct","err_direct_bound","err_preexp_bound","err_preexp") -# alpha : named numeric vector (0..1) alphas for the same names -# lwd : line width -# -# Returns: (invisible) NULL. Draws a plot. -# -plot_bc_four = function(res, - draw_order = c("err_direct","err_direct_bound", - "err_preexp_bound","err_preexp"), - alpha = c(err_direct = 0.00, - err_direct_bound = 0.00, - err_preexp_bound = 0.40, - err_preexp = 0.40), - lwd = 2) { - - base_cols = c(err_direct = "#000000", - err_preexp = "#D62728", - err_direct_bound = "#1F77B4", - err_preexp_bound = "#9467BD") - - to_rgba = function(hex, a) rgb(t(col2rgb(hex))/255, alpha = a) - - cols = mapply(to_rgba, base_cols[draw_order], alpha[draw_order], - SIMPLIFY = TRUE, USE.NAMES = TRUE) - - vals = unlist(res[draw_order]) - vals = vals[is.finite(vals)] - ylim = if (length(vals)) { - q = stats::quantile(vals, c(.01, .99), na.rm = TRUE) - c(q[1] / 10, q[2] * 10) - } else c(1e-20, 1e-12) - - first = draw_order[1] - plot(res$r, res[[first]], type = "l", log = "y", - col = cols[[1]], lwd = lwd, ylim = ylim, - xlab = "r", ylab = "Relative error (vs MPFR)", - main = "Blume–Capel: Direct / Preexp / (Split) Bound") - - if (length(draw_order) > 1) { - for (k in 2:length(draw_order)) { - lines(res$r, res[[draw_order[k]]], col = cols[[k]], lwd = lwd) - } - } - - abline(h = .Machine$double.eps, col = "gray70", lty = 2) - - ## --- Theoretical bound where max term hits exp(709) - scores <- 0:res$max_cat[1] - centered <- scores - res$ref[1] - - # θ_part(s) = θ_lin*(s-ref) + θ_quad*(s-ref)^2 - theta_part <- res$theta_lin[1] * centered + - res$theta_quad[1] * centered * centered - - U <- 709 - pos <- scores > 0 - - if (any(pos)) { - r_up_vec <- (U - theta_part[pos]) / scores[pos] - r_up <- min(r_up_vec) - } else { - r_up <- Inf - } - - # Geen zinvolle beneden-grens voor overflow met s >= 0 - r_low <- -Inf - - if (is.finite(r_up)) { - abline(v = r_up, col = "darkgreen", lty = 2, lwd = 2) - } - - print(r_low) - print(r_up) - - legend("top", - legend = c("Direct", - "Direct + bound (split)", - "Preexp + bound (split)", - "Preexp") - [match(draw_order, - c("err_direct","err_direct_bound", - "err_preexp_bound","err_preexp"))], - col = cols, lwd = lwd, bty = "n") - - invisible(NULL) -} - - -# ------------------------------------------------------------------------------ -# summarize_bc_four -# ------------------------------------------------------------------------------ -# Summarize accuracy per method. -# -# Args: -# res : data.frame from compare_bc_all_methods() -# -# Returns: -# data.frame with columns: Method, Mean, Median, Max, Finite -# -summarize_bc_four = function(res) { - cols = c("err_direct","err_direct_bound","err_preexp_bound","err_preexp") - labs = c("Direct","Direct+Bound(split)","Preexp+Bound(split)","Preexp") - mk = function(v){ - f = is.finite(v) & v > 0 - c(Mean=mean(v[f]), Median=median(v[f]), Max=max(v[f]), Finite=mean(f)) - } - out = t(sapply(cols, function(nm) mk(res[[nm]]))) - data.frame(Method=labs, out, row.names=NULL, check.names=FALSE) -} - -# ============================================================================== -# Example usage (uncomment to run locally) -# ------------------------------------------------------------------------------ -# res = compare_bc_all_methods( -# max_cat = 4, -# ref = 0, -# r_vals = seq(170, 175, length.out = 1000), -# theta_lin = 0, -# theta_quad = 1.00, -# mpfr_prec = 256 -# ) -# plot_bc_four(res, -# draw_order = c("err_direct","err_direct_bound","err_preexp_bound","err_preexp"), -# alpha = c(err_direct = 0.00, -# err_direct_bound = 1.00, -# err_preexp_bound = 1.00, -# err_preexp = 0.00), -# lwd = 1) -# print(summarize_bc_four(res), digits = 3) -# ============================================================================== - -scan_bc_configs <- function(max_cat_vec = c(4, 10), - ref_vec = c(0, 2), - theta_lin_vec = c(0.0, 0.12), - theta_quad_vec = c(-0.02, 0.0, 0.02), - r_vals = seq(-80, 80, length.out = 2000), - mpfr_prec = 256, - tol = 1e-12) { - - cfg_grid <- expand.grid( - max_cat = max_cat_vec, - ref = ref_vec, - theta_lin = theta_lin_vec, - theta_quad = theta_quad_vec, - KEEP.OUT.ATTRS = FALSE, - stringsAsFactors = FALSE - ) - - all_summaries <- vector("list", nrow(cfg_grid)) - - for (i in seq_len(nrow(cfg_grid))) { - cfg <- cfg_grid[i, ] - cat("Config", i, "of", nrow(cfg_grid), ":", - "max_cat =", cfg$max_cat, - "ref =", cfg$ref, - "theta_lin =", cfg$theta_lin, - "theta_quad =", cfg$theta_quad, "\n") - - res_i <- compare_bc_all_methods( - max_cat = cfg$max_cat, - ref = cfg$ref, - r_vals = r_vals, - theta_lin = cfg$theta_lin, - theta_quad = cfg$theta_quad, - mpfr_prec = mpfr_prec - ) - - summ_i <- summarize_bc_methods(res_i, tol = tol) - all_summaries[[i]] <- summ_i - } - - do.call(rbind, all_summaries) -} - -classify_bc_bound_methods <- function(res, tol = 1e-12, - eps_better = 1e-3) { - # tol : threshold for "good enough" relative error - # eps_better : multiplicative margin to call one method "better" when both good - - r <- res$r - eD <- res$err_direct_bound - eP <- res$err_preexp_bound - - finiteD <- is.finite(eD) & eD > 0 - finiteP <- is.finite(eP) & eP > 0 - - goodD <- finiteD & (eD < tol) - goodP <- finiteP & (eP < tol) - - state <- character(length(r)) - - for (i in seq_along(r)) { - if (!goodD[i] && !goodP[i]) { - state[i] <- "neither_good" - } else if (goodD[i] && !goodP[i]) { - state[i] <- "only_direct_good" - } else if (!goodD[i] && goodP[i]) { - state[i] <- "only_preexp_good" - } else { - # both good: compare which is better - # e.g. if preexp_bound error is at least eps_better times smaller than direct_bound - if (eP[i] <= eD[i] * (1 - eps_better)) { - state[i] <- "both_good_preexp_better" - } else if (eD[i] <= eP[i] * (1 - eps_better)) { - state[i] <- "both_good_direct_better" - } else { - # both good and within eps_better fraction: treat as "tie" - state[i] <- "both_good_similar" - } - } - } - - data.frame( - r = r, - err_direct_bound = eD, - err_preexp_bound = eP, - state = factor(state), - bound = res$bound, - max_cat = res$max_cat[1], - ref = res$ref[1], - theta_lin = res$theta_lin[1], - theta_quad = res$theta_quad[1], - stringsAsFactors = FALSE - ) -} - -summarize_bc_bound_classification <- function(class_df) { - # class_df is the output of classify_bc_bound_methods() - - r <- class_df$r - state <- as.character(class_df$state) - - if (length(r) == 0) { - return(class_df[FALSE, ]) # empty - } - - # Identify run boundaries where state changes - blocks <- list() - start_idx <- 1 - current_state <- state[1] - - for (i in 2:length(r)) { - if (state[i] != current_state) { - # close previous block - blocks[[length(blocks) + 1]] <- list( - state = current_state, - i_start = start_idx, - i_end = i - 1 - ) - # start new block - start_idx <- i - current_state <- state[i] - } - } - # close last block - blocks[[length(blocks) + 1]] <- list( - state = current_state, - i_start = start_idx, - i_end = length(r) - ) - - # Turn into a data.frame with r-intervals and some diagnostics - out_list <- vector("list", length(blocks)) - for (k in seq_along(blocks)) { - b <- blocks[[k]] - idx <- b$i_start:b$i_end - out_list[[k]] <- data.frame( - state = b$state, - r_min = min(r[idx]), - r_max = max(r[idx]), - # a few handy diagnostics per block: - max_err_direct_bound = max(class_df$err_direct_bound[idx], na.rm = TRUE), - max_err_preexp_bound = max(class_df$err_preexp_bound[idx], na.rm = TRUE), - min_bound = min(class_df$bound[idx], na.rm = TRUE), - max_bound = max(class_df$bound[idx], na.rm = TRUE), - n_points = length(idx), - max_cat = class_df$max_cat[1], - ref = class_df$ref[1], - theta_lin = class_df$theta_lin[1], - theta_quad = class_df$theta_quad[1], - stringsAsFactors = FALSE - ) - } - - do.call(rbind, out_list) -} - -# 1. Run the basic comparison -r_vals <- seq(0, 100, length.out = 2000) - -res4 <- compare_bc_all_methods( - max_cat = 4, - ref = 0, - r_vals = r_vals, - theta_lin = 0.12, - theta_quad = -0.02, - mpfr_prec = 256 -) - -# 2. Classify per-r which bound-method wins -class4 <- classify_bc_bound_methods(res4, tol = 1e-12, eps_better = 1e-3) - -# 3. Compress into r-intervals -summary4 <- summarize_bc_bound_classification(class4) -print(summary4, digits = 3) - - - - -simulate_bc_fast_safe <- function(param_grid, - r_vals = seq(-80, 80, length.out = 2000), - mpfr_prec = 256, - tol = 1e-12) { - # param_grid: data.frame with columns - # max_cat, ref, theta_lin, theta_quad - # r_vals : vector of residual r values - # tol : tolerance for "ok" numerics (relative error) - # - # Returns one big data.frame with columns: - # config_id, max_cat, ref, theta_lin, theta_quad, - # r, bound, fast_val, safe_val, - # err_fast, err_safe, ok_fast, ok_safe, - # ref_scaled (MPFR reference) - - if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { - stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") - } - - out_list <- vector("list", nrow(param_grid)) - - for (cfg_idx in seq_len(nrow(param_grid))) { - cfg <- param_grid[cfg_idx, ] - max_cat <- as.integer(cfg$max_cat) - ref <- as.integer(cfg$ref) - theta_lin <- as.numeric(cfg$theta_lin) - theta_quad <- as.numeric(cfg$theta_quad) - - # --- score grid and θ-part for this config -------------------------------- - scores <- 0:max_cat - centered <- scores - ref - - theta_part <- theta_lin * centered + theta_quad * centered^2 - exp_m <- exp(theta_part) # for fast method - - # MPFR constants - tl_mpfr <- mpfr(theta_lin, mpfr_prec) - tq_mpfr <- mpfr(theta_quad, mpfr_prec) - sc_center_mpfr <- mpfr(centered, mpfr_prec) - sc_raw_mpfr <- mpfr(scores, mpfr_prec) - - # Storage for this config - n_r <- length(r_vals) - res_cfg <- data.frame( - config_id = rep(cfg_idx, n_r), - max_cat = rep(max_cat, n_r), - ref = rep(ref, n_r), - theta_lin = rep(theta_lin, n_r), - theta_quad = rep(theta_quad, n_r), - r = r_vals, - bound = NA_real_, - fast_val = NA_real_, - safe_val = NA_real_, - err_fast = NA_real_, - err_safe = NA_real_, - ok_fast = NA, - ok_safe = NA, - ref_scaled = NA_real_, - stringsAsFactors = FALSE - ) - - # --- main loop over r for this config ------------------------------------- - for (i in seq_along(r_vals)) { - r <- r_vals[i] - - ## Double-precision exponents: - term <- theta_part + scores * r # θ_part(s) + s*r - term_max <- max(term) # M(r) = bound - res_cfg$bound[i] <- term_max - - ## MPFR reference (scaled with max-term): - r_mpfr <- mpfr(r, mpfr_prec) - term_mpfr <- tl_mpfr * sc_center_mpfr + - tq_mpfr * sc_center_mpfr * sc_center_mpfr + - sc_raw_mpfr * r_mpfr - term_max_mpfr <- mpfr(max(asNumeric(term_mpfr)), mpfr_prec) - ref_scaled_mpfr <- sum(exp(term_mpfr - term_max_mpfr)) - ref_scaled_num <- asNumeric(ref_scaled_mpfr) - res_cfg$ref_scaled[i] <- ref_scaled_num - - # --- SAFE: Direct + max-bound ------------------------------------------ - # Z_safe = sum_s exp(θ_part(s) + s*r - term_max) - safe_sum <- 0.0 - for (j in seq_along(scores)) { - safe_sum <- safe_sum + exp(theta_part[j] + scores[j] * r - term_max) - } - res_cfg$safe_val[i] <- safe_sum - - # --- FAST: Preexp + max-bound (power-chain) ---------------------------- - # Z_fast = sum_s exp(θ_part(s)) * exp(s*r - term_max) - eR <- exp(r) - pow_b <- exp(-term_max) # s = 0 → exp(0*r - term_max) - fast_sum <- 0.0 - for (j in seq_along(scores)) { - fast_sum <- fast_sum + exp_m[j] * pow_b - pow_b <- pow_b * eR - } - res_cfg$fast_val[i] <- fast_sum - - # --- Relative errors vs MPFR (scaled) ---------------------------------- - if (is.finite(ref_scaled_num) && ref_scaled_num > 0) { - res_cfg$err_safe[i] <- abs(safe_sum - ref_scaled_num) / ref_scaled_num - res_cfg$err_fast[i] <- abs(fast_sum - ref_scaled_num) / ref_scaled_num - } else { - res_cfg$err_safe[i] <- NA_real_ - res_cfg$err_fast[i] <- NA_real_ - } - - res_cfg$ok_safe[i] <- !is.na(res_cfg$err_safe[i]) && - is.finite(res_cfg$err_safe[i]) && - (res_cfg$err_safe[i] < tol) - - res_cfg$ok_fast[i] <- !is.na(res_cfg$err_fast[i]) && - is.finite(res_cfg$err_fast[i]) && - (res_cfg$err_fast[i] < tol) - } - - out_list[[cfg_idx]] <- res_cfg - } - - do.call(rbind, out_list) -} \ No newline at end of file diff --git a/dev/numerical_analyses/bgm_blumecapel_normalization_PL_extra.R b/dev/numerical_analyses/bgm_blumecapel_normalization_PL_extra.R deleted file mode 100644 index 43ebeb04..00000000 --- a/dev/numerical_analyses/bgm_blumecapel_normalization_PL_extra.R +++ /dev/null @@ -1,900 +0,0 @@ -############################################################ -# Blume–Capel normalization analysis: -# Numerical comparison of FAST vs SAFE exponentiation methods -# -# Objective -# --------- -# This script provides a full numerical investigation of two methods -# to compute the *scaled* Blume–Capel partition sum: -# -# Z_scaled(r) = sum_{s=0}^C exp( θ_part(s) + s*r - M(r) ) -# -# where -# θ_part(s) = θ_lin * (s - ref) + θ_quad * (s - ref)^2 -# and -# M(r) = max_s ( θ_part(s) + s*r ). -# -# We compare two computational approaches: -# -# SAFE = Direct computation : sum_s exp(θ_part + s*r - M(r)) -# FAST = Power-chain precompute : sum_s exp(θ_part(s)) * exp(s*r - M(r)) -# -# MPFR (256-bit) is used as the ground-truth reference. -# -# The goals are: -# -# 1. Determine each method's numerical stability across a wide range -# of (max_cat, ref, θ_lin, θ_quad, r). -# -# 2. Map all cases where FAST becomes inaccurate or produces NaN. -# -# 3. Identify the correct switching rule for the C++ implementation: -# if (bound <= ~709) use FAST -# else use SAFE -# -# where `bound = M(r)` is the maximum exponent before rescaling. -# -# 4. Produce plots and summary statistics to permanently document the -# reasoning behind this rule. -# -# Key numerical fact -# ------------------ -# exp(x) in IEEE double precision overflows at x ≈ 709.782712893. -# Therefore any exponent near ±709 is dangerous. -# -# Outcome summary -# --------------- -# - SAFE is stable across the entire tested range. -# - FAST is perfectly accurate **as long as bound ≤ ~709** -# - All FAST failures (NaN or large error) occur only when bound > ~709 -# - No FAST failures were observed below this threshold. -# -# This provides strong empirical justification for the C++ switching rule. -############################################################ - -library(Rmpfr) -library(dplyr) -library(ggplot2) - -############################################################ -# 1. Simulation function -# -# Simulates FAST vs SAFE across: -# - parameter grid (max_cat, ref, θ_lin, θ_quad) -# - range of r values -# -# Returns one large data.frame containing: -# - the computed bound M(r) -# - FAST and SAFE values -# - MPFR reference -# - relative errors -# - logical OK flags (err < tol) -############################################################ - -simulate_bc_fast_safe <- function(param_grid, - r_vals = seq(-80, 80, length.out = 2000), - mpfr_prec = 256, - tol = 1e-12) { - - if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { - stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") - } - - out_list <- vector("list", nrow(param_grid)) - - for (cfg_idx in seq_len(nrow(param_grid))) { - cfg <- param_grid[cfg_idx, ] - max_cat <- as.integer(cfg$max_cat) - ref <- as.integer(cfg$ref) - theta_lin <- as.numeric(cfg$theta_lin) - theta_quad <- as.numeric(cfg$theta_quad) - - # Score grid and θ(s) - scores <- 0:max_cat - centered <- scores - ref - theta_part <- theta_lin * centered + theta_quad * centered^2 - exp_m <- exp(theta_part) # used by FAST - - # Build MPFR constants - tl_mpfr <- mpfr(theta_lin, mpfr_prec) - tq_mpfr <- mpfr(theta_quad, mpfr_prec) - sc_center_mpfr <- mpfr(centered, mpfr_prec) - sc_raw_mpfr <- mpfr(scores, mpfr_prec) - - n_r <- length(r_vals) - - res_cfg <- data.frame( - config_id = rep(cfg_idx, n_r), - max_cat = rep(max_cat, n_r), - ref = rep(ref, n_r), - theta_lin = rep(theta_lin, n_r), - theta_quad = rep(theta_quad, n_r), - r = r_vals, - bound = NA_real_, - fast_val = NA_real_, - safe_val = NA_real_, - err_fast = NA_real_, - err_safe = NA_real_, - ok_fast = NA, - ok_safe = NA, - ref_scaled = NA_real_, - stringsAsFactors = FALSE - ) - - # Compute for all r - for (i in seq_along(r_vals)) { - r <- r_vals[i] - - term <- theta_part + scores * r # θ(s) + s*r - term_max <- max(term) # numerical bound - res_cfg$bound[i] <- term_max - - # MPFR scaled reference - r_mpfr <- mpfr(r, mpfr_prec) - term_mpfr <- tl_mpfr * sc_center_mpfr + - tq_mpfr * sc_center_mpfr * sc_center_mpfr + - sc_raw_mpfr * r_mpfr - term_max_mpfr <- mpfr(max(asNumeric(term_mpfr)), mpfr_prec) - ref_scaled_mpfr <- sum(exp(term_mpfr - term_max_mpfr)) - ref_scaled_num <- asNumeric(ref_scaled_mpfr) - res_cfg$ref_scaled[i] <- ref_scaled_num - - # SAFE method: direct evaluation - safe_sum <- 0.0 - for (j in seq_along(scores)) { - safe_sum <- safe_sum + exp(theta_part[j] + scores[j] * r - term_max) - } - res_cfg$safe_val[i] <- safe_sum - - # FAST method: preexp power-chain - eR <- exp(r) - pow_b <- exp(-term_max) - fast_sum <- 0.0 - for (j in seq_along(scores)) { - fast_sum <- fast_sum + exp_m[j] * pow_b - pow_b <- pow_b * eR - } - res_cfg$fast_val[i] <- fast_sum - - # Relative errors - if (is.finite(ref_scaled_num) && ref_scaled_num > 0) { - res_cfg$err_safe[i] <- abs(safe_sum - ref_scaled_num) / ref_scaled_num - res_cfg$err_fast[i] <- abs(fast_sum - ref_scaled_num) / ref_scaled_num - } - - res_cfg$ok_safe[i] <- !is.na(res_cfg$err_safe[i]) && - is.finite(res_cfg$err_safe[i]) && - (res_cfg$err_safe[i] < tol) - - res_cfg$ok_fast[i] <- !is.na(res_cfg$err_fast[i]) && - is.finite(res_cfg$err_fast[i]) && - (res_cfg$err_fast[i] < tol) - } - - out_list[[cfg_idx]] <- res_cfg - } - - do.call(rbind, out_list) -} - -############################################################ -# 2. Parameter grid and simulation -############################################################ - -param_grid <- expand.grid( - max_cat = c(10), - ref = c(0, 5, 10), - theta_lin = c(-0.5, 0.0, 0.5), - theta_quad = c(-0.2, 0.0, 0.2), - KEEP.OUT.ATTRS = FALSE, - stringsAsFactors = FALSE -) - -# Very wide r-range so that bound covers deep negative and deep positive -r_vals <- seq(-100, 100, length.out = 5001) - -tol <- 1e-12 - -sim_res <- simulate_bc_fast_safe( - param_grid = param_grid, - r_vals = r_vals, - mpfr_prec = 256, - tol = tol -) - -############################################################ -# 3. Post-processing: classify regions, log-errors, abs(bound) -############################################################ - -df <- sim_res %>% - mutate( - err_fast_clipped = pmax(err_fast, 1e-300), - err_safe_clipped = pmax(err_safe, 1e-300), - - log_err_fast = log10(err_fast_clipped), - log_err_safe = log10(err_safe_clipped), - - abs_bound = abs(bound), - - region = case_when( - ok_fast & ok_safe ~ "both_ok", - !ok_fast & ok_safe ~ "only_safe_ok", - ok_fast & !ok_safe ~ "only_fast_ok", - TRUE ~ "neither_ok" - ) - ) - -############################################################ -# 4. NaN analysis for FAST -# -# We explicitly check: -# -# Are there *any* NaN occurrences for FAST with |bound| < 709 ? -# -# This is essential: if NaN occurs for FAST even when |bound| is small, -# then the switching rule would fail. -############################################################ - -df_nan <- sim_res %>% filter(is.nan(err_fast)) - -nan_summary <- df_nan %>% - summarise( - n_nan = n(), - min_bound = min(bound, na.rm = TRUE), - max_bound = max(bound, na.rm = TRUE) - ) - -print(nan_summary) - -df_nan_inside <- df_nan %>% filter(abs(bound) < 709) - -cat("\nNumber of FAST NaN cases with |bound| < 709: ", - nrow(df_nan_inside), "\n\n") - -############################################################ -# 5. FAST and SAFE plots vs bound -# -# We also explicitly count how many cases fail (ok_* == FALSE) -# while |bound| < 709. If the switching rule is correct, this -# number should be zero for FAST in the region where we intend -# to use it. -############################################################ - -# Count failures for FAST and SAFE when |bound| < 709 -fast_fail_inside <- df %>% - filter(abs(bound) < 709, !ok_fast) %>% - nrow() - -safe_fail_inside <- df %>% - filter(abs(bound) < 709, !ok_safe) %>% - nrow() - -cat("\nFAST failures with |bound| < 709:", fast_fail_inside, "\n") -cat("SAFE failures with |bound| < 709:", safe_fail_inside, "\n\n") - -# FAST -ggplot(df, aes(x = bound, y = log_err_fast, colour = region)) + - geom_point(alpha = 0.3, size = 0.6, na.rm = TRUE) + - geom_hline(yintercept = log10(tol), linetype = 2) + - geom_vline(xintercept = 709, linetype = 2) + - geom_vline(xintercept = -709, linetype = 2) + - scale_color_manual(values = c( - both_ok = "darkgreen", - only_safe_ok = "orange", - only_fast_ok = "blue", - neither_ok = "red" - )) + - labs( - x = "bound = max_s (theta_part(s) + s*r)", - y = "log10(relative error) of FAST", - colour = "region", - subtitle = paste( - "FAST failures with |bound| < 709:", fast_fail_inside - ) - ) + - ggtitle("FAST method vs bound") + - theme_minimal() - -# SAFE -ggplot(df, aes(x = bound, y = log_err_safe, colour = region)) + - geom_point(alpha = 0.3, size = 0.6, na.rm = TRUE) + - geom_hline(yintercept = log10(tol), linetype = 2) + - geom_vline(xintercept = 709, linetype = 2) + - geom_vline(xintercept = -709, linetype = 2) + - scale_color_manual(values = c( - both_ok = "darkgreen", - only_safe_ok = "orange", - only_fast_ok = "blue", - neither_ok = "red" - )) + - labs( - x = "bound = max_s (theta_part(s) + s*r)", - y = "log10(relative error) of SAFE", - colour = "region", - subtitle = paste( - "SAFE failures with |bound| < 709:", safe_fail_inside - ) - ) + - ggtitle("SAFE method vs bound") + - theme_minimal() - - -############################################################ -# 6. Fraction of configurations per |bound|-bin -############################################################ - -df_bins <- df %>% - filter(is.finite(bound)) %>% - mutate( - abs_bound = abs(bound), - bound_bin = cut( - abs_bound, - breaks = seq(0, max(abs_bound, na.rm = TRUE) + 10, by = 10), - include_lowest = TRUE - ) - ) %>% - group_by(bound_bin) %>% - summarise( - mid_abs_bound = mean(abs_bound, na.rm = TRUE), - frac_fast_ok = mean(ok_fast, na.rm = TRUE), - frac_safe_ok = mean(ok_safe, na.rm = TRUE), - n = n(), - .groups = "drop" - ) - -ggplot(df_bins, aes(x = mid_abs_bound)) + - geom_line(aes(y = frac_fast_ok, colour = "FAST ok")) + - geom_line(aes(y = frac_safe_ok, colour = "SAFE ok")) + - geom_vline(xintercept = 709, linetype = 2) + - scale_colour_manual(values = c("FAST ok" = "blue", "SAFE ok" = "darkgreen")) + - labs( - x = "|bound| bin center", - y = "fraction of configurations with err < tol", - colour = "" - ) + - ggtitle("FAST vs SAFE numerical stability by |bound|") + - theme_minimal() - -############################################################ -# 7. Summary printed to console -############################################################ - -cat("\n================ SUMMARY =================\n") -print(nan_summary) - -cat("\nFAST NaN cases with |bound| < 709: ", - nrow(df_nan_inside), "\n\n") - -cat(" -Interpretation: --------------- -- The SAFE method (direct + bound) remains stable and accurate across the - entire tested parameter and residual range. - -- The FAST method (preexp + bound) is extremely accurate when the maximum - exponent before rescaling, `bound = M(r)`, satisfies: - - |bound| ≤ ~709 - -- As soon as bound exceeds approximately +709, FAST becomes unstable: - * large numerical error - * or NaN (observed systematically) - * No such failures appear below this threshold. - -C++ Implementation Rule (recommended): --------------------------------------- -if (bound <= 709.0) { - // FAST: preexp + bound (power-chain) -} else { - // SAFE: direct + bound -} - -This script constitutes the full reproducible analysis supporting the choice -of this switching threshold in the C++ Blume–Capel normalization code. -") - -############################################################ -# End of script -############################################################ - - - - - - - - -############################################################ -# Blume–Capel probability analysis: -# Numerical comparison of FAST vs SAFE probability evaluation -# -# Objective -# --------- -# This script provides a numerical investigation of two methods -# to compute the *probabilities* under the Blume–Capel -# pseudolikelihood: -# -# p_s(r) = exp( θ_part(s) + s*r - M(r) ) / Z_scaled(r) -# -# where -# θ_part(s) = θ_lin * (s - ref) + θ_quad * (s - ref)^2 -# M(r) = max_s ( θ_part(s) + s*r ) -# Z_scaled = sum_s exp( θ_part(s) + s*r - M(r) ) -# -# We compare two implementations: -# -# SAFE = direct exponentials with numerical bound M(r) -# FAST = preexp + power-chain for exp(s*r - M(r)) -# -# MPFR (256-bit) is used as the ground-truth reference. -# -# Goals -# ----- -# 1. Check numerical stability of SAFE vs FAST for probabilities -# across wide ranges of (max_cat, ref, θ_lin, θ_quad, r). -# 2. Confirm that the same switching rule used for the -# normalization carries over safely to probabilities: -# -# FAST is used only if -# -# |M(r)| <= EXP_BOUND AND -# pow_bound = max_cat * r - M(r) <= EXP_BOUND -# -# where EXP_BOUND ≈ 709. -# -# 3. Document the error behaviour in terms of: -# - max absolute difference per probability vector -# - max relative difference -# - KL divergence to MPFR reference. -# -# Outcome (to be checked empirically) -# ----------------------------------- -# - SAFE should be stable across the tested ranges. -# - FAST should exhibit negligible error whenever the -# switching bounds are satisfied. -# -############################################################ - -library(Rmpfr) -library(dplyr) -library(ggplot2) - -EXP_BOUND <- 709 # double overflow limit for exp() - - -############################################################ -# 1. Helper: MPFR probability reference for a single config -############################################################ - -bc_prob_ref_mpfr <- function(max_cat, ref, theta_lin, theta_quad, - r_vals, - mpfr_prec = 256) { - # Categories and centered scores - scores <- 0:max_cat - centered <- scores - ref - - # MPFR parameters - tl <- mpfr(theta_lin, mpfr_prec) - tq <- mpfr(theta_quad, mpfr_prec) - sc <- mpfr(scores, mpfr_prec) - s0 <- mpfr(centered, mpfr_prec) - - n_r <- length(r_vals) - n_s <- length(scores) - - # reference probability matrix (rows = r, cols = s) - P_ref <- matrix(NA_real_, nrow = n_r, ncol = n_s) - - for (i in seq_len(n_r)) { - r_mp <- mpfr(r_vals[i], mpfr_prec) - - # exponent(s) = θ_part(s) + s*r - term <- tl * s0 + tq * s0 * s0 + sc * r_mp - - # numeric bound M(r) - term_max <- max(asNumeric(term)) - term_max_mp <- mpfr(term_max, mpfr_prec) - - num <- exp(term - term_max_mp) # scaled numerators - Z <- sum(num) - p <- num / Z - - P_ref[i, ] <- asNumeric(p) - } - - P_ref -} - - -############################################################ -# 2. SAFE probabilities (double) for a single config -############################################################ - -bc_prob_safe <- function(max_cat, ref, theta_lin, theta_quad, - r_vals) { - scores <- 0:max_cat - centered <- scores - ref - theta_part <- theta_lin * centered + theta_quad * centered^2 - - n_r <- length(r_vals) - n_s <- length(scores) - - P_safe <- matrix(NA_real_, nrow = n_r, ncol = n_s) - - for (i in seq_len(n_r)) { - r <- r_vals[i] - - # exponents before scaling - exps <- theta_part + scores * r - b <- max(exps) - - numer <- exp(exps - b) - denom <- sum(numer) - - # NO fallback here; let denom=0 or non-finite propagate - p <- numer / denom - - P_safe[i, ] <- p - } - - P_safe -} - - - -############################################################ -# 3. FAST probabilities (double) for a single config -# -# This mirrors what a C++ compute_probs_blume_capel(FAST) -# implementation would do: precompute exp(theta_part), -# then use a power chain for exp(s*r - b). -############################################################ - -bc_prob_fast <- function(max_cat, ref, theta_lin, theta_quad, - r_vals) { - scores <- 0:max_cat - centered <- scores - ref - theta_part <- theta_lin * centered + theta_quad * centered^2 - exp_theta <- exp(theta_part) - - n_r <- length(r_vals) - n_s <- length(scores) - - P_fast <- matrix(NA_real_, nrow = n_r, ncol = n_s) - bounds <- numeric(n_r) - pow_bounds <- numeric(n_r) - - for (i in seq_len(n_r)) { - r <- r_vals[i] - - # exponents before scaling - exps <- theta_part + scores * r - b <- max(exps) - bounds[i] <- b - - # pow_bound = max_s (s*r - b) is attained at s = max_cat - pow_bounds[i] <- max_cat * r - b - - eR <- exp(r) - pow <- exp(-b) - - numer <- numeric(n_s) - denom <- 0.0 - - for (j in seq_along(scores)) { - numer[j] <- exp_theta[j] * pow - denom <- denom + numer[j] - pow <- pow * eR - } - - # Again: NO fallback, just divide and let problems show - p <- numer / denom - - P_fast[i, ] <- p - } - - list( - probs = P_fast, - bound = bounds, - pow_bound = pow_bounds - ) -} - - - -############################################################ -# 4. Main simulation: -# Explore param_grid × r_vals and compare: -# - P_ref (MPFR) -# - P_safe -# - P_fast -############################################################ - -simulate_bc_prob_fast_safe <- function(param_grid, - r_vals, - mpfr_prec = 256, - tol_prob = 1e-12) { - - if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { - stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") - } - - out_list <- vector("list", nrow(param_grid)) - - for (cfg_idx in seq_len(nrow(param_grid))) { - cfg <- param_grid[cfg_idx, ] - max_cat <- as.integer(cfg$max_cat) - ref <- as.integer(cfg$ref) - theta_lin <- as.numeric(cfg$theta_lin) - theta_quad <- as.numeric(cfg$theta_quad) - - n_r <- length(r_vals) - - # Reference - P_ref <- bc_prob_ref_mpfr(max_cat, ref, theta_lin, theta_quad, - r_vals, mpfr_prec = mpfr_prec) - # SAFE - P_safe <- bc_prob_safe(max_cat, ref, theta_lin, theta_quad, - r_vals) - # FAST (+ bounds) - fast_res <- bc_prob_fast(max_cat, ref, theta_lin, theta_quad, - r_vals) - P_fast <- fast_res$probs - bound <- fast_res$bound - pow_bound <- fast_res$pow_bound - - # Error metrics per r - max_abs_fast <- numeric(n_r) - max_rel_fast <- numeric(n_r) - kl_fast <- numeric(n_r) - - max_abs_safe <- numeric(n_r) - max_rel_safe <- numeric(n_r) - kl_safe <- numeric(n_r) - - # Helper: KL divergence D(p || q) - kl_div <- function(p, q) { - # If either vector has non-finite entries, KL is undefined → NA - if (!all(is.finite(p)) || !all(is.finite(q))) { - return(NA_real_) - } - - # Valid domain for KL: where both p and q are strictly positive - mask <- (p > 0) & (q > 0) - - # mask may contain NA → remove NA via na.rm=TRUE - if (!any(mask, na.rm = TRUE)) { - return(NA_real_) - } - - sum(p[mask] * (log(p[mask]) - log(q[mask]))) - } - - - for (i in seq_len(n_r)) { - p_ref <- P_ref[i, ] - p_safe <- P_safe[i, ] - p_fast <- P_fast[i, ] - - # max abs diff - max_abs_fast[i] <- max(abs(p_fast - p_ref)) - max_abs_safe[i] <- max(abs(p_safe - p_ref)) - - # max relative diff (avoid divide-by-zero) - rel_fast <- abs(p_fast - p_ref) - rel_safe <- abs(p_safe - p_ref) - - rel_fast[p_ref > 0] <- rel_fast[p_ref > 0] / p_ref[p_ref > 0] - rel_safe[p_ref > 0] <- rel_safe[p_ref > 0] / p_ref[p_ref > 0] - - rel_fast[p_ref == 0] <- 0 - rel_safe[p_ref == 0] <- 0 - - max_rel_fast[i] <- max(rel_fast) - max_rel_safe[i] <- max(rel_safe) - - # KL - kl_fast[i] <- kl_div(p_ref, p_fast) - kl_safe[i] <- kl_div(p_ref, p_safe) - } - - # "ok" flags using tol_prob on max_abs - ok_fast <- is.finite(max_abs_fast) & (max_abs_fast < tol_prob) - ok_safe <- is.finite(max_abs_safe) & (max_abs_safe < tol_prob) - - # FAST switching condition as in C++: - # use FAST only if |bound| <= EXP_BOUND and pow_bound <= EXP_BOUND - use_fast <- (abs(bound) <= EXP_BOUND) & (pow_bound <= EXP_BOUND) - - res_cfg <- data.frame( - config_id = rep(cfg_idx, n_r), - max_cat = rep(max_cat, n_r), - ref = rep(ref, n_r), - theta_lin = rep(theta_lin, n_r), - theta_quad = rep(theta_quad, n_r), - r = r_vals, - bound = bound, - pow_bound = pow_bound, - use_fast = use_fast, - max_abs_fast = max_abs_fast, - max_rel_fast = max_rel_fast, - kl_fast = kl_fast, - max_abs_safe = max_abs_safe, - max_rel_safe = max_rel_safe, - kl_safe = kl_safe, - ok_fast = ok_fast, - ok_safe = ok_safe, - stringsAsFactors = FALSE - ) - - out_list[[cfg_idx]] <- res_cfg - } - - do.call(rbind, out_list) -} - - -############################################################ -# 5. Example simulation setup -############################################################ - -# Parameter grid similar in spirit to the BC normalization script -param_grid <- expand.grid( - max_cat = c(4, 10), # Blume–Capel max categories (example) - ref = c(0, 2, 4, 5, 10), # include both interior & boundary refs - theta_lin = c(-0.5, 0.0, 0.5), - theta_quad = c(-0.2, 0.0, 0.2), - KEEP.OUT.ATTRS = FALSE, - stringsAsFactors = FALSE -) - -# Wide r-range; adjust as needed to match your empirical residuals -r_vals <- seq(-80, 80, length.out = 2001) - -tol_prob <- 1e-12 - -sim_probs <- simulate_bc_prob_fast_safe( - param_grid = param_grid, - r_vals = r_vals, - mpfr_prec = 256, - tol_prob = tol_prob -) - - -############################################################ -# 6. Post-processing and diagnostics -############################################################ - -df <- sim_probs %>% - mutate( - abs_bound = abs(bound), - region = case_when( - use_fast & ok_fast ~ "fast_ok_when_used", - use_fast & !ok_fast ~ "fast_bad_when_used", - !use_fast & ok_safe ~ "safe_ok_when_used", - !use_fast & !ok_safe ~ "safe_bad_when_used" - ) - ) - -# Check: any bad FAST cases *within* the intended FAST region? -fast_bad_inside <- df %>% - filter(use_fast, !ok_fast) - -cat("\nNumber of FAST probability failures where use_fast == TRUE: ", - nrow(fast_bad_inside), "\n\n") - -# Also track purely based on bounds (even if not marked use_fast) -fast_bad_bound_region <- df %>% - filter(abs(bound) <= EXP_BOUND, - pow_bound <= EXP_BOUND, - !ok_fast) - -cat("Number of FAST probability failures with |bound| <= 709 & pow_bound <= 709: ", - nrow(fast_bad_bound_region), "\n\n") - - -############################################################ -# 7. Plots: error vs bound (FAST only) -############################################################ - -df_fast <- df %>% - filter(use_fast) %>% - mutate( - log10_max_abs_fast = log10(pmax(max_abs_fast, 1e-300)) - ) - -ggplot(df_fast, aes(x = bound, y = log10_max_abs_fast)) + - geom_point(alpha = 0.3, size = 0.6) + - geom_hline(yintercept = log10(tol_prob), linetype = 2, colour = "darkgreen") + - geom_vline(xintercept = EXP_BOUND, linetype = 2, colour = "red") + - geom_vline(xintercept = -EXP_BOUND, linetype = 2, colour = "red") + - labs( - x = "bound = max_s (θ_part(s) + s*r)", - y = "log10(max absolute error) of FAST p_s(r)", - title = "FAST Blume–Capel probabilities vs bound (used region only)", - subtitle = paste( - "FAST failures in use_fast region:", nrow(fast_bad_inside) - ) - ) + - theme_minimal() - - -############################################################ -# 8. Binned summary by |bound| -############################################################ - -df_bins <- df %>% - mutate( - abs_bound = abs(bound), - bound_bin = cut( - abs_bound, - breaks = seq(0, max(abs_bound, na.rm = TRUE) + 10, by = 10), - include_lowest = TRUE - ) - ) %>% - group_by(bound_bin) %>% - summarise( - mid_abs_bound = mean(abs_bound, na.rm = TRUE), - frac_fast_ok = mean(ok_fast[use_fast], na.rm = TRUE), - frac_safe_ok = mean(ok_safe[!use_fast], na.rm = TRUE), - max_abs_fast_99 = quantile(max_abs_fast[use_fast], 0.99, na.rm = TRUE), - max_abs_safe_99 = quantile(max_abs_safe[!use_fast], 0.99, na.rm = TRUE), - n = n(), - .groups = "drop" - ) - -ggplot(df_bins, aes(x = mid_abs_bound)) + - geom_line(aes(y = frac_fast_ok, colour = "FAST ok (used)"), na.rm = TRUE) + - geom_line(aes(y = frac_safe_ok, colour = "SAFE ok (used)"), na.rm = TRUE) + - geom_vline(xintercept = EXP_BOUND, linetype = 2) + - scale_colour_manual(values = c( - "FAST ok (used)" = "blue", - "SAFE ok (used)" = "darkgreen" - )) + - labs( - x = "|bound| bin center", - y = "fraction of configurations with max_abs_error < tol_prob", - colour = "", - title = "Numerical stability of Blume–Capel probabilities by |bound|" - ) + - theme_minimal() - - -############################################################ -# 9. Console summary -############################################################ - -cat("\n================ PROBABILITY SUMMARY =================\n") - -cat("Total rows in simulation:", nrow(df), "\n\n") - -cat("FAST probability failures where use_fast == TRUE: ", - nrow(fast_bad_inside), "\n") -cat("FAST probability failures with |bound| <= 709 & pow_bound <= 709: ", - nrow(fast_bad_bound_region), "\n\n") - -cat("Typical 99th percentile max_abs_error per |bound|-bin (FAST used):\n") -print( - df_bins %>% - select(bound_bin, mid_abs_bound, max_abs_fast_99) %>% - arrange(mid_abs_bound), - digits = 4 -) - -cat(" -Interpretation guide --------------------- -- `ok_fast`/`ok_safe` are defined by max absolute error vs MPFR reference - being below tol_prob (default 1e-12). - -- `use_fast` encodes the **intended** C++ switching rule: - use_fast = (|bound| <= 709) & (pow_bound <= 709) - -- Ideally: - * `fast_bad_inside` should be empty or extremely rare, - showing that FAST is safe whenever used. - * errors for SAFE should be negligible everywhere. - -You can tighten the switching margin if needed (e.g. require -`pow_bound <= 700`) by adjusting `use_fast` in the code above. -") \ No newline at end of file diff --git a/dev/numerical_analyses/bgm_blumecapel_probs_PL.R b/dev/numerical_analyses/bgm_blumecapel_probs_PL.R deleted file mode 100644 index 6d887ba4..00000000 --- a/dev/numerical_analyses/bgm_blumecapel_probs_PL.R +++ /dev/null @@ -1,248 +0,0 @@ -############################################################ -# Blume–Capel probabilities: -# Numerical comparison of 4 methods vs MPFR reference -# -# Methods: -# - direct_unscaled : naive softmax -# - direct_bound : softmax with subtraction of M(r) -# - preexp_unscaled : preexp(theta_part) + power chain (no bound) -# - preexp_bound : preexp(theta_part) + power chain (with bound) -# -# Reference: -# - MPFR softmax with scaling by M(r) -############################################################ - -library(Rmpfr) -library(dplyr) -library(ggplot2) - -EXP_BOUND <- 709 - -############################################################ -# 1. Compare 4 methods for one BC configuration -############################################################ - -compare_bc_prob_4methods_one <- function(max_cat, - ref, - theta_lin, - theta_quad, - r_vals, - mpfr_prec = 256) { - - s_vals <- 0:max_cat - c_vals <- s_vals - ref - n_s <- length(s_vals) - n_r <- length(r_vals) - - # theta_part(s) - theta_part_num <- theta_lin * c_vals + theta_quad * c_vals^2 - - # MPFR parameters - tl_mp <- mpfr(theta_lin, mpfr_prec) - tq_mp <- mpfr(theta_quad, mpfr_prec) - s_mp <- mpfr(s_vals, mpfr_prec) - c_mp <- mpfr(c_vals, mpfr_prec) - - # Precompute for preexp methods - exp_theta <- exp(theta_part_num) - - res <- data.frame( - r = r_vals, - bound = NA_real_, - pow_bound = NA_real_, - err_direct = NA_real_, - err_bound = NA_real_, - err_preexp = NA_real_, - err_preexp_bound= NA_real_ - ) - - for (i in seq_len(n_r)) { - r <- r_vals[i] - r_mp <- mpfr(r, mpfr_prec) - - ## MPFR reference probabilities (softmax with scaling) - term_mp <- tl_mp * c_mp + - tq_mp * c_mp * c_mp + - s_mp * r_mp - - M_num <- max(asNumeric(term_mp)) - M_mp <- mpfr(M_num, mpfr_prec) - - num_ref_mp <- exp(term_mp - M_mp) - Z_ref_mp <- sum(num_ref_mp) - p_ref_mp <- num_ref_mp / Z_ref_mp - p_ref <- asNumeric(p_ref_mp) - - ## Double: exponents - term_num <- theta_part_num + s_vals * r - M <- max(term_num) - res$bound[i] <- M - res$pow_bound[i] <- max_cat * r - M - - ## (1) direct_unscaled - num_dir <- exp(term_num) - den_dir <- sum(num_dir) - p_dir <- num_dir / den_dir - - ## (2) direct_bound - num_b <- exp(term_num - M) - den_b <- sum(num_b) - p_b <- num_b / den_b - - ## (3) preexp_unscaled - eR <- exp(r) - pow <- eR - num_pre <- numeric(n_s) - den_pre <- 0.0 - - # s = 0 term - num_pre[1] <- exp_theta[1] * 1.0 - den_pre <- den_pre + num_pre[1] - - if (max_cat >= 1) { - for (s in 1:max_cat) { - num_pre[s + 1] <- exp_theta[s + 1] * pow - den_pre <- den_pre + num_pre[s + 1] - pow <- pow * eR - } - } - p_pre <- num_pre / den_pre - - ## (4) preexp_bound - eR2 <- exp(r) - pow_b <- exp(-M) - num_preB <- numeric(n_s) - den_preB <- 0.0 - - for (s in 0:max_cat) { - idx <- s + 1 - num_preB[idx] <- exp_theta[idx] * pow_b - den_preB <- den_preB + num_preB[idx] - pow_b <- pow_b * eR2 - } - p_preB <- num_preB / den_preB - - ## Relative errors vs MPFR reference on non-negligible support - tau <- 1e-15 # <-- tweak this - - support_mask <- p_ref >= tau - if (!any(support_mask)) { - support_mask <- p_ref == max(p_ref) # degenerate case: all tiny, pick the max - } - - rel_direct <- abs(p_dir - p_ref)[support_mask] / p_ref[support_mask] - rel_bound <- abs(p_b - p_ref)[support_mask] / p_ref[support_mask] - rel_preexp <- abs(p_pre - p_ref)[support_mask] / p_ref[support_mask] - rel_preB <- abs(p_preB - p_ref)[support_mask] / p_ref[support_mask] - - res$err_direct[i] <- max(rel_direct) - res$err_bound[i] <- max(rel_bound) - res$err_preexp[i] <- max(rel_preexp) - res$err_preexp_bound[i] <- max(rel_preB) - - - - } - - res -} - -############################################################ -# 2. Sweep across param_grid × r_vals -############################################################ - -simulate_bc_prob_4methods <- function(param_grid, - r_vals, - mpfr_prec = 256, - tol = 1e-12) { - - if (!all(c("max_cat", "ref", "theta_lin", "theta_quad") %in% names(param_grid))) { - stop("param_grid must have columns: max_cat, ref, theta_lin, theta_quad") - } - - out_list <- vector("list", nrow(param_grid)) - - for (cfg_idx in seq_len(nrow(param_grid))) { - cfg <- param_grid[cfg_idx, ] - - res_cfg <- compare_bc_prob_4methods_one( - max_cat = cfg$max_cat, - ref = cfg$ref, - theta_lin = cfg$theta_lin, - theta_quad = cfg$theta_quad, - r_vals = r_vals, - mpfr_prec = mpfr_prec - ) - - res_cfg$config_id <- cfg_idx - res_cfg$max_cat <- cfg$max_cat - res_cfg$ref <- cfg$ref - res_cfg$theta_lin <- cfg$theta_lin - res_cfg$theta_quad <- cfg$theta_quad - - # simple ok flags - res_cfg$ok_direct <- is.finite(res_cfg$err_direct) & (res_cfg$err_direct < tol) - res_cfg$ok_bound <- is.finite(res_cfg$err_bound) & (res_cfg$err_bound < tol) - res_cfg$ok_preexp <- is.finite(res_cfg$err_preexp) & (res_cfg$err_preexp < tol) - res_cfg$ok_preexp_bound <- is.finite(res_cfg$err_preexp_bound) & (res_cfg$err_preexp_bound < tol) - - out_list[[cfg_idx]] <- res_cfg - } - - do.call(rbind, out_list) -} - -############################################################ -# 3. Example broad analysis (you can adjust this) -############################################################ - -param_grid <- expand.grid( - max_cat = c(4, 10), - ref = c(0, 2, 4, 5, 10), - theta_lin = c(-0.5, 0.0, 0.5), - theta_quad = c(-0.2, 0.0, 0.2), - KEEP.OUT.ATTRS = FALSE, - stringsAsFactors = FALSE -) - -r_vals <- seq(-80, 80, length.out = 2001) -tol <- 1e-12 - -sim4 <- simulate_bc_prob_4methods( - param_grid = param_grid, - r_vals = r_vals, - mpfr_prec = 256, - tol = tol -) - -############################################################ -# 4. Summaries: where each method fails, as a function of bound/pow_bound -############################################################ - -df4 <- sim4 %>% - mutate( - abs_bound = abs(bound), - err_direct_cl = pmax(err_direct, 1e-300), - err_bound_cl = pmax(err_bound, 1e-300), - err_preexp_cl = pmax(err_preexp, 1e-300), - err_preexp_bound_cl = pmax(err_preexp_bound, 1e-300), - log_err_direct = log10(err_direct_cl), - log_err_bound = log10(err_bound_cl), - log_err_preexp = log10(err_preexp_cl), - log_err_preexp_bound= log10(err_preexp_bound_cl) - ) - -# Example: failures for each method inside |bound| <= 709 & pow_bound <= 709 -inside <- df4 %>% - filter(abs(bound) <= EXP_BOUND, pow_bound <= EXP_BOUND) - -n_direct_fail <- sum(!inside$ok_direct) -n_bound_fail <- sum(!inside$ok_bound) -n_preexp_fail <- sum(!inside$ok_preexp) -n_preexp_bound_fail <- sum(!inside$ok_preexp_bound) - -cat("\nFailures inside fast region (|bound| <= 709 & pow_bound <= 709):\n") -cat(" direct_unscaled :", n_direct_fail, "\n") -cat(" direct_bound :", n_bound_fail, "\n") -cat(" preexp_unscaled :", n_preexp_fail, "\n") -cat(" preexp_bound (FAST) :", n_preexp_bound_fail, "\n\n") diff --git a/dev/numerical_analyses/bgm_regularordinal_normalization_PL.R b/dev/numerical_analyses/bgm_regularordinal_normalization_PL.R deleted file mode 100644 index 8a6219b1..00000000 --- a/dev/numerical_analyses/bgm_regularordinal_normalization_PL.R +++ /dev/null @@ -1,695 +0,0 @@ -################################################################################ -# Reference: Numerical stability study for bounded vs. unbounded exponential sums -# Author: [Your Name] -# Date: [YYYY-MM-DD] -# -# Purpose: -# Evaluate and compare four ways to compute the sum -# -# S = 1 + Σ_{c=1..K} exp( m_c + (c+1)*r ) -# -# where r may vary widely. The goal is to identify numerically stable and -# computationally efficient formulations for use in gradient calculations. -# -# Methods compared: -# (1) direct – naive computation using raw exp() -# (2) bounded – stabilized by subtracting a "bound" (i.e., scaled domain) -# (3) preexp – precomputes exp(m_c) and exp(r) to replace repeated calls -# (4) preexp_bound – preexp variant with the same "bound" scaling -# -# For each method, we compute both unscaled and scaled variants where relevant, -# and compare them against a high-precision MPFR reference. -# -# Key insight: -# - For large negative r, preexp can lose precision (tiny multiplicative updates). -# - For large positive r, bounded scaling avoids overflow. -# - The combination (preexp + bound) gives the best general stability. -# -# Output: -# - res: data frame with per-r results and relative errors -# - Diagnostic plots and summary tables for numerical accuracy -################################################################################ - -library(Rmpfr) # for arbitrary precision reference computations - - -################################################################################ -# 1. Core comparison function -################################################################################ -compare_all_methods <- function(K = 5, - r_vals = seq(-10, 10, length.out = 200), - m_vals = NULL, - mpfr_prec = 256) { - # --------------------------------------------------------------------------- - # Parameters: - # K – number of categories (terms in the sum) - # r_vals – vector of r values to evaluate over - # m_vals – optional vector of m_c values; random if NULL - # mpfr_prec – bits of precision for the high-precision reference - # - # Returns: - # A data.frame containing per-r computed values, reference values, - # relative errors, and failure flags. - # --------------------------------------------------------------------------- - - if (is.null(m_vals)) m_vals <- runif(K, -1, 1) - - results <- data.frame( - r = r_vals, - direct = NA_real_, - bounded = NA_real_, # scaled-domain computation (exp(-bound) factor) - preexp = NA_real_, - preexp_bound = NA_real_, # scaled-domain computation - ref = NA_real_, # unscaled MPFR reference - ref_scaled = NA_real_, # scaled reference - err_direct = NA_real_, - err_bounded = NA_real_, - err_preexp = NA_real_, - err_preexp_bound = NA_real_, - ref_failed_unscaled = FALSE, - ref_failed_scaled = FALSE - ) - - # Loop over all r-values - for (i in seq_along(r_vals)) { - r <- r_vals[i] - bound <- K * r # can be unclipped; use max(0, K*r) for the clipped version - - # --- (0) High-precision MPFR reference ----------------------------------- - r_mp <- mpfr(r, precBits = mpfr_prec) - m_mp <- mpfr(m_vals, precBits = mpfr_prec) - b_mp <- mpfr(bound, precBits = mpfr_prec) - - ref_unscaled_mp <- 1 + sum(exp(m_mp + (1:K) * r_mp)) - ref_scaled_mp <- exp(-b_mp) * ref_unscaled_mp - - # Convert to doubles for inspection - ref_unscaled_num <- asNumeric(ref_unscaled_mp) - ref_scaled_num <- asNumeric(ref_scaled_mp) - results$ref_failed_unscaled[i] <- !is.finite(ref_unscaled_num) - results$ref_failed_scaled[i] <- !is.finite(ref_scaled_num) - results$ref[i] <- if (is.finite(ref_unscaled_num)) ref_unscaled_num else NA_real_ - results$ref_scaled[i] <- if (is.finite(ref_scaled_num)) ref_scaled_num else NA_real_ - - # --- (1) Direct exponential sum (unscaled) ------------------------------- - results$direct[i] <- 1 + sum(exp(m_vals + (1:K) * r)) - - # --- (2) Current bounded implementation (scaled) ------------------------- - eB <- exp(-bound) - results$bounded[i] <- eB + sum(exp(m_vals + (1:K) * r - bound)) - - # --- (3) Precomputed exp only (unscaled) --------------------------------- - exp_r <- exp(r) - exp_m <- exp(m_vals) - powE <- exp_r - S_pre <- 1.0 - for (c in 1:K) { - S_pre <- S_pre + exp_m[c] * powE - powE <- powE * exp_r - } - results$preexp[i] <- S_pre - - # --- (4) Precomputed exp + bound scaling (scaled) ------------------------ - exp_r <- exp(r) - exp_m <- exp(m_vals) - powE <- exp_r - S_preB <- eB - for (c in 1:K) { - S_preB <- S_preB + exp_m[c] * powE * eB - powE <- powE * exp_r - } - results$preexp_bound[i] <- S_preB - - # --- (5) Relative errors vs references ----------------------------------- - # Unscaled methods - for (m in c("direct", "preexp")) { - val <- results[[m]][i] - if (is.finite(val)) { - val_mp <- mpfr(val, precBits = mpfr_prec) - err_mp <- abs((val_mp - ref_unscaled_mp) / ref_unscaled_mp) - results[[paste0("err_", m)]][i] <- asNumeric(err_mp) - } - } - - # Scaled methods - for (m in c("bounded", "preexp_bound")) { - val <- results[[m]][i] - if (is.finite(val)) { - val_mp <- mpfr(val, precBits = mpfr_prec) - err_mp <- abs((val_mp - ref_scaled_mp) / ref_scaled_mp) - results[[paste0("err_", m)]][i] <- asNumeric(err_mp) - } - } - } - - msg_a <- mean(results$ref_failed_unscaled) - msg_b <- mean(results$ref_failed_scaled) - message(sprintf("Ref (unscaled) non-finite in %.1f%%; Ref (scaled) non-finite in %.1f%% of r-values", - 100 * msg_a, 100 * msg_b)) - results -} - - -################################################################################ -# 2. Plotting: log-scale accuracy with failure marking -################################################################################ -plot_errors <- function(res) { - err_cols <- c("err_bounded", "err_direct", "err_preexp", "err_preexp_bound") - cols <- c("gray", "black", "red", "blue") - names(cols) <- err_cols - - # Compute a robust ylim (1st–99th percentile) - finite_vals <- unlist(res[err_cols]) - finite_vals <- finite_vals[is.finite(finite_vals) & finite_vals > 0] - if (length(finite_vals) > 0) { - lower <- quantile(finite_vals, 0.01, na.rm = TRUE) - upper <- quantile(finite_vals, 0.99, na.rm = TRUE) - ylim <- c(lower / 10, upper * 10) - } else { - ylim <- c(1e-20, 1e-12) - } - - # Baseline curve: bounded - plot(res$r, res$err_bounded, type = "l", log = "y", - col = cols["err_bounded"], lwd = 2, - ylim = ylim, - xlab = "r", ylab = "Relative error", - main = "Accuracy and failure regions") - - # Add other methods - for (e in setdiff(err_cols, "err_bounded")) - lines(res$r, res[[e]], col = cols[e], lwd = 2) - - abline(h = .Machine$double.eps, col = "darkgray", lty = 2) - - legend("bottomright", - legend = c("Current bounded", "Direct exp", - "Preexp only", "Preexp + bound"), - col = cols, lwd = 2, bty = "n") - - # Mark numeric failures - for (e in err_cols) { - bad <- which(!is.finite(res[[e]]) | res[[e]] <= 0) - if (length(bad) > 0) - points(res$r[bad], rep(ylim[1], length(bad)), - pch = 21, col = cols[e], bg = cols[e], cex = 0.6) - } - - legend("bottomleft", legend = "dots = 0/Inf/NaN failures", bty = "n") -} - - -################################################################################ -# 3. Summarize accuracy across r -################################################################################ -summarize_accuracy <- function(res) { - err_cols <- c("err_direct", "err_bounded", "err_preexp", "err_preexp_bound") - - summary <- data.frame( - Method = c("Direct exp", "Current bounded", - "Preexp only", "Preexp + bound"), - Mean_error = NA_real_, - Median_error = NA_real_, - Max_error = NA_real_, - Finite_fraction = NA_real_, - Zero_or_Inf_fraction = NA_real_ - ) - - for (j in seq_along(err_cols)) { - e <- res[[err_cols[j]]] - finite_mask <- is.finite(e) & e > 0 - summary$Mean_error[j] <- mean(e[finite_mask], na.rm = TRUE) - summary$Median_error[j] <- median(e[finite_mask], na.rm = TRUE) - summary$Max_error[j] <- max(e[finite_mask], na.rm = TRUE) - summary$Finite_fraction[j] <- mean(finite_mask) - summary$Zero_or_Inf_fraction[j] <- 1 - mean(finite_mask) - } - - summary -} - - -################################################################################ -# 4. Alternate jitter plot for fine-scale comparison -################################################################################ -plot_errors_jitter <- function(res, offset_for_visibility = TRUE) { - err_cols <- c("err_bounded", "err_direct", "err_preexp", "err_preexp_bound") - cols <- c("gray", "black", "red", "blue") - - message("Plotting columns:") - for (i in seq_along(err_cols)) - message(sprintf(" %-15s -> %s", err_cols[i], cols[i])) - - offset_factor <- if (offset_for_visibility) c(1, 5, 100, 1e4) else rep(1, 4) - - finite_vals <- unlist(res[err_cols]) - finite_vals <- finite_vals[is.finite(finite_vals) & finite_vals > 0] - if (length(finite_vals) > 0) { - lower <- quantile(finite_vals, 0.01, na.rm = TRUE) - upper <- quantile(finite_vals, 0.99, na.rm = TRUE) - ylim <- c(lower / 10, upper * 10) - } else ylim <- c(1e-20, 1e-12) - - plot(res$r, res$err_bounded * offset_factor[1], - type = "l", log = "y", lwd = 2, col = cols[1], - ylim = ylim, - xlab = "r", ylab = "Relative error", - main = "Accuracy (offset for visibility)") - - for (j in 2:length(err_cols)) - lines(res$r, res[[err_cols[j]]] * offset_factor[j], col = cols[j], lwd = 2) - - abline(h = .Machine$double.eps, col = "darkgray", lty = 2) - legend("bottomright", - legend = c("Current bounded", "Direct exp", "Preexp only", "Preexp + bound"), - col = cols, lwd = 2) -} - - -################################################################################ -# 5. Example usage -################################################################################ -# Run test for a moderate K and r-range. -# Expand range (e.g. seq(-100, 80, 1)) to probe overflow/underflow limits. -# res <- compare_all_methods(K = 10, r_vals = seq(-71, 71, length.out = 1e4)) -# -# # Plot and summarize -# plot_errors(res) -# summary_table <- summarize_accuracy(res) -# print(summary_table, digits = 3) -# plot_errors_jitter(res) # optional visualization with offsets -################################################################################ - - -################################################################################ -# 6. Ratio stability check (direct vs preexp) × (bound vs clipped) -################################################################################ -compare_prob_ratios <- function(K = 5, - r_vals = seq(-20, 20, length.out = 200), - m_vals = NULL, - mpfr_prec = 256) { - - if (!requireNamespace("Rmpfr", quietly = TRUE)) - stop("Please install Rmpfr: install.packages('Rmpfr')") - - if (is.null(m_vals)) m_vals <- runif(K, -1, 1) - - res <- data.frame( - r = numeric(length(r_vals)), - err_direct_bound = numeric(length(r_vals)), - err_direct_clip = numeric(length(r_vals)), - err_preexp_bound = numeric(length(r_vals)), - err_preexp_clip = numeric(length(r_vals)) - ) - - for (i in seq_along(r_vals)) { - r <- r_vals[i] - b_raw <- K * r - b_clip <- max(0, b_raw) - - # --- High-precision reference --------------------------------------------- - r_mp <- Rmpfr::mpfr(r, precBits = mpfr_prec) - m_mp <- Rmpfr::mpfr(m_vals, precBits = mpfr_prec) - exp_terms_ref <- exp(m_mp + (1:K) * r_mp) - denom_ref <- 1 + sum(exp_terms_ref) - p_ref_num <- as.numeric(exp_terms_ref / denom_ref) - - # --- (1) Direct, un-clipped bound ---------------------------------------- - exp_terms_dB <- exp(m_vals + (1:K) * r - b_raw) - denom_dB <- exp(-b_raw) + sum(exp_terms_dB) - p_dB <- exp_terms_dB / denom_dB - res$err_direct_bound[i] <- max(abs(p_dB - p_ref_num) / p_ref_num) - - # --- (2) Direct, clipped bound ------------------------------------------- - exp_terms_dC <- exp(m_vals + (1:K) * r - b_clip) - denom_dC <- exp(-b_clip) + sum(exp_terms_dC) - p_dC <- exp_terms_dC / denom_dC - res$err_direct_clip[i] <- max(abs(p_dC - p_ref_num) / p_ref_num) - - # --- (3) Preexp, un-clipped bound --------------------------------------- - eR <- exp(r) - eM <- exp(m_vals) - eB <- exp(-b_raw) - powE <- eR - S_preB <- eB - terms_preB <- numeric(K) - for (c in 1:K) { - term <- eM[c] * powE * eB - terms_preB[c] <- term - S_preB <- S_preB + term - powE <- powE * eR - } - p_preB <- terms_preB / S_preB - res$err_preexp_bound[i] <- max(abs(p_preB - p_ref_num) / p_ref_num) - - # --- (4) Preexp, clipped bound ------------------------------------------ - eR <- exp(r) - eM <- exp(m_vals) - eB <- exp(-b_clip) - powE <- eR - S_preC <- eB - terms_preC <- numeric(K) - for (c in 1:K) { - term <- eM[c] * powE * eB - terms_preC[c] <- term - S_preC <- S_preC + term - powE <- powE * eR - } - p_preC <- terms_preC / S_preC - res$err_preexp_clip[i] <- max(abs(p_preC - p_ref_num) / p_ref_num) - - res$r[i] <- r - } - - return(res) -} - - -################################################################################ -# 7. Example usage: compare probability ratio stability -################################################################################ - -# K <- 10 -# r_vals <- seq(-75, 75, length.out = 1e4) -# set.seed(123) -# m_vals <- runif(K, -1, 1) -# -# res_ratio <- compare_prob_ratios(K = K, r_vals = r_vals, m_vals = m_vals) -# -# eps <- .Machine$double.eps -# plot(res_ratio$r, pmax(res_ratio$err_direct_bound, eps), -# type = "l", log = "y", lwd = 2, col = "red", -# xlab = "r", ylab = "Relative error (vs MPFR reference)", -# main = "Numerical stability of p_c ratio computations — 4 variants") -# -# lines(res_ratio$r, pmax(res_ratio$err_direct_clip, eps), col = "blue", lwd = 2) -# lines(res_ratio$r, pmax(res_ratio$err_preexp_bound, eps), col = "orange", lwd = 2) -# lines(res_ratio$r, pmax(res_ratio$err_preexp_clip, eps), col = "purple", lwd = 2) -# -# abline(h = .Machine$double.eps, col = "darkgray", lty = 2) -# legend("top", -# legend = c("Direct + Bound", "Direct + Clipped Bound", -# "Preexp + Bound", "Preexp + Clipped Bound"), -# col = c("red", "blue", "orange", "purple"), -# lwd = 2, bty = "n") -# -# abline(v = -70) -# abline(v = 70) -# -# # Summarize numeric accuracy -# summary_df <- data.frame( -# Method = c("Direct + Bound", "Direct + Clipped Bound", -# "Preexp + Bound", "Preexp + Clipped Bound"), -# Mean_error = c(mean(res_ratio$err_direct_bound, na.rm = TRUE), -# mean(res_ratio$err_direct_clip, na.rm = TRUE), -# mean(res_ratio$err_preexp_bound, na.rm = TRUE), -# mean(res_ratio$err_preexp_clip, na.rm = TRUE)), -# Median_error = c(median(res_ratio$err_direct_bound, na.rm = TRUE), -# median(res_ratio$err_direct_clip, na.rm = TRUE), -# median(res_ratio$err_preexp_bound, na.rm = TRUE), -# median(res_ratio$err_preexp_clip, na.rm = TRUE)), -# Max_error = c(max(res_ratio$err_direct_bound, na.rm = TRUE), -# max(res_ratio$err_direct_clip, na.rm = TRUE), -# max(res_ratio$err_preexp_bound, na.rm = TRUE), -# max(res_ratio$err_preexp_clip, na.rm = TRUE)) -# ) -# print(summary_df, digits = 3) -################################################################################ - -############################################################ -# Blume–Capel probabilities: -# Numerical comparison of FAST vs SAFE methods -# -# Objective -# --------- -# For a single Blume–Capel configuration (max_cat, ref, theta_lin, theta_quad), -# and a grid of residual scores r, we compare -# -# p_s(r) ∝ exp( theta_part(s) + s * r ), s = 0..max_cat -# -# with -# -# theta_part(s) = theta_lin * (s - ref) + theta_quad * (s - ref)^2 -# -# computed three ways: -# -# (1) MPFR reference softmax (high precision) -# (2) SAFE : double, direct exponentials with bound (subtract M(r)) -# (3) FAST : double, preexp(theta_part) + power chain for exp(s*r - M(r)) -# -# We record, for each r: -# -# - numeric bound M(r) = max_s [theta_part(s) + s * r] -# - pow_bound = max_cat * r - M(r) -# - max relative error of SAFE -# - max relative error of FAST -# -# No fallbacks, no patching of non-finite values: we let under/overflow -# show up as Inf/NaN in the errors and inspect those. -############################################################ - -library(Rmpfr) # for high-precision reference - -############################################################ -# 1. Reference probabilities using MPFR -############################################################ - -bc_prob_ref_mpfr <- function(max_cat, ref, theta_lin, theta_quad, - r_vals, - mpfr_prec = 256) { - # categories and centered scores - s_vals <- 0:max_cat - c_vals <- s_vals - ref - - # MPFR parameters - tl <- mpfr(theta_lin, precBits = mpfr_prec) - tq <- mpfr(theta_quad, precBits = mpfr_prec) - s_mp <- mpfr(s_vals, precBits = mpfr_prec) - c_mp <- mpfr(c_vals, precBits = mpfr_prec) - - n_r <- length(r_vals) - n_s <- length(s_vals) - - P_ref <- matrix(NA_real_, nrow = n_r, ncol = n_s) - - for (i in seq_len(n_r)) { - r_mp <- mpfr(r_vals[i], precBits = mpfr_prec) - - # exponent(s) = theta_part(s) + s * r - term_mp <- tl * c_mp + tq * c_mp * c_mp + s_mp * r_mp - - # numeric bound M(r) - M_num <- max(asNumeric(term_mp)) - M_mp <- mpfr(M_num, precBits = mpfr_prec) - - # scaled numerators - num_mp <- exp(term_mp - M_mp) - Z_mp <- sum(num_mp) - p_mp <- num_mp / Z_mp - - P_ref[i, ] <- asNumeric(p_mp) - } - - P_ref -} - -############################################################ -# 2. SAFE probabilities (double, direct + bound) -############################################################ - -bc_prob_safe <- function(max_cat, ref, theta_lin, theta_quad, - r_vals) { - s_vals <- 0:max_cat - c_vals <- s_vals - ref - - theta_part <- theta_lin * c_vals + theta_quad * c_vals^2 - - n_r <- length(r_vals) - n_s <- length(s_vals) - - P_safe <- matrix(NA_real_, nrow = n_r, ncol = n_s) - bound <- numeric(n_r) - - for (i in seq_len(n_r)) { - r <- r_vals[i] - - exps <- theta_part + s_vals * r - M <- max(exps) - bound[i] <- M - - numer <- exp(exps - M) - denom <- sum(numer) - - # no fallback here; denom can be 0 or Inf - P_safe[i, ] <- numer / denom - } - - list( - probs = P_safe, - bound = bound - ) -} - -############################################################ -# 3. FAST probabilities (double, preexp + power chain) -############################################################ - -bc_prob_fast <- function(max_cat, ref, theta_lin, theta_quad, - r_vals) { - s_vals <- 0:max_cat - c_vals <- s_vals - ref - - theta_part <- theta_lin * c_vals + theta_quad * c_vals^2 - exp_theta <- exp(theta_part) - - n_r <- length(r_vals) - n_s <- length(s_vals) - - P_fast <- matrix(NA_real_, nrow = n_r, ncol = n_s) - bound <- numeric(n_r) - pow_bound <- numeric(n_r) - - for (i in seq_len(n_r)) { - r <- r_vals[i] - - # exponents before scaling - exps <- theta_part + s_vals * r - M <- max(exps) - bound[i] <- M - - # pow_bound = max_s (s*r - M) attained at s = max_cat - pow_bound[i] <- max_cat * r - M - - eR <- exp(r) - pow <- exp(-M) - - numer <- numeric(n_s) - denom <- 0 - - for (j in seq_len(n_s)) { - numer[j] <- exp_theta[j] * pow - denom <- denom + numer[j] - pow <- pow * eR - } - - # again: no fallback; denom can be 0/Inf - P_fast[i, ] <- numer / denom - } - - list( - probs = P_fast, - bound = bound, - pow_bound = pow_bound - ) -} - -############################################################ -# 4. Core comparison function (one BC config) -############################################################ - -compare_bc_prob_methods <- function(max_cat = 4, - ref = 2, - theta_lin = 0.0, - theta_quad = 0.0, - r_vals = seq(-20, 20, length.out = 200), - mpfr_prec = 256) { - # MPFR reference - P_ref <- bc_prob_ref_mpfr( - max_cat = max_cat, - ref = ref, - theta_lin = theta_lin, - theta_quad = theta_quad, - r_vals = r_vals, - mpfr_prec = mpfr_prec - ) - - # SAFE - safe_res <- bc_prob_safe( - max_cat = max_cat, - ref = ref, - theta_lin = theta_lin, - theta_quad = theta_quad, - r_vals = r_vals - ) - P_safe <- safe_res$probs - bound_safe <- safe_res$bound - - # FAST - fast_res <- bc_prob_fast( - max_cat = max_cat, - ref = ref, - theta_lin = theta_lin, - theta_quad = theta_quad, - r_vals = r_vals - ) - P_fast <- fast_res$probs - bound_fast <- fast_res$bound - pow_bound <- fast_res$pow_bound - - stopifnot(all.equal(bound_safe, bound_fast)) - - n_r <- length(r_vals) - - res <- data.frame( - r = r_vals, - bound = bound_fast, - pow_bound = pow_bound, - err_safe = NA_real_, - err_fast = NA_real_ - ) - - for (i in seq_len(n_r)) { - p_ref <- P_ref[i, ] - p_safe <- P_safe[i, ] - p_fast <- P_fast[i, ] - - # max relative error vs MPFR reference - # (this is exactly in the spirit of compare_prob_ratios) - res$err_safe[i] <- max(abs(p_safe - p_ref) / p_ref) - res$err_fast[i] <- max(abs(p_fast - p_ref) / p_ref) - } - - res -} - -############################################################ -# 5. Example usage -############################################################ - -# Example: small BC variable -# max_cat <- 4 -# ref <- 2 -# theta_lin <- 0.3 -# theta_quad <- -0.1 -# r_vals <- seq(-80, 80, length.out = 2000) -# -# res_bc <- compare_bc_prob_methods( -# max_cat = max_cat, -# ref = ref, -# theta_lin = theta_lin, -# theta_quad = theta_quad, -# r_vals = r_vals, -# mpfr_prec = 256 -# ) -# -# # Quick inspection: log10 errors -# eps <- .Machine$double.eps -# plot(res_bc$r, pmax(res_bc$err_safe, eps), -# type = "l", log = "y", col = "black", lwd = 2, -# xlab = "r", ylab = "Relative error (vs MPFR)", -# main = "Blume–Capel probabilities: SAFE vs FAST") -# lines(res_bc$r, pmax(res_bc$err_fast, eps), col = "red", lwd = 2) -# abline(h = eps, col = "darkgray", lty = 2) -# legend("topright", -# legend = c("SAFE (direct + bound)", "FAST (preexp + power chain)"), -# col = c("black", "red"), -# lwd = 2, bty = "n") -# -# # You can then condition on bound/pow_bound just like in the -# # Blume–Capel normalization script to decide where FAST is safe. -############################################################ - - - - - diff --git a/dev/tests/test-simulation-recovery.R b/dev/tests/test-simulation-recovery.R deleted file mode 100644 index 3a168607..00000000 --- a/dev/tests/test-simulation-recovery.R +++ /dev/null @@ -1,685 +0,0 @@ -# ============================================================================== -# Simulation-Recovery Tests (Correctness Tests) -# ============================================================================== -# -# EXTENDS: test-tolerance.R (stochastic-robust testing approach) -# PATTERN: Self-consistency between estimation and simulation -# -# These tests verify that the estimation and simulation code are consistent: -# 1. Fit model on observed data → estimates A -# 2. Simulate new data from the fitted model -# 3. Refit model on simulated data → estimates B -# 4. Check: cor(A, B) should be high (model can reproduce its own structure) -# -# This approach: -# - Does NOT require knowing "true" parameters -# - Tests consistency between bgm()/bgmCompare() and simulate_mrf()/simulate.bgms() -# - Detects bugs in likelihood, posterior, or simulation code -# -# These tests are computationally expensive and skipped on CRAN. -# ============================================================================== - - -# ------------------------------------------------------------------------------ -# Helper Functions for Simulation-Recovery Tests -# ------------------------------------------------------------------------------ - -#' Run simulation-recovery test for a bgms fit -#' -#' @param fit A fitted bgms object -#' @param n_sim Number of observations to simulate (use >= 500 to avoid constant columns) -#' @param mcmc_args List of MCMC arguments for refitting -#' @param min_correlation Minimum acceptable correlation between estimates -#' @param seed Random seed for reproducibility -#' -#' @return List with correlation values and pass/fail status -run_simrec_test <- function(fit, n_sim = 350, mcmc_args = NULL, - min_correlation = 0.80, seed = 12345) { - - if (is.null(mcmc_args)) { - mcmc_args <- list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none") - } - - # Extract estimates from original fit - original_pairwise <- colMeans(extract_pairwise_interactions(fit)) - original_main <- fit$posterior_summary_main$mean - - # Simulate data from the fitted model (use large n to avoid constant columns) - set.seed(seed) - simulated_data <- simulate(fit, nsim = n_sim, method = "posterior-mean", - seed = seed) - - # Validate: check for constant columns (would cause bgm to fail) - # This can happen when the model predicts extreme probabilities for some categories - col_vars <- apply(simulated_data, 2, function(x) length(unique(x))) - if (any(col_vars < 2)) { - # Return skipped result - model predictions are too extreme for this test - return(list( - cor_pairwise = NA_real_, - cor_main = NA_real_, - passed = NA, - skipped = TRUE, - reason = sprintf("Model produces degenerate predictions for variable(s): %s", - paste(which(col_vars < 2), collapse = ", ")) - )) - } - - # Refit on simulated data - args <- extract_arguments(fit) - refit_args <- c( - list(x = simulated_data, edge_selection = FALSE), - mcmc_args - ) - - # Add variable_type if Blume-Capel - if (any(args$variable_type == "blume-capel")) { - refit_args$variable_type <- args$variable_type - refit_args$baseline_category <- args$baseline_category - } - - refit <- do.call(bgm, refit_args) - - # Extract estimates from refit - refit_pairwise <- colMeans(extract_pairwise_interactions(refit)) - refit_main <- refit$posterior_summary_main$mean - - # Handle potential length mismatch in main parameters - # (can happen when simulated data has fewer categories than original) - n_main <- min(length(original_main), length(refit_main)) - original_main <- original_main[1:n_main] - refit_main <- refit_main[1:n_main] - - # Calculate correlations - # Use Pearson (not Spearman): with few parameters and many near zero, - # Spearman rank correlation is dominated by noise in the ordering of - # near-zero values, while Pearson correctly captures the linear agreement. - cor_pairwise <- cor(original_pairwise, refit_pairwise) - cor_main <- if (n_main >= 3) cor(original_main, refit_main) else NA_real_ - - # If correlation is NA (zero variance or too few params), treat gracefully - if (is.na(cor_pairwise)) cor_pairwise <- 0 - main_testable <- !is.na(cor_main) - if (!main_testable) cor_main <- NA_real_ - - list( - cor_pairwise = cor_pairwise, - cor_main = cor_main, - passed = cor_pairwise >= min_correlation && - (!main_testable || cor_main >= min_correlation) - ) -} - - -#' Run simulation-recovery test for a GGM (continuous) bgms fit -#' -#' @param fit A fitted bgms object (GGM) -#' @param n_sim Number of observations to simulate -#' @param mcmc_args List of MCMC arguments for refitting -#' @param min_correlation Minimum acceptable correlation between estimates -#' @param seed Random seed for reproducibility -#' -#' @return List with correlation values and pass/fail status -run_simrec_test_ggm <- function(fit, n_sim = 500, mcmc_args = NULL, - min_correlation = 0.80, seed = 12345) { - - if(is.null(mcmc_args)) { - mcmc_args <- list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none") - } - - # Extract estimates from original fit - # For GGM, pairwise contains full precision matrix (including diagonal) - original_pairwise <- colMeans(extract_pairwise_interactions(fit)) - original_main <- diag(fit$posterior_mean_pairwise) - - # Simulate data from the fitted model - set.seed(seed) - simulated_data <- simulate(fit, nsim = n_sim, method = "posterior-mean", - seed = seed) - - # Refit on simulated data (must specify variable_type for GGM) - refit_args <- c( - list(x = simulated_data, variable_type = "continuous", - edge_selection = FALSE), - mcmc_args - ) - - refit <- do.call(bgm, refit_args) - - # Extract estimates from refit - refit_pairwise <- colMeans(extract_pairwise_interactions(refit)) - refit_main <- diag(refit$posterior_mean_pairwise) - - # Calculate correlations - cor_pairwise <- cor(original_pairwise, refit_pairwise) - cor_main <- cor(original_main, refit_main) - - if(is.na(cor_pairwise)) cor_pairwise <- 0 - if(is.na(cor_main)) cor_main <- 0 - - list( - cor_pairwise = cor_pairwise, - cor_main = cor_main, - passed = cor_pairwise >= min_correlation && cor_main >= min_correlation - ) -} - - -#' Run simulation-recovery test for a bgmCompare fit -#' -#' @param fit A fitted bgmCompare object -#' @param n_per_group Number of observations per group to simulate (use >= 250) -#' @param mcmc_args List of MCMC arguments for refitting -#' @param min_correlation Minimum acceptable correlation -#' @param seed Random seed -#' -#' @return List with correlation values and pass/fail status -run_simrec_test_compare <- function(fit, n_per_group = 250, mcmc_args = NULL, - min_correlation = 0.75, seed = 12345) { - - if (is.null(mcmc_args)) { - mcmc_args <- list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none") - } - - args <- extract_arguments(fit) - n_groups <- args$num_groups - - # Extract baseline pairwise estimates - original_pairwise <- colMeans(extract_pairwise_interactions(fit)) - - # Simulate data for each group using group-specific parameters - # For now, use baseline parameters (this is a simplification) - interactions <- fit$posterior_mean_pairwise_baseline - thresholds <- fit$posterior_mean_main_baseline - - set.seed(seed) - simulated_datasets <- list() - for (g in seq_len(n_groups)) { - simulated_datasets[[g]] <- simulate_mrf( - num_states = n_per_group, - num_variables = args$num_variables, - num_categories = args$num_categories, - pairwise = interactions, - main = thresholds, - seed = seed + g - ) - colnames(simulated_datasets[[g]]) <- args$data_columnnames - } - - # Combine into single dataset with group indicator - combined_data <- do.call(rbind, simulated_datasets) - group_indicator <- rep(seq_len(n_groups), each = n_per_group) - - # Validate: check for constant columns (would cause bgmCompare to fail) - col_vars <- apply(combined_data, 2, function(x) length(unique(x))) - if (any(col_vars < 2)) { - stop(sprintf("Simulated data has constant column(s): %s. Increase n_per_group or use different seed.", - paste(which(col_vars < 2), collapse = ", "))) - } - - # Refit - refit_args <- c( - list(x = combined_data, group_indicator = group_indicator, - difference_selection = FALSE), - mcmc_args - ) - - refit <- do.call(bgmCompare, refit_args) - - # Extract estimates from refit - refit_pairwise <- colMeans(extract_pairwise_interactions(refit)) - - # Calculate correlation (handle zero variance edge case) - cor_pairwise <- cor(original_pairwise, refit_pairwise) - - # If correlation is NA (zero variance), treat as failed - if (is.na(cor_pairwise)) cor_pairwise <- 0 - - list( - cor_pairwise = cor_pairwise, - passed = cor_pairwise >= min_correlation - ) -} - - -# ------------------------------------------------------------------------------ -# bgm() Simulation-Recovery Tests -# ------------------------------------------------------------------------------ -# These tests fit fresh models on larger datasets (matching data dimensions) -# rather than using the small session-cached fixtures. -# This takes longer but provides proper correctness validation. - -test_that("bgm simulation-recovery: ordinal variables (NUTS)", { - skip_on_cran() - - # Use full Wenchuan data - data("Wenchuan", package = "bgms") - x <- na.omit(Wenchuan[, 1:5]) - n_obs <- nrow(x) - - # Fit with adequate MCMC (1000 iter, 1000 warmup for proper convergence) - fit <- bgm(x, iter = 1000, warmup = 1000, chains = 1, - edge_selection = FALSE, seed = 11111, - display_progress = "none") - - result <- run_simrec_test( - fit, - n_sim = n_obs, # Match original sample size - mcmc_args = list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none"), - min_correlation = 0.80, - seed = 11111 - ) - - # Handle skipped case (model produces degenerate predictions) - if (isTRUE(result$skipped)) { - skip(result$reason) - } - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) - expect_true( - result$cor_main >= 0.80, - info = sprintf("Main effects correlation = %.3f (expected >= 0.80)", - result$cor_main) - ) -}) - - -test_that("bgm simulation-recovery: binary variables (NUTS)", { - skip_on_cran() - - # Use full ADHD data - data("ADHD", package = "bgms") - x <- ADHD[, 2:6] - n_obs <- nrow(x) - - fit <- bgm(x, iter = 1000, warmup = 1000, chains = 1, - edge_selection = FALSE, seed = 22222, - display_progress = "none") - - result <- run_simrec_test( - fit, - n_sim = n_obs, - mcmc_args = list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none"), - min_correlation = 0.75, - seed = 22222 - ) - - # Handle skipped case - if (isTRUE(result$skipped)) { - skip(result$reason) - } - - expect_true( - result$cor_pairwise >= 0.75, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.75)", - result$cor_pairwise) - ) -}) - - -test_that("bgm simulation-recovery: Blume-Capel variables", { - skip_on_cran() - - # Start from BC-simulated data to avoid category collapse: - # Fitting BC on 5-category Wenchuan data then simulating produces only - # 3 categories (the quadratic potential concentrates mass near baseline). - # That makes the refit incomparable to the original. Using BC-simulated - # data as the starting point keeps original and simulated data in the - # same distributional regime. - # - # Uses a non-zero baseline_category to exercise the centering logic - # in the OMRF C++ backend (observations_double_ centered around baseline). - p <- 5 - n_obs <- 500 - pairwise <- matrix(0, p, p) - pairwise[1, 2] <- pairwise[2, 1] <- 0.5 - pairwise[2, 3] <- pairwise[3, 2] <- 0.3 - pairwise[4, 5] <- pairwise[5, 4] <- -0.25 - - main <- matrix(0, p, 2) - main[, 1] <- c(-0.5, 0.0, 0.3, -0.2, 0.1) # linear - main[, 2] <- c(-0.3, -0.5, -0.4, -0.2, -0.6) # quadratic - - x <- simulate_mrf( - num_states = n_obs, num_variables = p, - num_categories = rep(3, p), - pairwise = pairwise, main = main, - variable_type = "blume-capel", - baseline_category = 1, - seed = 33333 - ) - colnames(x) <- paste0("V", 1:p) - - fit <- bgm(x, iter = 5000, warmup = 1000, chains = 2, - variable_type = "blume-capel", baseline_category = 1, - edge_selection = FALSE, seed = 33333, - display_progress = "none") - - result <- run_simrec_test( - fit, - n_sim = n_obs, - mcmc_args = list(iter = 5000, warmup = 1000, chains = 2, - display_progress = "none"), - min_correlation = 0.80, - seed = 33333 - ) - - # Handle skipped case - if (isTRUE(result$skipped)) { - skip(result$reason) - } - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) -}) - - -test_that("bgm simulation-recovery: adaptive-metropolis", { - skip_on_cran() - - # Use ADHD data with adaptive-metropolis sampler - data("ADHD", package = "bgms") - x <- ADHD[, 2:6] - n_obs <- nrow(x) - - fit <- bgm(x, iter = 1000, warmup = 1000, chains = 1, - update_method = "adaptive-metropolis", - edge_selection = FALSE, seed = 44444, - display_progress = "none") - - result <- run_simrec_test( - fit, - n_sim = n_obs, - mcmc_args = list(iter = 1000, warmup = 1000, chains = 1, - update_method = "adaptive-metropolis", - display_progress = "none"), - min_correlation = 0.80, - seed = 44444 - ) - - # Handle skipped case - if (isTRUE(result$skipped)) { - skip(result$reason) - } - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) -}) - - -test_that("bgm simulation-recovery: GGM (continuous variables)", { - skip_on_cran() - - # Generate continuous data from a known precision matrix - p <- 5 - omega_true <- diag(p) - omega_true[1, 2] <- omega_true[2, 1] <- 0.4 - omega_true[2, 3] <- omega_true[3, 2] <- 0.3 - omega_true[4, 5] <- omega_true[5, 4] <- -0.25 - omega_true[1, 3] <- omega_true[3, 1] <- 0.15 - - n_obs <- 500 - x <- simulate_mrf( - num_states = n_obs, - num_variables = p, - pairwise = omega_true, - variable_type = "continuous", - seed = 99999 - ) - colnames(x) <- paste0("V", 1:p) - - # Fit GGM - fit <- bgm(x, variable_type = "continuous", - iter = 1000, warmup = 1000, chains = 1, - edge_selection = FALSE, seed = 99999, - display_progress = "none") - - result <- run_simrec_test_ggm( - fit, - n_sim = n_obs, - mcmc_args = list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none"), - min_correlation = 0.80, - seed = 99999 - ) - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) - expect_true( - result$cor_main >= 0.80, - info = sprintf("Main effects (diagonal precision) correlation = %.3f (expected >= 0.80)", - result$cor_main) - ) -}) - - -# ------------------------------------------------------------------------------ -# bgmCompare() Simulation-Recovery Tests -# ------------------------------------------------------------------------------ - -test_that("bgmCompare simulation-recovery: ordinal variables", { - skip_on_cran() - - # Use Boredom split into 2 groups - data("Boredom", package = "bgms") - x <- na.omit(Boredom[, 2:6]) - n_obs <- nrow(x) - group_ind <- 1 * (Boredom[, 1] == "fr") - - fit <- bgmCompare(x, group_indicator = group_ind, - iter = 1000, warmup = 1000, chains = 1, - difference_selection = FALSE, seed = 55555, - display_progress = "none") - - result <- run_simrec_test_compare( - fit, - n_per_group = sum(group_ind), - mcmc_args = list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none"), - min_correlation = 0.80, - seed = 55555 - ) - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) -}) - - -test_that("bgmCompare simulation-recovery: binary variables", { - skip_on_cran() - - # Use ADHD data with diagnosis group - data("ADHD", package = "bgms") - x <- ADHD[, 2:6] - group_ind <- ADHD[, "group"] - - fit <- bgmCompare(x, group_indicator = group_ind, - iter = 1000, warmup = 1000, chains = 1, - difference_selection = FALSE, seed = 66666, - display_progress = "none") - - # Get group sizes for simulation - n_per_group <- min(table(group_ind)) - - result <- run_simrec_test_compare( - fit, - n_per_group = n_per_group, - mcmc_args = list(iter = 1000, warmup = 1000, chains = 1, - display_progress = "none"), - min_correlation = 0.80, - seed = 66666 - ) - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) -}) - - -test_that("bgmCompare simulation-recovery: adaptive-metropolis", { - skip_on_cran() - - # Use ADHD data with adaptive-metropolis - data("ADHD", package = "bgms") - x <- ADHD[, 2:6] - group_ind <- ADHD[, "group"] - - fit <- bgmCompare(x, group_indicator = group_ind, - iter = 1000, warmup = 1000, chains = 1, - update_method = "adaptive-metropolis", - difference_selection = FALSE, seed = 77777, - display_progress = "none") - - n_per_group <- min(table(group_ind)) - - result <- run_simrec_test_compare( - fit, - n_per_group = n_per_group, - mcmc_args = list(iter = 10000, warmup = 1000, chains = 1, - update_method = "adaptive-metropolis", - display_progress = "none"), - min_correlation = 0.80, - seed = 77777 - ) - - expect_true( - result$cor_pairwise >= 0.80, - info = sprintf("Pairwise correlation = %.3f (expected >= 0.80)", - result$cor_pairwise) - ) -}) - - -# ------------------------------------------------------------------------------ -# Cross-Method Consistency Tests -# ------------------------------------------------------------------------------ - -test_that("NUTS and adaptive-metropolis produce consistent estimates", -{ - skip_on_cran() - - # Use larger dataset for meaningful comparison - data("Wenchuan", package = "bgms") - x <- na.omit(Wenchuan[, 1:5]) - - # Fit with NUTS - fit_nuts <- bgm(x, iter = 1000, warmup = 1000, chains = 1, - update_method = "nuts", edge_selection = FALSE, - seed = 88888, display_progress = "none") - - # Fit with adaptive-metropolis - fit_am <- bgm(x, iter = 10000, warmup = 1000, chains = 1, - update_method = "adaptive-metropolis", edge_selection = FALSE, - seed = 88888, display_progress = "none") - - # Compare posterior means - nuts_pairwise <- colMeans(extract_pairwise_interactions(fit_nuts)) - am_pairwise <- colMeans(extract_pairwise_interactions(fit_am)) - - cor_val <- cor(nuts_pairwise, am_pairwise) - - expect_true( - cor_val >= 0.80, - info = sprintf("NUTS vs AM correlation = %.3f (expected >= 0.80)", cor_val) - ) -}) - - -# ============================================================================== -# GGM Posterior Recovery Test -# ============================================================================== -# -# Verifies that the GGM sampler recovers known precision matrix parameters. - -test_that("GGM posterior recovers parameters from simulated data", { - - n <- 1000 - p <- 10 - ne <- p * (p - 1) / 2 - - # Fixed precision matrix (avoids BDgraph dependency) - omega <- structure(c(6.240119, 0, 0, -0.370239, 0, 0, 0, 0, -1.622902, - 0, 0, 1.905013, 0, -0.194995, 0, 0, -2.468628, -0.557277, 0, - 0, 0, 0, 5.509142, -7.942389, 1.40081, 0, 0, -0.76775, 0, 0, - -0.370239, -0.194995, -7.942389, 15.521405, -3.537489, 0, 4.60785, - 0, 3.278511, 0, 0, 0, 1.40081, -3.537489, 2.78257, 0, 0, 1.374641, - 0, -1.198092, 0, 0, 0, 0, 0, 1.350879, 0, 0.230677, -1.357952, - 0, 0, -2.468628, 0, 4.60785, 0, 0, 15.88698, 0, 1.20017, -1.973919, - 0, -0.557277, -0.76775, 0, 1.374641, 0.230677, 0, 7.007312, 1.597035, - 0, -1.622902, 0, 0, 3.278511, 0, -1.357952, 1.20017, 1.597035, - 13.378039, -4.769958, 0, 0, 0, 0, -1.198092, 0, -1.973919, 0, - -4.769958, 5.536877), dim = c(10L, 10L)) - adj <- omega != 0 - diag(adj) <- 0 - covmat <- solve(omega) - chol_cov <- chol(covmat) - - set.seed(43) - x <- matrix(rnorm(n * p), nrow = n, ncol = p) %*% chol_cov - - # Without edge selection - fit_no_vs <- bgm( - x = x, variable_type = "continuous", - edge_selection = FALSE, - iter = 3000, warmup = 500, chains = 2, - display_progress = "none", seed = 42 - ) - - expect_true(cor(fit_no_vs$posterior_summary_main$mean, diag(omega)) > 0.9) - expect_true(cor(fit_no_vs$posterior_summary_pairwise$mean, omega[lower.tri(omega)]) > 0.9) - - # With edge selection (Bernoulli prior) - fit_vs <- bgm( - x = x, variable_type = "continuous", - edge_selection = TRUE, - iter = 5000, warmup = 500, chains = 2, - display_progress = "none", seed = 42 - ) - - expect_true(cor(fit_vs$posterior_summary_main$mean, diag(omega)) > 0.9) - expect_true(cor(fit_vs$posterior_summary_pairwise$mean, omega[lower.tri(omega)]) > 0.9) - expect_true(cor(fit_vs$posterior_summary_indicator$mean, adj[lower.tri(adj)]) > 0.85) - - # With edge selection (SBM prior) - fit_vs_sbm <- bgm( - x = x, variable_type = "continuous", - edge_selection = TRUE, - edge_prior = "Stochastic-Block", - iter = 5000, warmup = 500, chains = 2, - display_progress = "none", seed = 42 - ) - - expect_true(cor(fit_vs_sbm$posterior_summary_main$mean, diag(omega)) > 0.9) - expect_true(cor(fit_vs_sbm$posterior_summary_pairwise$mean, omega[lower.tri(omega)]) > 0.9) - expect_true(cor(fit_vs_sbm$posterior_summary_indicator$mean, adj[lower.tri(adj)]) > 0.85) - - # SBM-specific output - expect_false(is.null(fit_vs_sbm$posterior_mean_coclustering_matrix)) - expect_equal(nrow(fit_vs_sbm$posterior_mean_coclustering_matrix), p) - expect_equal(ncol(fit_vs_sbm$posterior_mean_coclustering_matrix), p) - expect_false(is.null(fit_vs_sbm$posterior_num_blocks)) - expect_false(is.null(fit_vs_sbm$posterior_mode_allocations)) - expect_false(is.null(fit_vs_sbm$raw_samples$allocations)) -}) diff --git a/man/ADHD.Rd b/man/ADHD.Rd index bd9886c1..bf26c4b0 100644 --- a/man/ADHD.Rd +++ b/man/ADHD.Rd @@ -7,36 +7,36 @@ \format{ A matrix with 355 rows and 19 columns. \describe{ - \item{group}{ADHD diagnosis: 1 = diagnosed, 0 = not diagnosed} - \item{avoid}{Often avoids, dislikes, or is reluctant to engage in tasks - that require sustained mental effort (I)} - \item{closeatt}{Often fails to give close attention to details or makes - careless mistakes in schoolwork, work, or other activities (I)} - \item{distract}{Is often easily distracted by extraneous stimuli (I)} - \item{forget}{Is often forgetful in daily activities (I)} - \item{instruct}{Often does not follow through on instructions and fails to - finish schoolwork, chores, or duties in the workplace (I)} - \item{listen}{Often does not seem to listen when spoken to directly - (I)} - \item{loses}{Often loses things necessary for tasks or activities (I)} - \item{org}{Often has difficulty organizing tasks and activities (I)} - \item{susatt}{Often has difficulty sustaining attention in tasks or play - activities (I)} - \item{blurts}{Often blurts out answers before questions have been completed - (HI)} - \item{fidget}{Often fidgets with hands or feet or squirms in seat - (HI)} - \item{interrupt}{Often interrupts or intrudes on others (HI)} - \item{motor}{Is often "on the go" or often acts as if "driven by a motor" - (HI)} - \item{quiet}{Often has difficulty playing or engaging in leisure activities - quietly (HI)} - \item{runs}{Often runs about or climbs excessively in situations in which - it is inappropriate (HI)} - \item{seat}{Often leaves seat in classroom or in other situations in which - remaining seated is expected (HI)} - \item{talks}{Often talks excessively (HI)} - \item{turn}{Often has difficulty awaiting turn (HI)} +\item{group}{ADHD diagnosis: 1 = diagnosed, 0 = not diagnosed} +\item{avoid}{Often avoids, dislikes, or is reluctant to engage in tasks +that require sustained mental effort (I)} +\item{closeatt}{Often fails to give close attention to details or makes +careless mistakes in schoolwork, work, or other activities (I)} +\item{distract}{Is often easily distracted by extraneous stimuli (I)} +\item{forget}{Is often forgetful in daily activities (I)} +\item{instruct}{Often does not follow through on instructions and fails to +finish schoolwork, chores, or duties in the workplace (I)} +\item{listen}{Often does not seem to listen when spoken to directly +(I)} +\item{loses}{Often loses things necessary for tasks or activities (I)} +\item{org}{Often has difficulty organizing tasks and activities (I)} +\item{susatt}{Often has difficulty sustaining attention in tasks or play +activities (I)} +\item{blurts}{Often blurts out answers before questions have been completed +(HI)} +\item{fidget}{Often fidgets with hands or feet or squirms in seat +(HI)} +\item{interrupt}{Often interrupts or intrudes on others (HI)} +\item{motor}{Is often "on the go" or often acts as if "driven by a motor" +(HI)} +\item{quiet}{Often has difficulty playing or engaging in leisure activities +quietly (HI)} +\item{runs}{Often runs about or climbs excessively in situations in which +it is inappropriate (HI)} +\item{seat}{Often leaves seat in classroom or in other situations in which +remaining seated is expected (HI)} +\item{talks}{Often talks excessively (HI)} +\item{turn}{Often has difficulty awaiting turn (HI)} } } \source{ diff --git a/man/Boredom.Rd b/man/Boredom.Rd index 7ed3a86f..b28ec419 100644 --- a/man/Boredom.Rd +++ b/man/Boredom.Rd @@ -7,19 +7,19 @@ \format{ A matrix with 986 rows and 9 columns. Each row corresponds to a respondent. \describe{ - \item{language}{Language in which the SBPS was administered: "en" = English, "fr" = French} - \item{loose_ends}{I often find myself at “loose ends,” not knowing what to - do.} - \item{entertain}{I find it hard to entertain myself.} - \item{repetitive}{Many things I have to do are repetitive and monotonous.} - \item{stimulation}{It takes more stimulation to get me going than most - people.} - \item{motivated}{I don't feel motivated by most things that I do.} - \item{keep_interest}{In most situations, it is hard for me to find - something to do or see to keep me interested.} - \item{sit_around}{Much of the time, I just sit around doing nothing.} - \item{half_dead_dull}{Unless I am doing something exciting, even dangerous, - I feel half-dead and dull.} +\item{language}{Language in which the SBPS was administered: "en" = English, "fr" = French} +\item{loose_ends}{I often find myself at “loose ends,” not knowing what to +do.} +\item{entertain}{I find it hard to entertain myself.} +\item{repetitive}{Many things I have to do are repetitive and monotonous.} +\item{stimulation}{It takes more stimulation to get me going than most +people.} +\item{motivated}{I don't feel motivated by most things that I do.} +\item{keep_interest}{In most situations, it is hard for me to find +something to do or see to keep me interested.} +\item{sit_around}{Much of the time, I just sit around doing nothing.} +\item{half_dead_dull}{Unless I am doing something exciting, even dangerous, +I feel half-dead and dull.} } } \source{ diff --git a/man/Wenchuan.Rd b/man/Wenchuan.Rd index 7a523736..c6bfb0cf 100644 --- a/man/Wenchuan.Rd +++ b/man/Wenchuan.Rd @@ -7,33 +7,33 @@ \format{ A matrix with 362 rows and 17 columns. Each row represents a participant. \describe{ - \item{intrusion}{Repeated, disturbing memories, thoughts, or images of a - stressful experience from the past?} - \item{dreams}{Repeated, disturbing dreams of a stressful experience from - the past?} - \item{flash}{Suddenly acting or feeling as if a stressful experience were - happening again (as if you were reliving it)?} - \item{upset}{Feeling very upset when something reminded you of a stressful - experience from the past?} - \item{physior}{Having physical reactions (e.g., heart pounding, trouble - breathing, sweating) when something reminded you of a stressful experience - from the past?} - \item{avoidth}{Avoiding thinking about or talking about a stressful - experience from the past or avoiding having feelings related to it?} - \item{avoidact}{Avoiding activities or situations because they reminded you - of a stressful experience from the past?} - \item{amnesia}{Trouble remembering important parts of a stressful - experience from the past?} - \item{lossint}{Loss of interest in activities that you used to enjoy?} - \item{distant}{Feeling distant or cut off from other people?} - \item{numb}{Feeling emotionally numb or being unable to have loving - feelings for those close to you?} - \item{future}{Feeling as if your future will somehow be cut short?} - \item{sleep}{Trouble falling or staying asleep?} - \item{anger}{Feeling irritable or having angry outbursts?} - \item{concen}{Having difficulty concentrating?} - \item{hyper}{Being "super-alert" or watchful or on guard?} - \item{startle}{Feeling jumpy or easily startled?} +\item{intrusion}{Repeated, disturbing memories, thoughts, or images of a +stressful experience from the past?} +\item{dreams}{Repeated, disturbing dreams of a stressful experience from +the past?} +\item{flash}{Suddenly acting or feeling as if a stressful experience were +happening again (as if you were reliving it)?} +\item{upset}{Feeling very upset when something reminded you of a stressful +experience from the past?} +\item{physior}{Having physical reactions (e.g., heart pounding, trouble +breathing, sweating) when something reminded you of a stressful experience +from the past?} +\item{avoidth}{Avoiding thinking about or talking about a stressful +experience from the past or avoiding having feelings related to it?} +\item{avoidact}{Avoiding activities or situations because they reminded you +of a stressful experience from the past?} +\item{amnesia}{Trouble remembering important parts of a stressful +experience from the past?} +\item{lossint}{Loss of interest in activities that you used to enjoy?} +\item{distant}{Feeling distant or cut off from other people?} +\item{numb}{Feeling emotionally numb or being unable to have loving +feelings for those close to you?} +\item{future}{Feeling as if your future will somehow be cut short?} +\item{sleep}{Trouble falling or staying asleep?} +\item{anger}{Feeling irritable or having angry outbursts?} +\item{concen}{Having difficulty concentrating?} +\item{hyper}{Being "super-alert" or watchful or on guard?} +\item{startle}{Feeling jumpy or easily startled?} } } \source{ diff --git a/man/bgm.Rd b/man/bgm.Rd index 92826f3e..9d8a0ea7 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -33,6 +33,7 @@ bgm( display_progress = c("per-chain", "total", "none"), seed = NULL, standardize = FALSE, + pseudolikelihood = c("conditional", "marginal"), verbose = getOption("bgms.verbose", TRUE), interaction_scale, burnin, @@ -49,8 +50,11 @@ ordinal variables, unobserved categories are collapsed; for Blume–Capel variables, all categories are retained.} \item{variable_type}{Character or character vector. Specifies the type of -each variable in \code{x}. Allowed values: \code{"ordinal"} or -\code{"blume-capel"}. Binary variables are automatically treated as +each variable in \code{x}. Allowed values: \code{"ordinal"}, +\code{"blume-capel"}, or \code{"continuous"}. A single string applies +to all variables. A per-variable vector that mixes discrete +(\code{"ordinal"} / \code{"blume-capel"}) and \code{"continuous"} +types fits a mixed MRF. Binary variables are automatically treated as \code{"ordinal"}. Default: \code{"ordinal"}.} \item{baseline_category}{Integer or vector. Baseline category used in @@ -108,12 +112,12 @@ number of clusters in the Stochastic Block Model. Default: \code{1}.} \item{update_method}{Character. Specifies how the MCMC sampler updates the model parameters: \describe{ - \item{"adaptive-metropolis"}{Componentwise adaptive Metropolis–Hastings - with Robbins–Monro proposal adaptation.} - \item{"hamiltonian-mc"}{Hamiltonian Monte Carlo with fixed path length - (number of leapfrog steps set by \code{hmc_num_leapfrogs}).} - \item{"nuts"}{The No-U-Turn Sampler, an adaptive form of HMC with - dynamically chosen trajectory lengths.} +\item{"adaptive-metropolis"}{Componentwise adaptive Metropolis–Hastings +with Robbins–Monro proposal adaptation.} +\item{"hamiltonian-mc"}{Hamiltonian Monte Carlo with fixed path length +(number of leapfrog steps set by \code{hmc_num_leapfrogs}).} +\item{"nuts"}{The No-U-Turn Sampler, an adaptive form of HMC with +dynamically chosen trajectory lengths.} } Default: \code{"nuts"}.} @@ -166,14 +170,29 @@ raw score endpoints \eqn{(0, m)} and Blume-Capel variables use centered score endpoints \eqn{(-b, m-b)}. Default: \code{FALSE}.} +\item{pseudolikelihood}{Character. Specifies the pseudo-likelihood +approximation used for mixed MRF models (ignored for pure ordinal or +pure continuous data). Options: +\describe{ +\item{\code{"conditional"}}{Conditions on the observed continuous +variables when computing the discrete full conditionals. Faster +because the discrete pseudo-likelihood does not depend on the +continuous precision matrix.} +\item{\code{"marginal"}}{Integrates out the continuous variables, +giving discrete full conditionals that account for induced +interactions through the continuous block. More expensive per +iteration.} +} +Default: \code{"conditional"}.} + \item{verbose}{Logical. If \code{TRUE}, prints informational messages during data processing (e.g., missing data handling, variable recoding). Defaults to \code{getOption("bgms.verbose", TRUE)}. Set \code{options(bgms.verbose = FALSE)} to suppress messages globally.} -\item{interaction_scale, burnin, save, threshold_alpha, threshold_beta}{`r lifecycle::badge("deprecated")` +\item{interaction_scale, burnin, save, threshold_alpha, threshold_beta}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} Deprecated arguments as of \strong{bgms 0.1.6.0}. -Use `pairwise_scale`, `warmup`, `main_alpha`, and `main_beta` instead.} +Use \code{pairwise_scale}, \code{warmup}, \code{main_alpha}, and \code{main_beta} instead.} } \value{ A list of class \code{"bgms"} with posterior summaries, posterior mean @@ -182,56 +201,66 @@ matrices, and access to raw MCMC draws. The object can be passed to Main components include: \itemize{ - \item \code{posterior_summary_main}: Data frame with posterior summaries - (mean, sd, MCSE, ESS, Rhat) for category threshold parameters. - \item \code{posterior_summary_pairwise}: Data frame with posterior - summaries for pairwise interaction parameters. - \item \code{posterior_summary_indicator}: Data frame with posterior - summaries for edge inclusion indicators (if \code{edge_selection = TRUE}). - - \item \code{posterior_mean_main}: Matrix of posterior mean thresholds - (rows = variables, cols = categories or parameters). - \item \code{posterior_mean_pairwise}: Symmetric matrix of posterior mean - pairwise interaction strengths. - \item \code{posterior_mean_indicator}: Symmetric matrix of posterior mean - inclusion probabilities (if edge selection was enabled). - - \item Additional summaries returned when - \code{edge_prior = "Stochastic-Block"}. For more details about this prior - see \insertCite{SekulovskiEtAl_2025;textual}{bgms}. - \itemize{ - \item \code{posterior_summary_pairwise_allocations}: Data frame with - posterior summaries (mean, sd, MCSE, ESS, Rhat) for the pairwise - cluster co-occurrence of the nodes. This serves to indicate - whether the estimated posterior allocations,co-clustering matrix - and posterior cluster probabilities (see blow) have converged. - \item \code{posterior_coclustering_matrix}: a symmetric matrix of - pairwise proportions of occurrence of every variable. This matrix - can be plotted to visually inspect the estimated number of clusters - and visually inspect nodes that tend to switch clusters. - \item \code{posterior_mean_allocations}: A vector with the posterior mean - of the cluster allocations of the nodes. This is calculated using the method - proposed in \insertCite{Dahl2009;textual}{bgms}. - \item \code{posterior_mode_allocations}: A vector with the posterior - mode of the cluster allocations of the nodes. - \item \code{posterior_num_blocks}: A data frame with the estimated - posterior inclusion probabilities for all the possible number of clusters. - } - \item \code{raw_samples}: A list of raw MCMC draws per chain: - \describe{ - \item{\code{main}}{List of main effect samples.} - \item{\code{pairwise}}{List of pairwise effect samples.} - \item{\code{indicator}}{List of indicator samples - (if edge selection enabled).} - \item{\code{allocations}}{List of cluster allocations - (if SBM prior used).} - \item{\code{nchains}}{Number of chains.} - \item{\code{niter}}{Number of post–warmup iterations per chain.} - \item{\code{parameter_names}}{Named lists of parameter labels.} - } - - \item \code{arguments}: A list of function call arguments and metadata - (e.g., number of variables, warmup, sampler settings, package version). +\item \code{posterior_summary_main}: Data frame with posterior summaries +(mean, sd, MCSE, ESS, Rhat) for main-effect parameters. +For OMRF models these are category thresholds; +for mixed MRF models these are discrete thresholds and +continuous means. \code{NULL} for GGM models (no main effects). +\item \code{posterior_summary_quadratic}: Data frame with posterior +summaries for the precision matrix diagonal. Present for GGM and +mixed MRF models; \code{NULL} for OMRF models. +\item \code{posterior_summary_pairwise}: Data frame with posterior +summaries for pairwise interaction parameters. +\item \code{posterior_summary_indicator}: Data frame with posterior +summaries for edge inclusion indicators (if \code{edge_selection = TRUE}). + +\item \code{posterior_mean_main}: Posterior mean of main-effect +parameters. \code{NULL} for GGM models. For OMRF: a matrix +(p x max_categories) of category thresholds. For mixed MRF: a list +with \code{$discrete} (threshold matrix) and \code{$continuous} +(q x 1 matrix of means). +\item \code{posterior_mean_pairwise}: Symmetric matrix of posterior mean +pairwise interaction strengths. For GGM and mixed MRF models the +precision matrix diagonal is included on the matrix diagonal. +\item \code{posterior_mean_indicator}: Symmetric matrix of posterior mean +inclusion probabilities (if edge selection was enabled). + +\item Additional summaries returned when +\code{edge_prior = "Stochastic-Block"}. For more details about this prior +see \insertCite{SekulovskiEtAl_2025;textual}{bgms}. +\itemize{ +\item \code{posterior_summary_pairwise_allocations}: Data frame with +posterior summaries (mean, sd, MCSE, ESS, Rhat) for the pairwise +cluster co-occurrence of the nodes. This serves to indicate +whether the estimated posterior allocations,co-clustering matrix +and posterior cluster probabilities (see blow) have converged. +\item \code{posterior_coclustering_matrix}: a symmetric matrix of +pairwise proportions of occurrence of every variable. This matrix +can be plotted to visually inspect the estimated number of clusters +and visually inspect nodes that tend to switch clusters. +\item \code{posterior_mean_allocations}: A vector with the posterior mean +of the cluster allocations of the nodes. This is calculated using the method +proposed in \insertCite{Dahl2009;textual}{bgms}. +\item \code{posterior_mode_allocations}: A vector with the posterior +mode of the cluster allocations of the nodes. +\item \code{posterior_num_blocks}: A data frame with the estimated +posterior inclusion probabilities for all the possible number of clusters. +} +\item \code{raw_samples}: A list of raw MCMC draws per chain: +\describe{ +\item{\code{main}}{List of main effect samples.} +\item{\code{pairwise}}{List of pairwise effect samples.} +\item{\code{indicator}}{List of indicator samples +(if edge selection enabled).} +\item{\code{allocations}}{List of cluster allocations +(if SBM prior used).} +\item{\code{nchains}}{Number of chains.} +\item{\code{niter}}{Number of post–warmup iterations per chain.} +\item{\code{parameter_names}}{Named lists of parameter labels.} +} + +\item \code{arguments}: A list of function call arguments and metadata +(e.g., number of variables, warmup, sampler settings, package version). } The \code{summary()} method prints formatted posterior summaries, and @@ -241,16 +270,18 @@ NUTS diagnostics (tree depth, divergences, energy, E-BFMI) are included in \code{fit$nuts_diag} if \code{update_method = "nuts"}. } \description{ -The \code{bgm} function estimates the pseudoposterior distribution of -category thresholds (main effects) and pairwise interaction parameters of a -Markov Random Field (MRF) model for binary and/or ordinal variables. -Optionally, it performs Bayesian edge selection using spike-and-slab -priors to infer the network structure. +The \code{bgm} function estimates the pseudoposterior distribution of the +parameters of a Markov Random Field (MRF) for binary, ordinal, continuous, +or mixed (discrete and continuous) variables. Depending on the variable +types, the model is an ordinal MRF, a Gaussian graphical model (GGM), or a +mixed MRF. Optionally, it performs Bayesian edge selection using +spike-and-slab priors to infer the network structure. } \details{ -This function models the joint distribution of binary and ordinal variables -using a Markov Random Field, with support for edge selection through Bayesian -variable selection. The statistical foundation of the model is described in +This function models the joint distribution of binary, ordinal, continuous, +or mixed variables using a Markov Random Field, with support for edge +selection through Bayesian variable selection. The statistical foundation +of the model is described in \insertCite{MarsmanVandenBerghHaslbeck_2025;textual}{bgms}, where the ordinal MRF model and its Bayesian estimation procedure were first introduced. While the implementation in \pkg{bgms} has since been extended and updated (e.g., @@ -276,16 +307,16 @@ by distance from this baseline. Category thresholds are modeled as: where: \itemize{ - \item \eqn{\mu_{c}}: category threshold for category \eqn{c} - \item \eqn{\alpha}: linear trend across categories - \item \eqn{\beta}: preference toward or away from the baseline - \itemize{ - \item If \eqn{\beta < 0}, the model favors responses near the baseline - category; - \item if \eqn{\beta > 0}, it favors responses farther away (i.e., - extremes). - } - \item \eqn{b}: baseline category +\item \eqn{\mu_{c}}: category threshold for category \eqn{c} +\item \eqn{\alpha}: linear trend across categories +\item \eqn{\beta}: preference toward or away from the baseline +\itemize{ +\item If \eqn{\beta < 0}, the model favors responses near the baseline +category; +\item if \eqn{\beta > 0}, it favors responses farther away (i.e., +extremes). +} +\item \eqn{b}: baseline category } Accordingly, pairwise interactions between Blume-Capel variables are modeled in terms of \eqn{c-b} scores. @@ -299,11 +330,11 @@ spike-and-slab priors. Supported priors for edge inclusion: \itemize{ - \item \strong{Bernoulli}: Fixed inclusion probability across edges. - \item \strong{Beta-Bernoulli}: Inclusion probability is assigned a Beta - prior distribution. - \item \strong{Stochastic-Block}: Cluster-based edge priors with Beta, - Dirichlet, and Poisson hyperpriors. +\item \strong{Bernoulli}: Fixed inclusion probability across edges. +\item \strong{Beta-Bernoulli}: Inclusion probability is assigned a Beta +prior distribution. +\item \strong{Stochastic-Block}: Cluster-based edge priors with Beta, +Dirichlet, and Poisson hyperpriors. } All priors operate via binary indicator variables controlling the inclusion @@ -314,11 +345,11 @@ or exclusion of each edge in the MRF. \itemize{ - \item \strong{Pairwise effects}: Modeled with a Cauchy (slab) prior. - \item \strong{Main effects}: Modeled using a beta-prime - distribution. - \item \strong{Edge indicators}: Use either a Bernoulli, Beta-Bernoulli, or - Stochastic-Block prior (as above). +\item \strong{Pairwise effects}: Modeled with a Cauchy (slab) prior. +\item \strong{Main effects}: Modeled using a beta-prime +distribution. +\item \strong{Edge indicators}: Use either a Bernoulli, Beta-Bernoulli, or +Stochastic-Block prior (as above). } } @@ -328,20 +359,20 @@ or exclusion of each edge in the MRF. Parameters are updated within a Gibbs framework, but the conditional updates can be carried out using different algorithms: \itemize{ - \item \strong{Adaptive Metropolis–Hastings}: Componentwise random–walk - updates for main effects and pairwise effects. Proposal standard - deviations are adapted during burn–in via Robbins–Monro updates - toward a target acceptance rate. - - \item \strong{Hamiltonian Monte Carlo (HMC)}: Joint updates of all - parameters using fixed–length leapfrog trajectories. Step size is - tuned during warmup via dual–averaging; the diagonal mass matrix can - also be adapted if \code{learn_mass_matrix = TRUE}. - - \item \strong{No–U–Turn Sampler (NUTS)}: An adaptive extension of HMC - that dynamically chooses trajectory lengths. Warmup uses a staged - adaptation schedule (fast–slow–fast) to stabilize step size and, if - enabled, the mass matrix. +\item \strong{Adaptive Metropolis–Hastings}: Componentwise random–walk +updates for main effects and pairwise effects. Proposal standard +deviations are adapted during burn–in via Robbins–Monro updates +toward a target acceptance rate. + +\item \strong{Hamiltonian Monte Carlo (HMC)}: Joint updates of all +parameters using fixed–length leapfrog trajectories. Step size is +tuned during warmup via dual–averaging; the diagonal mass matrix can +also be adapted if \code{learn_mass_matrix = TRUE}. + +\item \strong{No–U–Turn Sampler (NUTS)}: An adaptive extension of HMC +that dynamically chooses trajectory lengths. Warmup uses a staged +adaptation schedule (fast–slow–fast) to stabilize step size and, if +enabled, the mass matrix. } When \code{edge_selection = TRUE}, updates of edge–inclusion indicators @@ -362,28 +393,28 @@ schedule \insertCite{stan-manual}{bgms}. Warmup iterations are split into several phases: \itemize{ - \item \strong{Stage 1 (fast adaptation)}: A short initial interval - where only step size (for HMC/NUTS) is adapted, allowing the chain - to move quickly toward the typical set. - - \item \strong{Stage 2 (slow windows)}: A sequence of expanding, - memoryless windows where both step size and, if - \code{learn_mass_matrix = TRUE}, the diagonal mass matrix are - adapted. Each window ends with a reset of the dual–averaging scheme - for improved stability. - - \item \strong{Stage 3a (final fast interval)}: A short interval at the - end of the core warmup where the step size is adapted one final time. - - \item \strong{Stage 3b (proposal–SD tuning)}: Only active when - \code{edge_selection = TRUE} under HMC/NUTS. In this phase, - Robbins–Monro adaptation of proposal standard deviations is - performed for the Metropolis steps used in edge–selection moves. - - \item \strong{Stage 3c (graph selection warmup)}: Also only relevant - when \code{edge_selection = TRUE}. At the start of this phase, a - random graph structure is initialized, and Metropolis–Hastings - updates for edge inclusion indicators are switched on. +\item \strong{Stage 1 (fast adaptation)}: A short initial interval +where only step size (for HMC/NUTS) is adapted, allowing the chain +to move quickly toward the typical set. + +\item \strong{Stage 2 (slow windows)}: A sequence of expanding, +memoryless windows where both step size and, if +\code{learn_mass_matrix = TRUE}, the diagonal mass matrix are +adapted. Each window ends with a reset of the dual–averaging scheme +for improved stability. + +\item \strong{Stage 3a (final fast interval)}: A short interval at the +end of the core warmup where the step size is adapted one final time. + +\item \strong{Stage 3b (proposal–SD tuning)}: Only active when +\code{edge_selection = TRUE} under HMC/NUTS. In this phase, +Robbins–Monro adaptation of proposal standard deviations is +performed for the Metropolis steps used in edge–selection moves. + +\item \strong{Stage 3c (graph selection warmup)}: Also only relevant +when \code{edge_selection = TRUE}. At the start of this phase, a +random graph structure is initialized, and Metropolis–Hastings +updates for edge inclusion indicators are switched on. } When \code{edge_selection = FALSE}, the total number of warmup iterations diff --git a/man/bgmCompare.Rd b/man/bgmCompare.Rd index 5a45a3e9..db4cecab 100644 --- a/man/bgmCompare.Rd +++ b/man/bgmCompare.Rd @@ -149,30 +149,30 @@ during data processing (e.g., missing data handling, variable recoding). Defaults to \code{getOption("bgms.verbose", TRUE)}. Set \code{options(bgms.verbose = FALSE)} to suppress messages globally.} -\item{main_difference_model, reference_category, pairwise_difference_scale, main_difference_scale, pairwise_difference_prior, main_difference_prior, pairwise_difference_probability, main_difference_probability, pairwise_beta_bernoulli_alpha, pairwise_beta_bernoulli_beta, main_beta_bernoulli_alpha, main_beta_bernoulli_beta, interaction_scale, threshold_alpha, threshold_beta, burnin, save}{`r lifecycle::badge("deprecated")` +\item{main_difference_model, reference_category, pairwise_difference_scale, main_difference_scale, pairwise_difference_prior, main_difference_prior, pairwise_difference_probability, main_difference_probability, pairwise_beta_bernoulli_alpha, pairwise_beta_bernoulli_beta, main_beta_bernoulli_alpha, main_beta_bernoulli_beta, interaction_scale, threshold_alpha, threshold_beta, burnin, save}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} Deprecated arguments as of \strong{bgms 0.1.6.0}. -Use `difference_scale`, `difference_prior`, `difference_probability`, -`beta_bernoulli_alpha`, `beta_bernoulli_beta`, `baseline_category`, -`pairwise_scale`, and `warmup` instead.} +Use \code{difference_scale}, \code{difference_prior}, \code{difference_probability}, +\code{beta_bernoulli_alpha}, \code{beta_bernoulli_beta}, \code{baseline_category}, +\code{pairwise_scale}, and \code{warmup} instead.} } \value{ A list of class \code{"bgmCompare"} containing posterior summaries, posterior mean matrices, and raw MCMC samples: \itemize{ - \item \code{posterior_summary_main_baseline}, - \code{posterior_summary_pairwise_baseline}: summaries of baseline - thresholds and pairwise interactions. - \item \code{posterior_summary_main_differences}, - \code{posterior_summary_pairwise_differences}: summaries of group - differences in thresholds and pairwise interactions. - \item \code{posterior_summary_indicator}: summaries of inclusion - indicators (if \code{difference_selection = TRUE}). - \item \code{posterior_mean_main_baseline}, - \code{posterior_mean_pairwise_baseline}: posterior mean matrices - (legacy style). - \item \code{raw_samples}: list of raw draws per chain for main, - pairwise, and indicator parameters. - \item \code{arguments}: list of function call arguments and metadata. +\item \code{posterior_summary_main_baseline}, +\code{posterior_summary_pairwise_baseline}: summaries of baseline +thresholds and pairwise interactions. +\item \code{posterior_summary_main_differences}, +\code{posterior_summary_pairwise_differences}: summaries of group +differences in thresholds and pairwise interactions. +\item \code{posterior_summary_indicator}: summaries of inclusion +indicators (if \code{difference_selection = TRUE}). +\item \code{posterior_mean_main_baseline}, +\code{posterior_mean_pairwise_baseline}: posterior mean matrices +(legacy style). +\item \code{raw_samples}: list of raw draws per chain for main, +pairwise, and indicator parameters. +\item \code{arguments}: list of function call arguments and metadata. } The \code{summary()} method prints formatted summaries, and @@ -226,8 +226,8 @@ baseline plus group differences. When \code{difference_selection = TRUE}, spike-and-slab priors are applied to difference parameters: \itemize{ - \item \strong{Bernoulli}: fixed prior inclusion probability. - \item \strong{Beta–Bernoulli}: inclusion probability given a Beta prior. +\item \strong{Bernoulli}: fixed prior inclusion probability. +\item \strong{Beta–Bernoulli}: inclusion probability given a Beta prior. } } @@ -237,14 +237,14 @@ Parameters are updated within a Gibbs framework, using the same sampling algorithms and staged warmup scheme described in \code{\link{bgm}}: \itemize{ - \item \strong{Adaptive Metropolis–Hastings}: componentwise random–walk - proposals with Robbins–Monro adaptation of proposal SDs. - \item \strong{Hamiltonian Monte Carlo (HMC)}: joint updates with fixed - leapfrog trajectories; step size and optionally the mass matrix are - adapted during warmup. - \item \strong{No–U–Turn Sampler (NUTS)}: an adaptive HMC variant with - dynamic trajectory lengths; warmup uses the same staged adaptation - schedule as HMC. +\item \strong{Adaptive Metropolis–Hastings}: componentwise random–walk +proposals with Robbins–Monro adaptation of proposal SDs. +\item \strong{Hamiltonian Monte Carlo (HMC)}: joint updates with fixed +leapfrog trajectories; step size and optionally the mass matrix are +adapted during warmup. +\item \strong{No–U–Turn Sampler (NUTS)}: an adaptive HMC variant with +dynamic trajectory lengths; warmup uses the same staged adaptation +schedule as HMC. } For details on the staged adaptation schedule (fast–slow–fast phases), diff --git a/man/bgms-package.Rd b/man/bgms-package.Rd index d1f16dd7..311a378d 100644 --- a/man/bgms-package.Rd +++ b/man/bgms-package.Rd @@ -4,50 +4,58 @@ \name{bgms-package} \alias{bgms} \alias{bgms-package} -\title{bgms: Bayesian Analysis of Networks of Binary and/or Ordinal Variables} +\title{bgms: Bayesian Analysis of Graphical Models} \description{ The \code{R} package \strong{bgms} provides tools for Bayesian analysis of -the ordinal Markov random field (MRF), a graphical model describing networks -of binary and/or ordinal variables \insertCite{MarsmanVandenBerghHaslbeck_2025}{bgms}. -The likelihood is approximated via a pseudolikelihood, and Markov chain Monte -Carlo (MCMC) methods are used to sample from the corresponding pseudoposterior -distribution of model parameters. +graphical models describing networks of binary, ordinal, continuous, and +mixed variables +\insertCite{MarsmanVandenBerghHaslbeck_2025}{bgms}. +Supported model families include ordinal Markov random fields (MRFs), +Gaussian graphical models (GGMs), and mixed MRFs that combine discrete +and continuous variables in a single network. The likelihood is approximated +via a pseudolikelihood, and Markov chain Monte Carlo (MCMC) methods are used +to sample from the corresponding pseudoposterior distribution of model +parameters. The main entry points are: \itemize{ - \item \strong{bgm}: estimation in a one-sample design. - \item \strong{bgmCompare}: estimation and group comparison in an - independent-sample design. +\item \strong{bgm}: estimation in a one-sample design. +Use \code{variable_type = "ordinal"} for an MRF, +\code{"continuous"} for a GGM, or a per-variable vector +mixing \code{"ordinal"}, \code{"blume-capel"}, and +\code{"continuous"} for a mixed MRF. +\item \strong{bgmCompare}: estimation and group comparison in an +independent-sample design. } Both functions support Bayesian effect selection with spike-and-slab priors. \itemize{ - \item In one-sample designs, \code{bgm} models the presence or absence of - edges between variables. Posterior inclusion probabilities quantify the - plausibility of each edge and can be converted into Bayes factors for - conditional independence tests. +\item In one-sample designs, \code{bgm} models the presence or absence of +edges between variables. Posterior inclusion probabilities quantify the +plausibility of each edge and can be converted into Bayes factors for +conditional independence tests. - \item \code{bgm} can also model communities (clusters) of variables. The - posterior distribution of the number of clusters provides evidence for or - against clustering \insertCite{SekulovskiEtAl_2025}{bgms}. +\item \code{bgm} can also model communities (clusters) of variables. The +posterior distribution of the number of clusters provides evidence for or +against clustering \insertCite{SekulovskiEtAl_2025}{bgms}. - \item In independent-sample designs, \code{bgmCompare} estimates group - differences in edge weights and category thresholds. Posterior inclusion - probabilities quantify the evidence for differences and can be converted - into Bayes factors for parameter equivalence tests - \insertCite{MarsmanWaldorpSekulovskiHaslbeck_2024}{bgms}. +\item In independent-sample designs, \code{bgmCompare} estimates group +differences in edge weights and category thresholds. Posterior inclusion +probabilities quantify the evidence for differences and can be converted +into Bayes factors for parameter equivalence tests +\insertCite{MarsmanWaldorpSekulovskiHaslbeck_2024}{bgms}. } } \section{Tools}{ The package also provides: \enumerate{ - \item Simulation of response data from MRFs with a Gibbs sampler - (\code{\link{simulate_mrf}}). - \item Posterior estimation and edge selection in one-sample designs - (\code{\link{bgm}}). - \item Posterior estimation and group-difference selection in - independent-sample designs (\code{\link{bgmCompare}}). +\item Simulation of response data from MRFs with a Gibbs sampler +(\code{\link{simulate_mrf}}). +\item Posterior estimation and edge selection in one-sample designs +(\code{\link{bgm}}). +\item Posterior estimation and group-difference selection in +independent-sample designs (\code{\link{bgmCompare}}). } } @@ -55,10 +63,10 @@ The package also provides: For tutorials and worked examples, see: \itemize{ - \item \code{vignette("intro", package = "bgms")} — Getting started. - \item \code{vignette("comparison", package = "bgms")} — Model comparison. - \item \code{vignette("diagnostics", package = "bgms")} — Diagnostics and - spike-and-slab summaries. +\item \code{vignette("intro", package = "bgms")} — Getting started. +\item \code{vignette("comparison", package = "bgms")} — Model comparison. +\item \code{vignette("diagnostics", package = "bgms")} — Diagnostics and +spike-and-slab summaries. } } diff --git a/man/coef.bgmCompare.Rd b/man/coef.bgmCompare.Rd index fc16bff9..45ea3189 100644 --- a/man/coef.bgmCompare.Rd +++ b/man/coef.bgmCompare.Rd @@ -14,16 +14,16 @@ \value{ A list with components: \describe{ - \item{main_effects_raw}{Posterior means of the raw main-effect parameters - (variables x [baseline + differences]).} - \item{pairwise_effects_raw}{Posterior means of the raw pairwise-effect parameters - (pairs x [baseline + differences]).} - \item{main_effects_groups}{Posterior means of group-specific main effects - (variables x groups), computed as baseline plus projected differences.} - \item{pairwise_effects_groups}{Posterior means of group-specific pairwise effects - (pairs x groups), computed as baseline plus projected differences.} - \item{indicators}{Posterior mean inclusion probabilities as a symmetric matrix, - with diagonals corresponding to main effects and off-diagonals to pairwise effects.} +\item{main_effects_raw}{Posterior means of the raw main-effect parameters +(variables x (baseline + differences)).} +\item{pairwise_effects_raw}{Posterior means of the raw pairwise-effect parameters +(pairs x (baseline + differences)).} +\item{main_effects_groups}{Posterior means of group-specific main effects +(variables x groups), computed as baseline plus projected differences.} +\item{pairwise_effects_groups}{Posterior means of group-specific pairwise effects +(pairs x groups), computed as baseline plus projected differences.} +\item{indicators}{Posterior mean inclusion probabilities as a symmetric matrix, +with diagonals corresponding to main effects and off-diagonals to pairwise effects.} } } \description{ @@ -37,7 +37,7 @@ and group-specific effects from a \code{bgmCompare} fit, as well as inclusion in } \seealso{ -[bgmCompare()], [print.bgmCompare()], [summary.bgmCompare()] +\code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=print.bgmCompare]{print.bgmCompare()}}, \code{\link[=summary.bgmCompare]{summary.bgmCompare()}} Other posterior-methods: \code{\link{coef.bgms}()}, diff --git a/man/coef.bgms.Rd b/man/coef.bgms.Rd index ae663fdc..e6ab834f 100644 --- a/man/coef.bgms.Rd +++ b/man/coef.bgms.Rd @@ -14,13 +14,19 @@ \value{ A list with the following components: \describe{ - \item{main}{Posterior mean of the category threshold parameters.} - \item{pairwise}{Posterior mean of the pairwise interaction matrix.} - \item{indicator}{Posterior mean of the edge inclusion indicators (if available).} +\item{main}{Posterior mean of the main-effect parameters. \code{NULL} for +GGM models (no main effects). For OMRF models this is a numeric matrix +(p x max_categories) of category thresholds. For mixed MRF models this +is a list with \code{$discrete} (p x max_categories matrix) and +\code{$continuous} (q x 1 matrix of means).} +\item{pairwise}{Posterior mean of the pairwise interaction matrix. For GGM +and mixed MRF models the precision matrix diagonal is included on the +matrix diagonal.} +\item{indicator}{Posterior mean of the edge inclusion indicators (if available).} } } \description{ -Returns the posterior mean thresholds, pairwise effects, and edge inclusion indicators from a \code{bgms} model fit. +Returns the posterior mean main effects, pairwise effects, and edge inclusion indicators from a \code{bgms} model fit. } \examples{ \donttest{ @@ -30,7 +36,7 @@ coef(fit) } \seealso{ -[bgm()], [print.bgms()], [summary.bgms()] +\code{\link[=bgm]{bgm()}}, \code{\link[=print.bgms]{print.bgms()}}, \code{\link[=summary.bgms]{summary.bgms()}} Other posterior-methods: \code{\link{coef.bgmCompare}()}, diff --git a/man/extract_arguments.Rd b/man/extract_arguments.Rd index f4212df3..24b0a5ff 100644 --- a/man/extract_arguments.Rd +++ b/man/extract_arguments.Rd @@ -7,20 +7,20 @@ extract_arguments(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A named list containing all arguments passed to the fitting - function, including data dimensions, prior settings, and MCMC - configuration. +function, including data dimensions, prior settings, and MCMC +configuration. } \description{ -Retrieves the arguments used when fitting a model with [bgm()] or -[bgmCompare()]. +Retrieves the arguments used when fitting a model with \code{\link[=bgm]{bgm()}} or +\code{\link[=bgmCompare]{bgmCompare()}}. } \seealso{ -[bgm()], [bgmCompare()], [summary.bgms()], [summary.bgmCompare()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=summary.bgms]{summary.bgms()}}, \code{\link[=summary.bgmCompare]{summary.bgmCompare()}} Other extractors: \code{\link{extract_category_thresholds}()}, @@ -28,6 +28,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, diff --git a/man/extract_category_thresholds.Rd b/man/extract_category_thresholds.Rd index ef1b7bf2..a4c3470d 100644 --- a/man/extract_category_thresholds.Rd +++ b/man/extract_category_thresholds.Rd @@ -7,23 +7,24 @@ extract_category_thresholds(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ -\describe{ - \item{bgms}{A matrix with one row per variable and one column per - category threshold, containing posterior means.} - \item{bgmCompare}{A matrix with one row per post-warmup iteration, - containing posterior samples of baseline threshold parameters.} - } +See \code{\link[=extract_main_effects]{extract_main_effects()}} for details. } \description{ -Retrieves category threshold parameters from a model fitted with -[bgm()] or [bgmCompare()]. +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} + +\code{extract_category_thresholds()} was renamed to \code{\link[=extract_main_effects]{extract_main_effects()}} to +reflect that main effects include continuous means and precisions +(mixed MRF), not only category thresholds. +} +\details{ +Extract Category Threshold Estimates } \seealso{ -[bgm()], [bgmCompare()], [extract_pairwise_interactions()] +\code{\link[=extract_main_effects]{extract_main_effects()}} Other extractors: \code{\link{extract_arguments}()}, @@ -31,6 +32,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, diff --git a/man/extract_ess.Rd b/man/extract_ess.Rd index de2c1f8c..7bed7472 100644 --- a/man/extract_ess.Rd +++ b/man/extract_ess.Rd @@ -7,19 +7,19 @@ extract_ess(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A named list with ESS values for each parameter type present in - the model (e.g., `main`, `pairwise`, `indicator`). +the model (e.g., \code{main}, \code{pairwise}, \code{indicator}). } \description{ Retrieves effective sample size estimates for all parameters from a -model fitted with [bgm()] or [bgmCompare()]. +model fitted with \code{\link[=bgm]{bgm()}} or \code{\link[=bgmCompare]{bgmCompare()}}. } \seealso{ -[bgm()], [bgmCompare()], [extract_rhat()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_rhat]{extract_rhat()}} Other extractors: \code{\link{extract_arguments}()}, @@ -27,6 +27,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, diff --git a/man/extract_group_params.Rd b/man/extract_group_params.Rd index 1d4ca459..ed5ad754 100644 --- a/man/extract_group_params.Rd +++ b/man/extract_group_params.Rd @@ -7,20 +7,20 @@ extract_group_params(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgmCompare` -(from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgmCompare} +(from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ -A list with elements `main_effects_groups` (main effects per - group) and `pairwise_effects_groups` (pairwise effects per group). +A list with elements \code{main_effects_groups} (main effects per +group) and \code{pairwise_effects_groups} (pairwise effects per group). } \description{ Computes group-specific parameter estimates by combining baseline -parameters and group differences from a model fitted with [bgmCompare()]. +parameters and group differences from a model fitted with \code{\link[=bgmCompare]{bgmCompare()}}. } \seealso{ -[bgmCompare()], [extract_pairwise_interactions()], - [extract_category_thresholds()] +\code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_pairwise_interactions]{extract_pairwise_interactions()}}, +\code{\link[=extract_main_effects]{extract_main_effects()}} Other extractors: \code{\link{extract_arguments}()}, @@ -28,6 +28,7 @@ Other extractors: \code{\link{extract_ess}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, diff --git a/man/extract_indicator_priors.Rd b/man/extract_indicator_priors.Rd index bba68c32..44f6f91b 100644 --- a/man/extract_indicator_priors.Rd +++ b/man/extract_indicator_priors.Rd @@ -7,27 +7,27 @@ extract_indicator_priors(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A named list describing the prior structure, including the prior - type and any hyperparameters. - \describe{ - \item{bgms}{Requires `edge_selection = TRUE`. Returns a list with the - prior type (`"Bernoulli"`, `"Beta-Bernoulli"`, or - `"Stochastic-Block"`) and associated hyperparameters.} - \item{bgmCompare}{Requires `difference_selection = TRUE`. Returns the - difference prior specification.} - } +type and any hyperparameters. +\describe{ +\item{bgms}{Requires \code{edge_selection = TRUE}. Returns a list with the +prior type (\code{"Bernoulli"}, \code{"Beta-Bernoulli"}, or +\code{"Stochastic-Block"}) and associated hyperparameters.} +\item{bgmCompare}{Requires \code{difference_selection = TRUE}. Returns the +difference prior specification.} +} } \description{ Retrieves the prior specification used for inclusion indicators in a -model fitted with [bgm()] (edge indicators) or [bgmCompare()] +model fitted with \code{\link[=bgm]{bgm()}} (edge indicators) or \code{\link[=bgmCompare]{bgmCompare()}} (difference indicators). } \seealso{ -[bgm()], [bgmCompare()], [extract_indicators()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_indicators]{extract_indicators()}} Other extractors: \code{\link{extract_arguments}()}, @@ -35,6 +35,7 @@ Other extractors: \code{\link{extract_ess}()}, \code{\link{extract_group_params}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, diff --git a/man/extract_indicators.Rd b/man/extract_indicators.Rd index 41b1cc81..5127cd00 100644 --- a/man/extract_indicators.Rd +++ b/man/extract_indicators.Rd @@ -7,26 +7,26 @@ extract_indicators(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A matrix with one row per post-warmup iteration and one column per - indicator, containing binary (0/1) samples. - \describe{ - \item{bgms}{One column per edge. Requires `edge_selection = TRUE`.} - \item{bgmCompare}{Columns for main-effect and pairwise difference - indicators. Requires `difference_selection = TRUE`.} - } +indicator, containing binary (0/1) samples. +\describe{ +\item{bgms}{One column per edge. Requires \code{edge_selection = TRUE}.} +\item{bgmCompare}{Columns for main-effect and pairwise difference +indicators. Requires \code{difference_selection = TRUE}.} +} } \description{ Retrieves posterior samples of inclusion indicators from a model fitted -with [bgm()] (edge inclusion indicators) or [bgmCompare()] (difference +with \code{\link[=bgm]{bgm()}} (edge inclusion indicators) or \code{\link[=bgmCompare]{bgmCompare()}} (difference indicators). } \seealso{ -[bgm()], [bgmCompare()], - [extract_posterior_inclusion_probabilities()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, +\code{\link[=extract_posterior_inclusion_probabilities]{extract_posterior_inclusion_probabilities()}} Other extractors: \code{\link{extract_arguments}()}, @@ -34,6 +34,7 @@ Other extractors: \code{\link{extract_ess}()}, \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, diff --git a/man/extract_main_effects.Rd b/man/extract_main_effects.Rd new file mode 100644 index 00000000..00cf9ee4 --- /dev/null +++ b/man/extract_main_effects.Rd @@ -0,0 +1,64 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extractor_functions.R +\name{extract_main_effects} +\alias{extract_main_effects} +\title{Extract Main Effect Estimates} +\usage{ +extract_main_effects(bgms_object) +} +\arguments{ +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} +} +\value{ +The structure depends on the model type: +\describe{ +\item{GGM (bgms)}{\code{NULL} (invisibly). GGM models have no main effects; +the precision matrix diagonal is on \code{coef(fit)$pairwise}.} +\item{OMRF (bgms)}{A numeric matrix with one row per variable and one +column per category threshold, containing posterior means. Columns +beyond the number of categories for a variable are \code{NA}.} +\item{Mixed MRF (bgms)}{A list with two elements: +\describe{ +\item{discrete}{A numeric matrix (p rows x max_categories columns) +of posterior mean thresholds for discrete variables.} +\item{continuous}{A numeric matrix (q rows x 1 column) of +posterior mean continuous variable means.} +}} +\item{bgmCompare}{A matrix with one row per post-warmup iteration, +containing posterior samples of baseline main-effect parameters.} +} +} +\description{ +Retrieves posterior mean main-effect parameters from a model fitted with +\code{\link[=bgm]{bgm()}} or \code{\link[=bgmCompare]{bgmCompare()}}. For OMRF models these are category thresholds; +for mixed MRF models these include discrete thresholds and continuous +means. GGM models have no main effects and return \code{NULL}. +} +\details{ +Extract Main Effect Estimates +} +\examples{ +\donttest{ +fit = bgm(x = Wenchuan[, 1:3]) +extract_main_effects(fit) +} + +} +\seealso{ +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_pairwise_interactions]{extract_pairwise_interactions()}}, +\code{\link[=extract_category_thresholds]{extract_category_thresholds()}} + +Other extractors: +\code{\link{extract_arguments}()}, +\code{\link{extract_category_thresholds}()}, +\code{\link{extract_ess}()}, +\code{\link{extract_group_params}()}, +\code{\link{extract_indicator_priors}()}, +\code{\link{extract_indicators}()}, +\code{\link{extract_pairwise_interactions}()}, +\code{\link{extract_posterior_inclusion_probabilities}()}, +\code{\link{extract_rhat}()}, +\code{\link{extract_sbm}()} +} +\concept{extractors} diff --git a/man/extract_pairwise_interactions.Rd b/man/extract_pairwise_interactions.Rd index 70e02886..0ce5b42a 100644 --- a/man/extract_pairwise_interactions.Rd +++ b/man/extract_pairwise_interactions.Rd @@ -7,24 +7,24 @@ extract_pairwise_interactions(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A matrix with one row per post-warmup iteration and one column per - edge, containing posterior samples of interaction strengths. - \describe{ - \item{bgms}{Columns correspond to all unique variable pairs.} - \item{bgmCompare}{Columns correspond to the baseline pairwise - interaction parameters.} - } +edge, containing posterior samples of interaction strengths. +\describe{ +\item{bgms}{Columns correspond to all unique variable pairs.} +\item{bgmCompare}{Columns correspond to the baseline pairwise +interaction parameters.} +} } \description{ Retrieves posterior samples of pairwise interaction parameters from a -model fitted with [bgm()] or [bgmCompare()]. +model fitted with \code{\link[=bgm]{bgm()}} or \code{\link[=bgmCompare]{bgmCompare()}}. } \seealso{ -[bgm()], [bgmCompare()], [extract_category_thresholds()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_main_effects]{extract_main_effects()}} Other extractors: \code{\link{extract_arguments}()}, @@ -33,6 +33,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()}, \code{\link{extract_sbm}()} diff --git a/man/extract_pairwise_thresholds.Rd b/man/extract_pairwise_thresholds.Rd index f6a8920f..58e0b330 100644 --- a/man/extract_pairwise_thresholds.Rd +++ b/man/extract_pairwise_thresholds.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/extractor_functions.R \name{extract_pairwise_thresholds} \alias{extract_pairwise_thresholds} -\title{Deprecated: Use extract_category_thresholds instead} +\title{Deprecated: Use extract_main_effects instead} \usage{ extract_pairwise_thresholds(bgms_object) } @@ -10,6 +10,6 @@ extract_pairwise_thresholds(bgms_object) \item{bgms_object}{A bgms or bgmCompare object.} } \description{ -Deprecated: Use extract_category_thresholds instead +Deprecated: Use extract_main_effects instead } \keyword{internal} diff --git a/man/extract_posterior_inclusion_probabilities.Rd b/man/extract_posterior_inclusion_probabilities.Rd index 059232f1..aac217ed 100644 --- a/man/extract_posterior_inclusion_probabilities.Rd +++ b/man/extract_posterior_inclusion_probabilities.Rd @@ -7,26 +7,26 @@ extract_posterior_inclusion_probabilities(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A symmetric p x p matrix of posterior inclusion probabilities, - with variable names as row and column names. - \describe{ - \item{bgms}{Off-diagonal entries are edge inclusion probabilities. - Requires `edge_selection = TRUE`.} - \item{bgmCompare}{Diagonal entries are main-effect inclusion - probabilities; off-diagonal entries are pairwise difference - inclusion probabilities. Requires `difference_selection = TRUE`.} - } +with variable names as row and column names. +\describe{ +\item{bgms}{Off-diagonal entries are edge inclusion probabilities. +Requires \code{edge_selection = TRUE}.} +\item{bgmCompare}{Diagonal entries are main-effect inclusion +probabilities; off-diagonal entries are pairwise difference +inclusion probabilities. Requires \code{difference_selection = TRUE}.} +} } \description{ Computes posterior inclusion probabilities from a model fitted with -[bgm()] (edge inclusion) or [bgmCompare()] (difference inclusion). +\code{\link[=bgm]{bgm()}} (edge inclusion) or \code{\link[=bgmCompare]{bgmCompare()}} (difference inclusion). } \seealso{ -[bgm()], [bgmCompare()], [extract_indicators()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_indicators]{extract_indicators()}} Other extractors: \code{\link{extract_arguments}()}, @@ -35,6 +35,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_rhat}()}, \code{\link{extract_sbm}()} diff --git a/man/extract_rhat.Rd b/man/extract_rhat.Rd index b923c584..96d0f338 100644 --- a/man/extract_rhat.Rd +++ b/man/extract_rhat.Rd @@ -7,19 +7,19 @@ extract_rhat(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]) -or `bgmCompare` (from [bgmCompare()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}) +or \code{bgmCompare} (from \code{\link[=bgmCompare]{bgmCompare()}}).} } \value{ A named list with R-hat values for each parameter type present in - the model (e.g., `main`, `pairwise`, `indicator`). +the model (e.g., \code{main}, \code{pairwise}, \code{indicator}). } \description{ Retrieves R-hat convergence diagnostics for all parameters from a -model fitted with [bgm()] or [bgmCompare()]. +model fitted with \code{\link[=bgm]{bgm()}} or \code{\link[=bgmCompare]{bgmCompare()}}. } \seealso{ -[bgm()], [bgmCompare()], [extract_ess()] +\code{\link[=bgm]{bgm()}}, \code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=extract_ess]{extract_ess()}} Other extractors: \code{\link{extract_arguments}()}, @@ -28,6 +28,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_sbm}()} diff --git a/man/extract_sbm.Rd b/man/extract_sbm.Rd index 987153b7..23759049 100644 --- a/man/extract_sbm.Rd +++ b/man/extract_sbm.Rd @@ -7,21 +7,21 @@ extract_sbm(bgms_object) } \arguments{ -\item{bgms_object}{A fitted model object of class `bgms` (from [bgm()]).} +\item{bgms_object}{A fitted model object of class \code{bgms} (from \code{\link[=bgm]{bgm()}}).} } \value{ -A list with elements `posterior_num_blocks`, - `posterior_mean_allocations`, `posterior_mode_allocations`, and - `posterior_mean_coclustering_matrix`. Requires `edge_selection = TRUE` - and `edge_prior = "Stochastic-Block"`. +A list with elements \code{posterior_num_blocks}, +\code{posterior_mean_allocations}, \code{posterior_mode_allocations}, and +\code{posterior_mean_coclustering_matrix}. Requires \code{edge_selection = TRUE} +and \code{edge_prior = "Stochastic-Block"}. } \description{ -Retrieves posterior summaries from a model fitted with [bgm()] using +Retrieves posterior summaries from a model fitted with \code{\link[=bgm]{bgm()}} using the Stochastic Block prior on edge inclusion. } \seealso{ -[bgm()], [extract_indicators()], - [extract_posterior_inclusion_probabilities()] +\code{\link[=bgm]{bgm()}}, \code{\link[=extract_indicators]{extract_indicators()}}, +\code{\link[=extract_posterior_inclusion_probabilities]{extract_posterior_inclusion_probabilities()}} Other extractors: \code{\link{extract_arguments}()}, @@ -30,6 +30,7 @@ Other extractors: \code{\link{extract_group_params}()}, \code{\link{extract_indicator_priors}()}, \code{\link{extract_indicators}()}, +\code{\link{extract_main_effects}()}, \code{\link{extract_pairwise_interactions}()}, \code{\link{extract_posterior_inclusion_probabilities}()}, \code{\link{extract_rhat}()} diff --git a/man/mrfSampler.Rd b/man/mrfSampler.Rd index 8ff00c14..310dbbeb 100644 --- a/man/mrfSampler.Rd +++ b/man/mrfSampler.Rd @@ -21,15 +21,18 @@ mrfSampler( \item{num_variables}{The number of variables in the MRF.} -\item{num_categories}{Either a positive integer or a vector of positive -integers of length \code{num_variables}. The number of response categories on top -of the base category: \code{num_categories = 1} generates binary states. +\item{num_categories}{Either a positive integer or a vector +of positive integers of length \code{num_variables}. The +number of response categories on top of the base category: +\code{num_categories = 1} generates binary states. Only used for ordinal and Blume-Capel variables; ignored when \code{variable_type = "continuous"}.} -\item{pairwise}{A symmetric \code{num_variables} by \code{num_variables} matrix. -For ordinal and Blume-Capel variables, this contains the pairwise interaction -parameters; only the off-diagonal elements are used. For continuous variables, +\item{pairwise}{A symmetric \code{num_variables} by +\code{num_variables} matrix. For ordinal and Blume-Capel +variables, this contains the pairwise interaction parameters; +only the off-diagonal elements are used. For continuous +variables, this is the precision matrix \eqn{\Omega}{Omega} (including diagonal) and must be positive definite.} @@ -37,7 +40,8 @@ must be positive definite.} \code{num_variables} by \code{max(num_categories)} matrix of category thresholds. The elements in row \code{i} indicate the thresholds of variable \code{i}. If \code{num_categories} is a vector, only the first -\code{num_categories[i]} elements are used in row \code{i}. If the Blume-Capel +\code{num_categories[i]} elements are used in row \code{i}. +If the Blume-Capel model is used for the category thresholds for variable \code{i}, then row \code{i} requires two values (details below); the first is \eqn{\alpha}{\alpha}, the linear contribution of the Blume-Capel model and @@ -49,16 +53,18 @@ if not supplied or if all values are zero.} \item{variable_type}{What kind of variables are simulated? Can be a single character string specifying the variable type of all \code{p} variables at once or a vector of character strings of length \code{p} specifying the type -for each variable separately. Currently, bgm supports ``ordinal'', -``blume-capel'', and ``continuous''. Binary variables are automatically -treated as ``ordinal''. Ordinal and Blume-Capel variables can be mixed +for each variable separately. Currently, bgm supports \code{"ordinal"}, +\code{"blume-capel"}, and \code{"continuous"}. Binary variables are automatically +treated as \code{"ordinal"}. Ordinal and Blume-Capel variables can be mixed freely, but continuous variables cannot be mixed with ordinal or Blume-Capel variables. When \code{variable_type = "continuous"}, the function simulates from a Gaussian graphical model. Defaults to \code{variable_type = "ordinal"}.} -\item{baseline_category}{An integer vector of length \code{num_variables} specifying the -baseline_category category that is used for the Blume-Capel model (details below). +\item{baseline_category}{An integer vector of length +\code{num_variables} specifying the baseline_category +category that is used for the Blume-Capel model +(details below). Can be any integer value between \code{0} and \code{num_categories} (or \code{num_categories[i]}).} @@ -72,15 +78,15 @@ a seed is generated from R's random number generator (so \code{set.seed()} can be used before calling this function).} } \value{ -A matrix of simulated observations (see [simulate_mrf()]). +A matrix of simulated observations (see \code{\link[=simulate_mrf]{simulate_mrf()}}). } \description{ -`r lifecycle::badge("deprecated")` +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} -`mrfSampler()` was renamed to [simulate_mrf()] as of bgms 0.1.6.3 to +\code{mrfSampler()} was renamed to \code{\link[=simulate_mrf]{simulate_mrf()}} as of bgms 0.1.6.3 to follow the package's naming conventions. } \seealso{ -[simulate_mrf()] for the current function. +\code{\link[=simulate_mrf]{simulate_mrf()}} for the current function. } \keyword{internal} diff --git a/man/predict.bgmCompare.Rd b/man/predict.bgmCompare.Rd index 0d407dd1..f8aacee2 100644 --- a/man/predict.bgmCompare.Rd +++ b/man/predict.bgmCompare.Rd @@ -26,30 +26,31 @@ prediction (1 to number of groups). Required argument.} \item{variables}{Which variables to predict. Can be: \itemize{ - \item A character vector of variable names - \item An integer vector of column indices - \item \code{NULL} (default) to predict all variables +\item A character vector of variable names +\item An integer vector of column indices +\item \code{NULL} (default) to predict all variables }} \item{type}{Character string specifying the type of prediction: \describe{ - \item{\code{"probabilities"}}{Return the full conditional probability - distribution for each variable and observation.} - \item{\code{"response"}}{Return the predicted category (mode of the - conditional distribution).} +\item{\code{"probabilities"}}{Return the full conditional probability +distribution for each variable and observation.} +\item{\code{"response"}}{Return the predicted category (mode of the +conditional distribution).} }} \item{method}{Character string specifying which parameter estimates to use: \describe{ - \item{\code{"posterior-mean"}}{Use posterior mean parameters.} +\item{\code{"posterior-mean"}}{Use posterior mean parameters.} }} \item{...}{Additional arguments (currently ignored).} } \value{ -For \code{type = "probabilities"}: A named list with one element per -predicted variable. Each element is a matrix with \code{n} rows and -\code{num_categories + 1} columns containing \eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} +For \code{type = "probabilities"}: A named list with one +element per predicted variable. Each element is a matrix with +\code{n} rows and \code{num_categories + 1} columns containing +\eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} for each observation and category. For \code{type = "response"}: A matrix with \code{n} rows and @@ -82,8 +83,10 @@ pred_g2 = predict(fit, newdata = y[1:10, ], group = 2, type = "response") } \seealso{ -\code{\link{predict.bgms}} for predicting from single-group models, - \code{\link{simulate.bgmCompare}} for simulating from group-comparison models. +\code{\link{predict.bgms}} for predicting +from single-group models, +\code{\link{simulate.bgmCompare}} for simulating +from group-comparison models. Other prediction: \code{\link{predict.bgms}()}, diff --git a/man/predict.bgms.Rd b/man/predict.bgms.Rd index 37e16319..285a34cb 100644 --- a/man/predict.bgms.Rd +++ b/man/predict.bgms.Rd @@ -24,27 +24,29 @@ the original data used to fit the model.} \item{variables}{Which variables to predict. Can be: \itemize{ - \item A character vector of variable names - \item An integer vector of column indices - \item \code{NULL} (default) to predict all variables +\item A character vector of variable names +\item An integer vector of column indices +\item \code{NULL} (default) to predict all variables }} \item{type}{Character string specifying the type of prediction: \describe{ - \item{\code{"probabilities"}}{Return the full conditional probability - distribution for each variable and observation.} - \item{\code{"response"}}{Return the predicted category (mode of the - conditional distribution).} +\item{\code{"probabilities"}}{Return the full conditional probability +distribution for each variable and observation.} +\item{\code{"response"}}{Return the predicted category (mode of the +conditional distribution).} }} \item{method}{Character string specifying which parameter estimates to use: \describe{ - \item{\code{"posterior-mean"}}{Use posterior mean parameters.} - \item{\code{"posterior-sample"}}{Average predictions over posterior draws.} +\item{\code{"posterior-mean"}}{Use posterior mean parameters.} +\item{\code{"posterior-sample"}}{Average predictions +over posterior draws.} }} \item{ndraws}{Number of posterior draws to use when -\code{method = "posterior-sample"}. If \code{NULL}, uses all available draws.} +\code{method = "posterior-sample"}. If \code{NULL}, +uses all available draws.} \item{seed}{Optional random seed for reproducibility when \code{method = "posterior-sample"}.} @@ -56,8 +58,9 @@ the original data used to fit the model.} For \code{type = "probabilities"}: A named list with one element per predicted variable. Each element is a matrix with \code{n} rows and -\code{num_categories + 1} columns containing \eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} for each -observation and category. +\code{num_categories + 1} columns containing +\eqn{P(X_j = c | X_{-j})}{P(X_j = c | X_-j)} +for each observation and category. For \code{type = "response"}: A matrix with \code{n} rows and \code{length(variables)} columns containing predicted categories. @@ -79,10 +82,18 @@ For \code{type = "response"}: A matrix with \code{n} rows and When \code{method = "posterior-sample"}, conditional parameters are averaged over posterior draws, and an attribute \code{"sd"} is included. + +\strong{Mixed MRF models:} + +For mixed models, the return list contains elements for both discrete and +continuous predicted variables. Discrete variables return probability +matrices (as in ordinal models); continuous variables return conditional +mean and SD matrices (as in GGM models). } \description{ Computes conditional probability distributions for one or more variables -given the observed values of other variables in the data. +given the observed values of other variables in the data. Supports ordinal, +Blume-Capel, continuous (GGM), and mixed MRF models. } \details{ For each observation, the function computes the conditional distribution @@ -92,7 +103,8 @@ sampler. For GGM (continuous) models, the conditional distribution of \eqn{X_j | X_{-j}}{X_j | X_{-j}} is Gaussian with mean -\eqn{-\omega_{jj}^{-1} \sum_{k \neq j} \omega_{jk} x_k}{-omega_jj^{-1} sum_{k != j} omega_jk x_k} +\eqn{-\omega_{jj}^{-1} \sum_{k \neq j} +\omega_{jk} x_k}{-omega_jj^{-1} sum_{k != j} omega_jk x_k} and variance \eqn{\omega_{jj}^{-1}}{omega_jj^{-1}}, where \eqn{\Omega}{Omega} is the precision matrix. } diff --git a/man/print.bgmCompare.Rd b/man/print.bgmCompare.Rd index 7c4fd715..334b1988 100644 --- a/man/print.bgmCompare.Rd +++ b/man/print.bgmCompare.Rd @@ -2,20 +2,20 @@ % Please edit documentation in R/bgmcompare-methods.r \name{print.bgmCompare} \alias{print.bgmCompare} -\title{Print method for `bgmCompare` objects} +\title{Print method for \code{bgmCompare} objects} \usage{ \method{print}{bgmCompare}(x, ...) } \arguments{ -\item{x}{An object of class `bgmCompare`.} +\item{x}{An object of class \code{bgmCompare}.} \item{...}{Ignored.} } \value{ -Invisibly returns `x`. +Invisibly returns \code{x}. } \description{ -Minimal console output for `bgmCompare` fit objects. +Minimal console output for \code{bgmCompare} fit objects. } \examples{ \donttest{ @@ -24,7 +24,7 @@ Minimal console output for `bgmCompare` fit objects. } \seealso{ -[bgmCompare()], [summary.bgmCompare()], [coef.bgmCompare()] +\code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=summary.bgmCompare]{summary.bgmCompare()}}, \code{\link[=coef.bgmCompare]{coef.bgmCompare()}} Other posterior-methods: \code{\link{coef.bgmCompare}()}, diff --git a/man/print.bgms.Rd b/man/print.bgms.Rd index 156bfc92..52a9a823 100644 --- a/man/print.bgms.Rd +++ b/man/print.bgms.Rd @@ -2,20 +2,20 @@ % Please edit documentation in R/bgms-methods.R \name{print.bgms} \alias{print.bgms} -\title{Print method for `bgms` objects} +\title{Print method for \code{bgms} objects} \usage{ \method{print}{bgms}(x, ...) } \arguments{ -\item{x}{An object of class `bgms`.} +\item{x}{An object of class \code{bgms}.} \item{...}{Ignored.} } \value{ -Invisibly returns `x`. +Invisibly returns \code{x}. } \description{ -Minimal console output for `bgms` fit objects. +Minimal console output for \code{bgms} fit objects. } \examples{ \donttest{ @@ -25,7 +25,7 @@ print(fit) } \seealso{ -[bgm()], [summary.bgms()], [coef.bgms()] +\code{\link[=bgm]{bgm()}}, \code{\link[=summary.bgms]{summary.bgms()}}, \code{\link[=coef.bgms]{coef.bgms()}} Other posterior-methods: \code{\link{coef.bgmCompare}()}, diff --git a/man/simulate.bgmCompare.Rd b/man/simulate.bgmCompare.Rd index 6a16c1dd..718315dd 100644 --- a/man/simulate.bgmCompare.Rd +++ b/man/simulate.bgmCompare.Rd @@ -26,8 +26,8 @@ number of groups). Required argument.} \item{method}{Character string specifying which parameter estimates to use: \describe{ - \item{\code{"posterior-mean"}}{Use posterior mean parameters (faster, - single simulation).} +\item{\code{"posterior-mean"}}{Use posterior mean parameters (faster, +single simulation).} }} \item{iter}{Number of Gibbs iterations for equilibration before collecting @@ -37,7 +37,7 @@ samples. Default: \code{1000}.} } \value{ A matrix with \code{nsim} rows and \code{p} columns containing - simulated observations for the specified group. +simulated observations for the specified group. } \description{ Generates new observations from the Markov Random Field model for a @@ -69,7 +69,7 @@ new_data_g2 = simulate(fit, nsim = 100, group = 2) } \seealso{ \code{\link{simulate.bgms}} for simulating from single-group models, - \code{\link{predict.bgmCompare}} for computing conditional probabilities. +\code{\link{predict.bgmCompare}} for computing conditional probabilities. Other prediction: \code{\link{predict.bgmCompare}()}, diff --git a/man/simulate.bgms.Rd b/man/simulate.bgms.Rd index e988e3a2..9185e5ff 100644 --- a/man/simulate.bgms.Rd +++ b/man/simulate.bgms.Rd @@ -25,15 +25,16 @@ \item{method}{Character string specifying which parameter estimates to use: \describe{ - \item{\code{"posterior-mean"}}{Use posterior mean parameters (faster, - single simulation).} - \item{\code{"posterior-sample"}}{Sample from posterior draws, producing - one dataset per draw (accounts for parameter uncertainty). This method - uses parallel processing when \code{cores > 1}.} +\item{\code{"posterior-mean"}}{Use posterior mean parameters (faster, +single simulation).} +\item{\code{"posterior-sample"}}{Sample from posterior draws, producing +one dataset per draw (accounts for parameter uncertainty). This method +uses parallel processing when \code{cores > 1}.} }} \item{ndraws}{Number of posterior draws to use when -\code{method = "posterior-sample"}. If \code{NULL}, uses all available draws.} +\code{method = "posterior-sample"}. If \code{NULL}, +uses all available draws.} \item{iter}{Number of Gibbs iterations for equilibration before collecting samples. Default: \code{1000}.} @@ -54,14 +55,20 @@ If \code{method = "posterior-mean"}: A matrix with \code{nsim} rows and If \code{method = "posterior-sample"}: A list of matrices, one per posterior draw, each with \code{nsim} rows and \code{p} columns. + +For mixed MRF models, discrete columns contain non-negative integers and +continuous columns contain real-valued observations, ordered as in the +original data. } \description{ Generates new observations from the Markov Random Field model using the -estimated parameters from a fitted \code{bgms} object. +estimated parameters from a fitted \code{bgms} object. Supports ordinal, +Blume-Capel, continuous (GGM), and mixed MRF models. } \details{ -This function uses the estimated interaction and threshold parameters to -generate new data via Gibbs sampling. When \code{method = "posterior-sample"}, +This function uses the estimated interaction and threshold +parameters to generate new data via Gibbs sampling. When +\code{method = "posterior-sample"}, parameter uncertainty is parameter uncertainty is propagated to the simulated data by using different posterior draws. Parallel processing is available for this method via the \code{cores} argument. @@ -75,7 +82,11 @@ fit = bgm(x = Wenchuan[, 1:5], chains = 2) new_data = simulate(fit, nsim = 100) # Simulate with parameter uncertainty (10 datasets) -new_data_list = simulate(fit, nsim = 100, method = "posterior-sample", ndraws = 10) +new_data_list = simulate( + fit, + nsim = 100, + method = "posterior-sample", ndraws = 10 +) # Use parallel processing for faster simulation new_data_list = simulate(fit, @@ -87,7 +98,7 @@ new_data_list = simulate(fit, } \seealso{ \code{\link{predict.bgms}} for computing conditional probabilities, - \code{\link{simulate_mrf}} for simulation with user-specified parameters. +\code{\link{simulate_mrf}} for simulation with user-specified parameters. Other prediction: \code{\link{predict.bgmCompare}()}, diff --git a/man/simulate_mrf.Rd b/man/simulate_mrf.Rd index 6124a965..a2260235 100644 --- a/man/simulate_mrf.Rd +++ b/man/simulate_mrf.Rd @@ -21,15 +21,18 @@ simulate_mrf( \item{num_variables}{The number of variables in the MRF.} -\item{num_categories}{Either a positive integer or a vector of positive -integers of length \code{num_variables}. The number of response categories on top -of the base category: \code{num_categories = 1} generates binary states. +\item{num_categories}{Either a positive integer or a vector +of positive integers of length \code{num_variables}. The +number of response categories on top of the base category: +\code{num_categories = 1} generates binary states. Only used for ordinal and Blume-Capel variables; ignored when \code{variable_type = "continuous"}.} -\item{pairwise}{A symmetric \code{num_variables} by \code{num_variables} matrix. -For ordinal and Blume-Capel variables, this contains the pairwise interaction -parameters; only the off-diagonal elements are used. For continuous variables, +\item{pairwise}{A symmetric \code{num_variables} by +\code{num_variables} matrix. For ordinal and Blume-Capel +variables, this contains the pairwise interaction parameters; +only the off-diagonal elements are used. For continuous +variables, this is the precision matrix \eqn{\Omega}{Omega} (including diagonal) and must be positive definite.} @@ -37,7 +40,8 @@ must be positive definite.} \code{num_variables} by \code{max(num_categories)} matrix of category thresholds. The elements in row \code{i} indicate the thresholds of variable \code{i}. If \code{num_categories} is a vector, only the first -\code{num_categories[i]} elements are used in row \code{i}. If the Blume-Capel +\code{num_categories[i]} elements are used in row \code{i}. +If the Blume-Capel model is used for the category thresholds for variable \code{i}, then row \code{i} requires two values (details below); the first is \eqn{\alpha}{\alpha}, the linear contribution of the Blume-Capel model and @@ -49,16 +53,18 @@ if not supplied or if all values are zero.} \item{variable_type}{What kind of variables are simulated? Can be a single character string specifying the variable type of all \code{p} variables at once or a vector of character strings of length \code{p} specifying the type -for each variable separately. Currently, bgm supports ``ordinal'', -``blume-capel'', and ``continuous''. Binary variables are automatically -treated as ``ordinal''. Ordinal and Blume-Capel variables can be mixed +for each variable separately. Currently, bgm supports \code{"ordinal"}, +\code{"blume-capel"}, and \code{"continuous"}. Binary variables are automatically +treated as \code{"ordinal"}. Ordinal and Blume-Capel variables can be mixed freely, but continuous variables cannot be mixed with ordinal or Blume-Capel variables. When \code{variable_type = "continuous"}, the function simulates from a Gaussian graphical model. Defaults to \code{variable_type = "ordinal"}.} -\item{baseline_category}{An integer vector of length \code{num_variables} specifying the -baseline_category category that is used for the Blume-Capel model (details below). +\item{baseline_category}{An integer vector of length +\code{num_variables} specifying the baseline_category +category that is used for the Blume-Capel model +(details below). Can be any integer value between \code{0} and \code{num_categories} (or \code{num_categories[i]}).} @@ -77,10 +83,11 @@ observations. For ordinal/Blume-Capel variables, entries are non-negative integers. For continuous variables, entries are real-valued. } \description{ -`simulate_mrf()` generates observations from a Markov Random Field using -user-specified parameters. For ordinal and Blume-Capel variables, observations -are generated via Gibbs sampling. For continuous variables (Gaussian graphical -model), observations are drawn directly from the multivariate normal +\code{simulate_mrf()} generates observations from a Markov Random +Field using user-specified parameters. For ordinal and +Blume-Capel variables, observations are generated via Gibbs +sampling. For continuous variables (Gaussian graphical model), +observations are drawn directly from the multivariate normal distribution implied by the precision matrix. } \details{ @@ -92,8 +99,8 @@ conditional distribution given the other variable states. \strong{Continuous variables (GGM):} Observations are drawn from \eqn{N(\mu, \Omega^{-1})}{N(mu, Omega^{-1})} where \eqn{\Omega}{Omega} is the precision matrix specified via -`pairwise` and \eqn{\mu}{mu} is the means vector specified via `main`. -No Gibbs sampling is needed; `iter` is ignored. +\code{pairwise} and \eqn{\mu}{mu} is the means vector specified via \code{main}. +No Gibbs sampling is needed; \code{iter} is ignored. There are two modeling options for the category thresholds. The default option assumes that the category thresholds are free, except that the first @@ -106,7 +113,8 @@ The Blume-Capel option is specifically designed for ordinal variables that have a special type of baseline_category category, such as the neutral category in a Likert scale. The Blume-Capel model specifies the following quadratic model for the threshold parameters: -\deqn{\mu_{\text{c}} = \alpha \times (\text{c} - \text{r}) + \beta \times (\text{c} - \text{r})^2,}{{\mu_{\text{c}} = \alpha \times (\text{c} - \text{r}) + \beta \times (\text{c} - \text{r})^2,}} +\deqn{\mu_{\text{c}} = \alpha (\text{c} - \text{r}) + + \beta (\text{c} - \text{r})^2} where \eqn{\mu_{\text{c}}}{\mu_{\text{c}}} is the threshold for category c (which now includes zero), \eqn{\alpha}{\alpha} offers a linear trend across categories (increasing threshold values if diff --git a/man/summary.bgmCompare.Rd b/man/summary.bgmCompare.Rd index 039a4aa6..5b93a0fc 100644 --- a/man/summary.bgmCompare.Rd +++ b/man/summary.bgmCompare.Rd @@ -2,20 +2,20 @@ % Please edit documentation in R/bgmcompare-methods.r \name{summary.bgmCompare} \alias{summary.bgmCompare} -\title{Summary method for `bgmCompare` objects} +\title{Summary method for \code{bgmCompare} objects} \usage{ \method{summary}{bgmCompare}(object, ...) } \arguments{ -\item{object}{An object of class `bgmCompare`.} +\item{object}{An object of class \code{bgmCompare}.} \item{...}{Currently ignored.} } \value{ -An object of class `summary.bgmCompare` with posterior summaries. +An object of class \code{summary.bgmCompare} with posterior summaries. } \description{ -Returns posterior summaries and diagnostics for a fitted `bgmCompare` model. +Returns posterior summaries and diagnostics for a fitted \code{bgmCompare} model. } \examples{ \donttest{ @@ -24,7 +24,7 @@ Returns posterior summaries and diagnostics for a fitted `bgmCompare` model. } \seealso{ -[bgmCompare()], [print.bgmCompare()], [coef.bgmCompare()] +\code{\link[=bgmCompare]{bgmCompare()}}, \code{\link[=print.bgmCompare]{print.bgmCompare()}}, \code{\link[=coef.bgmCompare]{coef.bgmCompare()}} Other posterior-methods: \code{\link{coef.bgmCompare}()}, diff --git a/man/summary.bgms.Rd b/man/summary.bgms.Rd index 5dea0fb8..e1b4b84c 100644 --- a/man/summary.bgms.Rd +++ b/man/summary.bgms.Rd @@ -2,20 +2,20 @@ % Please edit documentation in R/bgms-methods.R \name{summary.bgms} \alias{summary.bgms} -\title{Summary method for `bgms` objects} +\title{Summary method for \code{bgms} objects} \usage{ \method{summary}{bgms}(object, ...) } \arguments{ -\item{object}{An object of class `bgms`.} +\item{object}{An object of class \code{bgms}.} \item{...}{Currently ignored.} } \value{ -An object of class `summary.bgms` with posterior summaries. +An object of class \code{summary.bgms} with posterior summaries. } \description{ -Returns posterior summaries and diagnostics for a fitted `bgms` model. +Returns posterior summaries and diagnostics for a fitted \code{bgms} model. } \examples{ \donttest{ @@ -25,7 +25,7 @@ summary(fit) } \seealso{ -[bgm()], [print.bgms()], [coef.bgms()] +\code{\link[=bgm]{bgm()}}, \code{\link[=print.bgms]{print.bgms()}}, \code{\link[=coef.bgms]{coef.bgms()}} Other posterior-methods: \code{\link{coef.bgmCompare}()}, diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 805000be..477ea174 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -121,6 +121,27 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// compute_conditional_mixed +Rcpp::List compute_conditional_mixed(const arma::imat& x_observations, const arma::mat& y_observations, const arma::ivec& predict_vars, const arma::mat& Kxx, const arma::mat& Kxy, const arma::mat& Kyy, const arma::mat& mux, const arma::vec& muy, const arma::ivec& num_categories, const Rcpp::StringVector& variable_type, const arma::ivec& baseline_category); +RcppExport SEXP _bgms_compute_conditional_mixed(SEXP x_observationsSEXP, SEXP y_observationsSEXP, SEXP predict_varsSEXP, SEXP KxxSEXP, SEXP KxySEXP, SEXP KyySEXP, SEXP muxSEXP, SEXP muySEXP, SEXP num_categoriesSEXP, SEXP variable_typeSEXP, SEXP baseline_categorySEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::imat& >::type x_observations(x_observationsSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type y_observations(y_observationsSEXP); + Rcpp::traits::input_parameter< const arma::ivec& >::type predict_vars(predict_varsSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type Kxx(KxxSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type Kxy(KxySEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type Kyy(KyySEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type mux(muxSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type muy(muySEXP); + Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP); + Rcpp::traits::input_parameter< const Rcpp::StringVector& >::type variable_type(variable_typeSEXP); + Rcpp::traits::input_parameter< const arma::ivec& >::type baseline_category(baseline_categorySEXP); + rcpp_result_gen = Rcpp::wrap(compute_conditional_mixed(x_observations, y_observations, predict_vars, Kxx, Kxy, Kyy, mux, muy, num_categories, variable_type, baseline_category)); + return rcpp_result_gen; +END_RCPP +} // sample_omrf_gibbs IntegerMatrix sample_omrf_gibbs(int num_states, int num_variables, IntegerVector num_categories, NumericMatrix pairwise, NumericMatrix main, int iter, int seed); RcppExport SEXP _bgms_sample_omrf_gibbs(SEXP num_statesSEXP, SEXP num_variablesSEXP, SEXP num_categoriesSEXP, SEXP pairwiseSEXP, SEXP mainSEXP, SEXP iterSEXP, SEXP seedSEXP) { @@ -212,6 +233,53 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// sample_mixed_mrf_gibbs +Rcpp::List sample_mixed_mrf_gibbs(int num_states, NumericMatrix Kxx_r, NumericMatrix Kxy_r, NumericMatrix Kyy_r, NumericMatrix mux_r, NumericVector muy_r, IntegerVector num_categories_r, Rcpp::StringVector variable_type_r, IntegerVector baseline_category_r, int iter, int seed); +RcppExport SEXP _bgms_sample_mixed_mrf_gibbs(SEXP num_statesSEXP, SEXP Kxx_rSEXP, SEXP Kxy_rSEXP, SEXP Kyy_rSEXP, SEXP mux_rSEXP, SEXP muy_rSEXP, SEXP num_categories_rSEXP, SEXP variable_type_rSEXP, SEXP baseline_category_rSEXP, SEXP iterSEXP, SEXP seedSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< int >::type num_states(num_statesSEXP); + Rcpp::traits::input_parameter< NumericMatrix >::type Kxx_r(Kxx_rSEXP); + Rcpp::traits::input_parameter< NumericMatrix >::type Kxy_r(Kxy_rSEXP); + Rcpp::traits::input_parameter< NumericMatrix >::type Kyy_r(Kyy_rSEXP); + Rcpp::traits::input_parameter< NumericMatrix >::type mux_r(mux_rSEXP); + Rcpp::traits::input_parameter< NumericVector >::type muy_r(muy_rSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type num_categories_r(num_categories_rSEXP); + Rcpp::traits::input_parameter< Rcpp::StringVector >::type variable_type_r(variable_type_rSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type baseline_category_r(baseline_category_rSEXP); + Rcpp::traits::input_parameter< int >::type iter(iterSEXP); + Rcpp::traits::input_parameter< int >::type seed(seedSEXP); + rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf_gibbs(num_states, Kxx_r, Kxy_r, Kyy_r, mux_r, muy_r, num_categories_r, variable_type_r, baseline_category_r, iter, seed)); + return rcpp_result_gen; +END_RCPP +} +// run_mixed_simulation_parallel +Rcpp::List run_mixed_simulation_parallel(const arma::mat& mux_samples, const arma::mat& kxx_samples, const arma::mat& muy_samples, const arma::mat& kyy_samples, const arma::mat& kxy_samples, const arma::ivec& draw_indices, int num_states, int p, int q, const arma::ivec& num_categories, const Rcpp::StringVector& variable_type_r, const arma::ivec& baseline_category, int iter, int nThreads, int seed, int progress_type); +RcppExport SEXP _bgms_run_mixed_simulation_parallel(SEXP mux_samplesSEXP, SEXP kxx_samplesSEXP, SEXP muy_samplesSEXP, SEXP kyy_samplesSEXP, SEXP kxy_samplesSEXP, SEXP draw_indicesSEXP, SEXP num_statesSEXP, SEXP pSEXP, SEXP qSEXP, SEXP num_categoriesSEXP, SEXP variable_type_rSEXP, SEXP baseline_categorySEXP, SEXP iterSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type mux_samples(mux_samplesSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type kxx_samples(kxx_samplesSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type muy_samples(muy_samplesSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type kyy_samples(kyy_samplesSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type kxy_samples(kxy_samplesSEXP); + Rcpp::traits::input_parameter< const arma::ivec& >::type draw_indices(draw_indicesSEXP); + Rcpp::traits::input_parameter< int >::type num_states(num_statesSEXP); + Rcpp::traits::input_parameter< int >::type p(pSEXP); + Rcpp::traits::input_parameter< int >::type q(qSEXP); + Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP); + Rcpp::traits::input_parameter< const Rcpp::StringVector& >::type variable_type_r(variable_type_rSEXP); + Rcpp::traits::input_parameter< const arma::ivec& >::type baseline_category(baseline_categorySEXP); + Rcpp::traits::input_parameter< int >::type iter(iterSEXP); + Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); + Rcpp::traits::input_parameter< int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(run_mixed_simulation_parallel(mux_samples, kxx_samples, muy_samples, kyy_samples, kxy_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type)); + return rcpp_result_gen; +END_RCPP +} // sample_ggm Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const bool na_impute, const Rcpp::Nullable missing_index_nullable); RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP) { @@ -241,6 +309,40 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// sample_mixed_mrf +Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable); +RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); + Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type dirichlet_alpha(dirichlet_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); + Rcpp::traits::input_parameter< const std::string& >::type sampler_type(sampler_typeSEXP); + Rcpp::traits::input_parameter< const double >::type target_acceptance(target_acceptanceSEXP); + Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); + Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); + Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); + Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_discrete_nullable(missing_index_discrete_nullableSEXP); + Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_continuous_nullable(missing_index_continuous_nullableSEXP); + rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)); + return rcpp_result_gen; +END_RCPP +} // sample_omrf Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable); RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP) { @@ -297,12 +399,16 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, {"_bgms_compute_conditional_ggm", (DL_FUNC) &_bgms_compute_conditional_ggm, 3}, {"_bgms_compute_conditional_probs", (DL_FUNC) &_bgms_compute_conditional_probs, 7}, + {"_bgms_compute_conditional_mixed", (DL_FUNC) &_bgms_compute_conditional_mixed, 11}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 7}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, {"_bgms_sample_ggm_direct", (DL_FUNC) &_bgms_sample_ggm_direct, 4}, {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, {"_bgms_run_ggm_simulation_parallel", (DL_FUNC) &_bgms_run_ggm_simulation_parallel, 9}, + {"_bgms_sample_mixed_mrf_gibbs", (DL_FUNC) &_bgms_sample_mixed_mrf_gibbs, 11}, + {"_bgms_run_mixed_simulation_parallel", (DL_FUNC) &_bgms_run_mixed_simulation_parallel, 16}, {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 19}, + {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 24}, {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} diff --git a/src/bgmCompare/bgmCompare_helper.cpp b/src/bgmCompare/bgmCompare_helper.cpp index 9f7ee37d..9f87246a 100644 --- a/src/bgmCompare/bgmCompare_helper.cpp +++ b/src/bgmCompare/bgmCompare_helper.cpp @@ -5,33 +5,31 @@ -/** - * Computes group-specific main effects for a given variable (bgmCompare model). - * - * For a variable, the main-effect parameters are stored with: - * - One "baseline" column (shared across groups). - * - Additional columns for group-specific deviations. - * - * This function extracts the rows corresponding to the variable’s categories - * and combines them with the group projection vector to yield the - * group-specific main effects. - * - * Inputs: - * - variable: Index of the variable of interest. - * - num_groups: Number of groups in the model. - * - main_effects: Matrix of main-effect parameters for all variables. - * - main_effect_indices: Index matrix giving [start_row, end_row] for each variable. - * - proj_group: Projection vector selecting the group (length = num_groups − 1). - * - * Returns: - * - A vector of group-specific main effects for the categories of the variable. - * - * Notes: - * - The projection vector should match the encoding used for group effects - * (e.g. dummy or contrast coding). - * - This function is used in likelihood evaluations where group-specific - * parameters are required. - */ +// Computes group-specific main effects for a given variable (bgmCompare model). +// +// For a variable, the main-effect parameters are stored with: +// - One "baseline" column (shared across groups). +// - Additional columns for group-specific deviations. +// +// This function extracts the rows corresponding to the variable’s categories +// and combines them with the group projection vector to yield the +// group-specific main effects. +// +// Inputs: +// - variable: Index of the variable of interest. +// - num_groups: Number of groups in the model. +// - main_effects: Matrix of main-effect parameters for all variables. +// - main_effect_indices: Index matrix giving [start_row, end_row] for each variable. +// - proj_group: Projection vector selecting the group (length = num_groups − 1). +// +// Returns: +// - A vector of group-specific main effects for the categories of the variable. +// +// Notes: +// - The projection vector should match the encoding used for group effects +// (e.g. dummy or contrast coding). +// - This function is used in likelihood evaluations where group-specific +// parameters are required. arma::vec compute_group_main_effects( const int variable, const int num_groups, @@ -57,33 +55,31 @@ arma::vec compute_group_main_effects( -/** - * Computes the group-specific pairwise effect for a variable pair (bgmCompare model). - * - * For each variable pair, the pairwise-effect parameters are stored with: - * - One "baseline" column (shared across groups). - * - Additional columns for group-specific deviations. - * - * This function extracts the baseline effect and, if the edge is active - * (per the inclusion indicator), adds the group-specific deviation obtained - * from the projection vector. - * - * Inputs: - * - var1, var2: Indices of the two variables forming the pair. - * - num_groups: Number of groups in the model. - * - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs). - * - pairwise_effect_indices: Lookup matrix mapping (var1, var2) → row index. - * - inclusion_indicator: Symmetric binary matrix of active edges. - * - proj_group: Projection vector selecting the group (length = num_groups − 1). - * - * Returns: - * - The group-specific pairwise effect for (var1, var2). - * - * Notes: - * - The index matrix must match the storage convention (typically var1 < var2). - * - If `inclusion_indicator(var1, var2) == 0`, only the baseline effect is used. - * - This function is used in likelihood evaluations and Gibbs updates. - */ +// Computes the group-specific pairwise effect for a variable pair (bgmCompare model). +// +// For each variable pair, the pairwise-effect parameters are stored with: +// - One "baseline" column (shared across groups). +// - Additional columns for group-specific deviations. +// +// This function extracts the baseline effect and, if the edge is active +// (per the inclusion indicator), adds the group-specific deviation obtained +// from the projection vector. +// +// Inputs: +// - var1, var2: Indices of the two variables forming the pair. +// - num_groups: Number of groups in the model. +// - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs). +// - pairwise_effect_indices: Lookup matrix mapping (var1, var2) → row index. +// - inclusion_indicator: Symmetric binary matrix of active edges. +// - proj_group: Projection vector selecting the group (length = num_groups − 1). +// +// Returns: +// - The group-specific pairwise effect for (var1, var2). +// +// Notes: +// - The index matrix must match the storage convention (typically var1 < var2). +// - If `inclusion_indicator(var1, var2) == 0`, only the baseline effect is used. +// - This function is used in likelihood evaluations and Gibbs updates. double compute_group_pairwise_effects( const int var1, const int var2, @@ -111,36 +107,34 @@ double compute_group_pairwise_effects( -/** - * Flattens main-effect and pairwise-effect parameters into a single vector (bgmCompare model). - * - * Layout of the output vector: - * 1. Main-effect overall parameters (column 0 of main_effects), stacked by variable. - * 2. Pairwise-effect overall parameters (column 0 of pairwise_effects), stacked by pair. - * 3. Main-effect group differences (columns 1..G-1), included only if - * the variable is marked active in inclusion_indicator(v,v). - * 4. Pairwise-effect group differences (columns 1..G-1), included only if - * the pair is marked active in inclusion_indicator(v1,v2). - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). - * - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - num_categories: Number of categories per variable. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - * Returns: - * - A flat vector of parameters containing overall effects and (if active) group differences. - * - * Notes: - * - The order of pairs in `pairwise_effects` must match the upper-triangle order - * of (var1,var2) pairs as constructed in R. - * - The length of the output vector depends on both the number of groups - * and the active entries in inclusion_indicator. - * - This function is the inverse of `unvectorize_model_parameters_bgmcompare()`. - */ +// Flattens main-effect and pairwise-effect parameters into a single vector (bgmCompare model). +// +// Layout of the output vector: +// 1. Main-effect overall parameters (column 0 of main_effects), stacked by variable. +// 2. Pairwise-effect overall parameters (column 0 of pairwise_effects), stacked by pair. +// 3. Main-effect group differences (columns 1..G-1), included only if +// the variable is marked active in inclusion_indicator(v,v). +// 4. Pairwise-effect group differences (columns 1..G-1), included only if +// the pair is marked active in inclusion_indicator(v1,v2). +// +// Inputs: +// - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). +// - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - num_categories: Number of categories per variable. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// +// Returns: +// - A flat vector of parameters containing overall effects and (if active) group differences. +// +// Notes: +// - The order of pairs in `pairwise_effects` must match the upper-triangle order +// of (var1,var2) pairs as constructed in R. +// - The length of the output vector depends on both the number of groups +// and the active entries in inclusion_indicator. +// - This function is the inverse of `unvectorize_model_parameters_bgmcompare()`. arma::vec vectorize_model_parameters_bgmcompare( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -217,38 +211,36 @@ arma::vec vectorize_model_parameters_bgmcompare( -/** - * Reconstructs main-effect and pairwise-effect matrices from a flat parameter vector (bgmCompare model). - * - * The input vector must follow the layout produced by `vectorize_model_parameters_bgmcompare()`: - * 1. Main-effect overall parameters (column 0 of main_effects), stacked by variable. - * 2. Pairwise-effect overall parameters (column 0 of pairwise_effects), stacked by pair. - * 3. Main-effect group differences (columns 1..G-1), included only if - * the variable is active in inclusion_indicator(v,v). - * 4. Pairwise-effect group differences (columns 1..G-1), included only if - * the pair is active in inclusion_indicator(v1,v2). - * - * Inputs: - * - param_vec: Flattened parameter vector. - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - num_groups: Number of groups (columns in main_effects / pairwise_effects). - * - num_categories: Number of categories per variable. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - * Outputs: - * - main_effects_out: Matrix of main-effect parameters (rows = categories, cols = groups). - * - pairwise_effects_out: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). - * - * Notes: - * - The vector must have exactly the length returned by `vectorize_model_parameters_bgmcompare()`. - * - Diagonal entries in `inclusion_indicator` determine whether main-effect - * group differences are included. - * - Off-diagonal entries in `inclusion_indicator` determine whether - * pairwise-effect group differences are included. - * - This function is the inverse of `vectorize_model_parameters_bgmcompare()`. - */ +// Reconstructs main-effect and pairwise-effect matrices from a flat parameter vector (bgmCompare model). +// +// The input vector must follow the layout produced by `vectorize_model_parameters_bgmcompare()`: +// 1. Main-effect overall parameters (column 0 of main_effects), stacked by variable. +// 2. Pairwise-effect overall parameters (column 0 of pairwise_effects), stacked by pair. +// 3. Main-effect group differences (columns 1..G-1), included only if +// the variable is active in inclusion_indicator(v,v). +// 4. Pairwise-effect group differences (columns 1..G-1), included only if +// the pair is active in inclusion_indicator(v1,v2). +// +// Inputs: +// - param_vec: Flattened parameter vector. +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - num_groups: Number of groups (columns in main_effects / pairwise_effects). +// - num_categories: Number of categories per variable. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// +// Outputs: +// - main_effects_out: Matrix of main-effect parameters (rows = categories, cols = groups). +// - pairwise_effects_out: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). +// +// Notes: +// - The vector must have exactly the length returned by `vectorize_model_parameters_bgmcompare()`. +// - Diagonal entries in `inclusion_indicator` determine whether main-effect +// group differences are included. +// - Off-diagonal entries in `inclusion_indicator` determine whether +// pairwise-effect group differences are included. +// - This function is the inverse of `vectorize_model_parameters_bgmcompare()`. void unvectorize_model_parameters_bgmcompare( const arma::vec& param_vec, arma::mat& main_effects_out, // [n_main_rows × G] @@ -314,44 +306,42 @@ void unvectorize_model_parameters_bgmcompare( -/** - * Builds index maps linking matrix entries to positions in the vectorized parameter vector (bgmCompare model). - * - * The index maps are used to quickly locate where each parameter (main-effect or pairwise-effect, - * across groups) sits inside the flattened parameter vector produced by - * `vectorize_model_parameters_bgmcompare()`. - * - * Layout: - * 1. Main-effect overall parameters (col 0). - * 2. Pairwise-effect overall parameters (col 0). - * 3. Main-effect group differences (cols 1..G-1), included only if - * the variable is active in inclusion_indicator(v,v). - * 4. Pairwise-effect group differences (cols 1..G-1), included only if - * the pair is active in inclusion_indicator(v1,v2). - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (used for dimension info). - * - pairwise_effects: Matrix of pairwise-effect parameters (used for dimension info). - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - num_categories: Number of categories per variable. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - * Returns: - * - A pair of integer matrices: - * * main_index: [num_main × num_groups], with each entry giving the position - * in the parameter vector for that main-effect parameter (or -1 if inactive). - * * pair_index: [num_pair × num_groups], with each entry giving the position - * in the parameter vector for that pairwise-effect parameter (or -1 if inactive). - * - * Notes: - * - Entries are set to -1 when the corresponding parameter is inactive. - * - The returned index maps must always be consistent with the ordering used - * in vectorization/unvectorization. - * - A final check (e.g. verifying that `off == param_vec.n_elem`) can help - * catch mismatches between index maps and vectorizer logic. - */ +// Builds index maps linking matrix entries to positions in the vectorized parameter vector (bgmCompare model). +// +// The index maps are used to quickly locate where each parameter (main-effect or pairwise-effect, +// across groups) sits inside the flattened parameter vector produced by +// `vectorize_model_parameters_bgmcompare()`. +// +// Layout: +// 1. Main-effect overall parameters (col 0). +// 2. Pairwise-effect overall parameters (col 0). +// 3. Main-effect group differences (cols 1..G-1), included only if +// the variable is active in inclusion_indicator(v,v). +// 4. Pairwise-effect group differences (cols 1..G-1), included only if +// the pair is active in inclusion_indicator(v1,v2). +// +// Inputs: +// - main_effects: Matrix of main-effect parameters (used for dimension info). +// - pairwise_effects: Matrix of pairwise-effect parameters (used for dimension info). +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - num_categories: Number of categories per variable. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// +// Returns: +// - A pair of integer matrices: +// * main_index: [num_main × num_groups], with each entry giving the position +// in the parameter vector for that main-effect parameter (or -1 if inactive). +// * pair_index: [num_pair × num_groups], with each entry giving the position +// in the parameter vector for that pairwise-effect parameter (or -1 if inactive). +// +// Notes: +// - Entries are set to -1 when the corresponding parameter is inactive. +// - The returned index maps must always be consistent with the ordering used +// in vectorization/unvectorization. +// - A final check (e.g. verifying that `off == param_vec.n_elem`) can help +// catch mismatches between index maps and vectorizer logic. std::pair build_index_maps( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -411,37 +401,35 @@ std::pair build_index_maps( -/** - * Extracts entries of the inverse mass matrix corresponding to active parameters (bgmCompare model). - * - * If `selection` is false, the full diagonal vector is returned unchanged. - * If `selection` is true, the output is restricted to: - * 1. Main-effect overall parameters (column 0). - * 2. Pairwise-effect overall parameters (column 0). - * 3. Main-effect group differences (columns 1..G-1) for variables with - * inclusion_indicator(v,v) == 1. - * 4. Pairwise-effect group differences (columns 1..G-1) for pairs with - * inclusion_indicator(v1,v2) == 1. - * - * Inputs: - * - inv_diag: Full inverse mass diagonal (length = all parameters). - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - num_groups: Number of groups. - * - num_categories: Number of categories per variable. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - main_index: Index map for main effects (from build_index_maps()). - * - pair_index: Index map for pairwise effects (from build_index_maps()). - * - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - selection: If true, restrict to active parameters; if false, return full inv_diag. - * - * Returns: - * - A vector containing inverse mass entries for active parameters only. - * - * Notes: - * - Must be consistent with the layout in `vectorize_model_parameters_bgmcompare()`. - * - Index maps (`main_index`, `pair_index`) are required to locate group-difference entries. - */ +// Extracts entries of the inverse mass matrix corresponding to active parameters (bgmCompare model). +// +// If `selection` is false, the full diagonal vector is returned unchanged. +// If `selection` is true, the output is restricted to: +// 1. Main-effect overall parameters (column 0). +// 2. Pairwise-effect overall parameters (column 0). +// 3. Main-effect group differences (columns 1..G-1) for variables with +// inclusion_indicator(v,v) == 1. +// 4. Pairwise-effect group differences (columns 1..G-1) for pairs with +// inclusion_indicator(v1,v2) == 1. +// +// Inputs: +// - inv_diag: Full inverse mass diagonal (length = all parameters). +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - num_groups: Number of groups. +// - num_categories: Number of categories per variable. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - main_index: Index map for main effects (from build_index_maps()). +// - pair_index: Index map for pairwise effects (from build_index_maps()). +// - main_effect_indices: Index ranges [row_start, row_end] for each variable in main_effects. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - selection: If true, restrict to active parameters; if false, return full inv_diag. +// +// Returns: +// - A vector containing inverse mass entries for active parameters only. +// +// Notes: +// - Must be consistent with the layout in `vectorize_model_parameters_bgmcompare()`. +// - Index maps (`main_index`, `pair_index`) are required to locate group-difference entries. arma::vec inv_mass_active( const arma::vec& inv_diag, const arma::imat& inclusion_indicator, diff --git a/src/bgmCompare/bgmCompare_logp_and_grad.cpp b/src/bgmCompare/bgmCompare_logp_and_grad.cpp index 49bb780d..3e8b0aac 100644 --- a/src/bgmCompare/bgmCompare_logp_and_grad.cpp +++ b/src/bgmCompare/bgmCompare_logp_and_grad.cpp @@ -8,34 +8,32 @@ -/** - * Compute the total length of the parameter vector in the bgmCompare model. - * - * The parameter vector consists of: - * 1. Main-effect overall parameters (column 0). - * 2. Pairwise-effect overall parameters (column 0). - * 3. Main-effect group-difference parameters (columns 1..G-1) for variables - * with inclusion_indicator(v,v) == 1. - * 4. Pairwise-effect group-difference parameters (columns 1..G-1) for pairs - * with inclusion_indicator(v1,v2) == 1. - * - * Inputs: - * - num_variables: Number of observed variables. - * - main_effect_indices: Row ranges [start,end] in main_effects for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - inclusion_indicator: Symmetric binary matrix; diagonal entries control main-effect - * differences, off-diagonal entries control pairwise-effect differences. - * - num_categories: Vector of category counts per variable. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel) for each variable. - * - num_groups: Number of groups in the model. - * - * Returns: - * - arma::uword: Total number of parameters in the vectorized model. - * - * Notes: - * - This function must be consistent with vectorize_model_parameters_bgmcompare(). - * - Used to allocate gradient vectors, prior vectors, and mass matrices. - */ +// Compute the total length of the parameter vector in the bgmCompare model. +// +// The parameter vector consists of: +// 1. Main-effect overall parameters (column 0). +// 2. Pairwise-effect overall parameters (column 0). +// 3. Main-effect group-difference parameters (columns 1..G-1) for variables +// with inclusion_indicator(v,v) == 1. +// 4. Pairwise-effect group-difference parameters (columns 1..G-1) for pairs +// with inclusion_indicator(v1,v2) == 1. +// +// Inputs: +// - num_variables: Number of observed variables. +// - main_effect_indices: Row ranges [start,end] in main_effects for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - inclusion_indicator: Symmetric binary matrix; diagonal entries control main-effect +// differences, off-diagonal entries control pairwise-effect differences. +// - num_categories: Vector of category counts per variable. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel) for each variable. +// - num_groups: Number of groups in the model. +// +// Returns: +// - arma::uword: Total number of parameters in the vectorized model. +// +// Notes: +// - This function must be consistent with vectorize_model_parameters_bgmcompare(). +// - Used to allocate gradient vectors, prior vectors, and mass matrices. arma::uword total_length( const int num_variables, const arma::imat& main_effect_indices, @@ -71,49 +69,47 @@ arma::uword total_length( -/** - * Compute the observed-data contribution to the gradient vector - * in the bgmCompare model (active parameterization). - * - * This function accumulates observed sufficient statistics from the data - * and projects them into the parameter vector space. The output has the - * same length and ordering as `vectorize_model_parameters_bgmcompare()`, - * and includes: - * 1. Main-effect overall parameters (column 0). - * 2. Pairwise-effect overall parameters (column 0). - * 3. Main-effect group-difference parameters (columns 1..G-1) if - * inclusion_indicator(v,v) == 1. - * 4. Pairwise-effect group-difference parameters (columns 1..G-1) if - * inclusion_indicator(v1,v2) == 1. - * - * Inputs: - * - main_effect_indices: Row ranges [start,end] in main_effects for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - projection: Matrix of size (num_groups × (num_groups-1)) containing group projections. - * - observations: Matrix of observed variable values (N × V). - * - group_indices: Index ranges [start,end] defining which rows in `observations` - * belong to each group. - * - num_categories: Vector giving the number of categories per variable. - * - inclusion_indicator: Symmetric binary matrix; diagonal entries control inclusion - * of main-effect differences, off-diagonal entries control inclusion of pairwise - * differences. - * - counts_per_category_group: Per-group category count tables (list of matrices). - * - blume_capel_stats_group: Per-group Blume–Capel sufficient statistics (list of matrices). - * - pairwise_stats_group: Per-group pairwise sufficient statistics (list of matrices). - * - num_groups: Number of groups. - * - is_ordinal_variable: Indicator vector (1 = ordinal, 0 = Blume–Capel) per variable. - * - baseline_category: Vector of baseline categories per variable (Blume–Capel). - * - main_index: Index map for main effects (from build_index_maps()). - * - pair_index: Index map for pairwise effects (from build_index_maps()). - * - * Returns: - * - arma::vec: Observed-data contribution to the gradient (length = total_length()). - * - * Notes: - * - This function computes the *data-dependent* part of the gradient only; - * parameter-dependent expected statistics and priors must be added separately. - * - The output ordering must remain consistent with `vectorize_model_parameters_bgmcompare()`. - */ +// Compute the observed-data contribution to the gradient vector +// in the bgmCompare model (active parameterization). +// +// This function accumulates observed sufficient statistics from the data +// and projects them into the parameter vector space. The output has the +// same length and ordering as `vectorize_model_parameters_bgmcompare()`, +// and includes: +// 1. Main-effect overall parameters (column 0). +// 2. Pairwise-effect overall parameters (column 0). +// 3. Main-effect group-difference parameters (columns 1..G-1) if +// inclusion_indicator(v,v) == 1. +// 4. Pairwise-effect group-difference parameters (columns 1..G-1) if +// inclusion_indicator(v1,v2) == 1. +// +// Inputs: +// - main_effect_indices: Row ranges [start,end] in main_effects for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - projection: Matrix of size (num_groups × (num_groups-1)) containing group projections. +// - observations: Matrix of observed variable values (N × V). +// - group_indices: Index ranges [start,end] defining which rows in `observations` +// belong to each group. +// - num_categories: Vector giving the number of categories per variable. +// - inclusion_indicator: Symmetric binary matrix; diagonal entries control inclusion +// of main-effect differences, off-diagonal entries control inclusion of pairwise +// differences. +// - counts_per_category_group: Per-group category count tables (list of matrices). +// - blume_capel_stats_group: Per-group Blume–Capel sufficient statistics (list of matrices). +// - pairwise_stats_group: Per-group pairwise sufficient statistics (list of matrices). +// - num_groups: Number of groups. +// - is_ordinal_variable: Indicator vector (1 = ordinal, 0 = Blume–Capel) per variable. +// - baseline_category: Vector of baseline categories per variable (Blume–Capel). +// - main_index: Index map for main effects (from build_index_maps()). +// - pair_index: Index map for pairwise effects (from build_index_maps()). +// +// Returns: +// - arma::vec: Observed-data contribution to the gradient (length = total_length()). +// +// Notes: +// - This function computes the *data-dependent* part of the gradient only; +// parameter-dependent expected statistics and priors must be added separately. +// - The output ordering must remain consistent with `vectorize_model_parameters_bgmcompare()`. arma::vec gradient_observed_active( const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, @@ -218,63 +214,61 @@ arma::vec gradient_observed_active( -/** - * Computes the gradient of the log pseudoposterior for the bgmCompare model. - * - * The gradient combines three contributions: - * 1. Observed sufficient statistics (precomputed and supplied via `grad_obs`). - * 2. Expected sufficient statistics under the current parameter values - * (computed using softmax probabilities for ordinal or Blume–Capel variables). - * 3. Prior contributions on main effects, pairwise effects, and group differences. - * - * Procedure: - * - Initialize gradient with `grad_obs` (observed-data contribution). - * - Loop over groups: - * * Build group-specific main and pairwise effects using - * `compute_group_main_effects()` and `compute_group_pairwise_effects()`. - * * Compute expected sufficient statistics from residual scores and - * subtract them from the gradient. - * - Add prior contributions: - * * Logistic–Beta prior gradient for main-effect baseline parameters. - * * Cauchy prior gradient for group-difference parameters and pairwise effects. - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). - * - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). - * - main_effect_indices: Index ranges [row_start,row_end] for each variable in main_effects. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - observations_double: Observation matrix (N × V), pre-converted to double. - * - group_indices: Row ranges [start,end] for each group in `observations_double`. - * - num_categories: Number of categories per variable. - * - counts_per_category_group: Per-group category counts (ordinal variables). - * - blume_capel_stats_group: Per-group sufficient statistics (Blume–Capel variables). - * - pairwise_stats_group: Per-group pairwise sufficient statistics. - * - num_groups: Number of groups. - * - inclusion_indicator: Symmetric binary matrix; diagonal entries control inclusion - * of main-effect differences, off-diagonal entries control inclusion of pairwise - * differences. - * - is_ordinal_variable: Indicator vector (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. - * - interaction_scale: Scale parameter for Cauchy prior on baseline pairwise effects. - * - difference_scale: Scale parameter for Cauchy prior on group differences. - * - main_index: Index map for main-effect parameters (from build_index_maps()). - * - pair_index: Index map for pairwise-effect parameters (from build_index_maps()). - * - grad_obs: Precomputed observed-data contribution to the gradient - * (output of `gradient_observed_active()`). - * - * Returns: - * - arma::vec: Gradient of the log pseudoposterior with respect to all active - * parameters, in the layout defined by `vectorize_model_parameters_bgmcompare()`. - * - * Notes: - * - Must remain consistent with `vectorize_model_parameters_bgmcompare()` and - * `unvectorize_model_parameters_bgmcompare()`. - * - Expected sufficient statistics are computed on-the-fly, while observed - * statistics are passed in via `grad_obs`. - * - Priors are applied after observed and expected contributions. - */ +// Computes the gradient of the log pseudoposterior for the bgmCompare model. +// +// The gradient combines three contributions: +// 1. Observed sufficient statistics (precomputed and supplied via `grad_obs`). +// 2. Expected sufficient statistics under the current parameter values +// (computed using softmax probabilities for ordinal or Blume–Capel variables). +// 3. Prior contributions on main effects, pairwise effects, and group differences. +// +// Procedure: +// - Initialize gradient with `grad_obs` (observed-data contribution). +// - Loop over groups: +// * Build group-specific main and pairwise effects using +// `compute_group_main_effects()` and `compute_group_pairwise_effects()`. +// * Compute expected sufficient statistics from residual scores and +// subtract them from the gradient. +// - Add prior contributions: +// * Logistic–Beta prior gradient for main-effect baseline parameters. +// * Cauchy prior gradient for group-difference parameters and pairwise effects. +// +// Inputs: +// - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). +// - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). +// - main_effect_indices: Index ranges [row_start,row_end] for each variable in main_effects. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - observations_double: Observation matrix (N × V), pre-converted to double. +// - group_indices: Row ranges [start,end] for each group in `observations_double`. +// - num_categories: Number of categories per variable. +// - counts_per_category_group: Per-group category counts (ordinal variables). +// - blume_capel_stats_group: Per-group sufficient statistics (Blume–Capel variables). +// - pairwise_stats_group: Per-group pairwise sufficient statistics. +// - num_groups: Number of groups. +// - inclusion_indicator: Symmetric binary matrix; diagonal entries control inclusion +// of main-effect differences, off-diagonal entries control inclusion of pairwise +// differences. +// - is_ordinal_variable: Indicator vector (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. +// - interaction_scale: Scale parameter for Cauchy prior on baseline pairwise effects. +// - difference_scale: Scale parameter for Cauchy prior on group differences. +// - main_index: Index map for main-effect parameters (from build_index_maps()). +// - pair_index: Index map for pairwise-effect parameters (from build_index_maps()). +// - grad_obs: Precomputed observed-data contribution to the gradient +// (output of `gradient_observed_active()`). +// +// Returns: +// - arma::vec: Gradient of the log pseudoposterior with respect to all active +// parameters, in the layout defined by `vectorize_model_parameters_bgmcompare()`. +// +// Notes: +// - Must remain consistent with `vectorize_model_parameters_bgmcompare()` and +// `unvectorize_model_parameters_bgmcompare()`. +// - Expected sufficient statistics are computed on-the-fly, while observed +// statistics are passed in via `grad_obs`. +// - Priors are applied after observed and expected contributions. arma::vec gradient( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -505,18 +499,16 @@ arma::vec gradient( } -/** - * Computes both log pseudoposterior and gradient in a single pass. - * - * Fuses the computations of `log_pseudoposterior()` and `gradient()`, - * sharing intermediate results (group-specific effects, residual matrices, - * and probability computations) to avoid redundant work during NUTS sampling. - * - * Returns: - * - std::pair containing: - * - first: log pseudoposterior value (scalar) - * - second: gradient vector (same layout as gradient()) - */ +// Computes both log pseudoposterior and gradient in a single pass. +// +// Fuses the computations of `log_pseudoposterior()` and `gradient()`, +// sharing intermediate results (group-specific effects, residual matrices, +// and probability computations) to avoid redundant work during NUTS sampling. +// +// Returns: +// - std::pair containing: +// - first: log pseudoposterior value (scalar) +// - second: gradient vector (same layout as gradient()) std::pair logp_and_gradient( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -779,54 +771,52 @@ std::pair logp_and_gradient( -/** - * Computes the log pseudoposterior contribution of a single main-effect parameter (bgmCompare model). - * - * This function isolates the contribution of one main-effect parameter, - * either the overall (baseline) effect or one of its group-specific differences. - * - * Procedure: - * - For each group: - * * Construct group-specific main effects for the selected variable - * with `compute_group_main_effects()`. - * * Construct group-specific pairwise effects for the variable. - * * Add linear contributions from sufficient statistics. - * * Subtract log normalizing constants from the group-specific likelihood. - * - Add prior contribution: - * * Logistic–Beta prior for baseline (h == 0). - * * Cauchy prior for group differences (h > 0), if included. - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). - * - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). - * - main_effect_indices: Index ranges [row_start,row_end] for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - observations: Observation matrix (persons × variables). - * - group_indices: Row ranges [start,end] for each group in observations. - * - num_categories: Number of categories per variable. - * - counts_per_category_group: Per-group category counts (for ordinal variables). - * - blume_capel_stats_group: Per-group sufficient statistics (for Blume–Capel variables). - * - num_groups: Number of groups. - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. - * - difference_scale: Scale parameter for Cauchy priors on group differences. - * - variable: Index of the variable of interest. - * - category: Category index (only used if variable is ordinal). - * - par: Parameter index (0 = linear, 1 = quadratic; used for Blume–Capel). - * - h: Column index (0 = overall baseline, >0 = group difference). - * - * Returns: - * - The scalar log pseudoposterior contribution of the selected parameter. - * - * Notes: - * - If h > 0 but inclusion_indicator(variable, variable) == 0, - * the function returns 0.0 (no contribution). - * - This component function is used in parameter-wise Metropolis updates. - * - Consistent with the full `log_pseudoposterior()` for bgmCompare. - */ +// Computes the log pseudoposterior contribution of a single main-effect parameter (bgmCompare model). +// +// This function isolates the contribution of one main-effect parameter, +// either the overall (baseline) effect or one of its group-specific differences. +// +// Procedure: +// - For each group: +// * Construct group-specific main effects for the selected variable +// with `compute_group_main_effects()`. +// * Construct group-specific pairwise effects for the variable. +// * Add linear contributions from sufficient statistics. +// * Subtract log normalizing constants from the group-specific likelihood. +// - Add prior contribution: +// * Logistic–Beta prior for baseline (h == 0). +// * Cauchy prior for group differences (h > 0), if included. +// +// Inputs: +// - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). +// - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). +// - main_effect_indices: Index ranges [row_start,row_end] for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - observations: Observation matrix (persons × variables). +// - group_indices: Row ranges [start,end] for each group in observations. +// - num_categories: Number of categories per variable. +// - counts_per_category_group: Per-group category counts (for ordinal variables). +// - blume_capel_stats_group: Per-group sufficient statistics (for Blume–Capel variables). +// - num_groups: Number of groups. +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. +// - difference_scale: Scale parameter for Cauchy priors on group differences. +// - variable: Index of the variable of interest. +// - category: Category index (only used if variable is ordinal). +// - par: Parameter index (0 = linear, 1 = quadratic; used for Blume–Capel). +// - h: Column index (0 = overall baseline, >0 = group difference). +// +// Returns: +// - The scalar log pseudoposterior contribution of the selected parameter. +// +// Notes: +// - If h > 0 but inclusion_indicator(variable, variable) == 0, +// the function returns 0.0 (no contribution). +// - This component function is used in parameter-wise Metropolis updates. +// - Consistent with the full `log_pseudoposterior()` for bgmCompare. double log_pseudoposterior_main_component( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -959,53 +949,51 @@ double log_pseudoposterior_main_component( } -/** - * Computes the log pseudoposterior contribution of a single pairwise-effect parameter (bgmCompare model). - * - * Isolates the contribution of one pairwise-effect parameter between two variables, - * either the baseline effect (h == 0) or a group-specific difference (h > 0). - * - * Procedure: - * - For each group: - * * Construct group-specific main effects for the two variables. - * * Add linear contributions from the pairwise sufficient statistic. - * - Baseline (h == 0): contribution = 2 * suff_pair * proposed_value. - * - Difference (h > 0): scaled by projection value proj_g(h-1). - * * Subtract log normalizing constants from both variables' likelihoods. - * - Add prior contribution: - * * Cauchy prior for baseline (scale = interaction_scale). - * * Cauchy prior for group differences (scale = difference_scale). - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). - * - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). - * - main_effect_indices: Index ranges [row_start, row_end] for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1, var2) to row in pairwise_effects. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - observations: Observation matrix (persons × variables). - * - group_indices: Row ranges [start, end] for each group in observations. - * - num_categories: Number of categories per variable. - * - pairwise_stats_group: Per-group pairwise sufficient statistics. - * - residual_matrices: Per-group residual matrices (persons × variables). - * - num_groups: Number of groups. - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - interaction_scale: Scale parameter for Cauchy prior on baseline pairwise effects. - * - pairwise_scaling_factors: Per-pair scaling factors for the prior. - * - difference_scale: Scale parameter for Cauchy prior on group differences. - * - variable1, variable2: Indices of the variable pair. - * - h: Column index (0 = baseline, > 0 = group difference). - * - delta: Parameter change (proposed - current). - * - * Returns: - * - The log pseudoposterior value at the proposed state. - * - * Notes: - * - If h > 0 but inclusion_indicator(variable1, variable2) == 0, returns 0.0. - * - The proposed value is computed as pairwise_effects(idx, h) + delta. - * - Residual scores are adjusted by delta without modifying residual_matrices. - */ +// Computes the log pseudoposterior contribution of a single pairwise-effect parameter (bgmCompare model). +// +// Isolates the contribution of one pairwise-effect parameter between two variables, +// either the baseline effect (h == 0) or a group-specific difference (h > 0). +// +// Procedure: +// - For each group: +// * Construct group-specific main effects for the two variables. +// * Add linear contributions from the pairwise sufficient statistic. +// - Baseline (h == 0): contribution = 2 * suff_pair * proposed_value. +// - Difference (h > 0): scaled by projection value proj_g(h-1). +// * Subtract log normalizing constants from both variables' likelihoods. +// - Add prior contribution: +// * Cauchy prior for baseline (scale = interaction_scale). +// * Cauchy prior for group differences (scale = difference_scale). +// +// Inputs: +// - main_effects: Matrix of main-effect parameters (rows = categories, cols = groups). +// - pairwise_effects: Matrix of pairwise-effect parameters (rows = pairs, cols = groups). +// - main_effect_indices: Index ranges [row_start, row_end] for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1, var2) to row in pairwise_effects. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - observations: Observation matrix (persons × variables). +// - group_indices: Row ranges [start, end] for each group in observations. +// - num_categories: Number of categories per variable. +// - pairwise_stats_group: Per-group pairwise sufficient statistics. +// - residual_matrices: Per-group residual matrices (persons × variables). +// - num_groups: Number of groups. +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - interaction_scale: Scale parameter for Cauchy prior on baseline pairwise effects. +// - pairwise_scaling_factors: Per-pair scaling factors for the prior. +// - difference_scale: Scale parameter for Cauchy prior on group differences. +// - variable1, variable2: Indices of the variable pair. +// - h: Column index (0 = baseline, > 0 = group difference). +// - delta: Parameter change (proposed - current). +// +// Returns: +// - The log pseudoposterior value at the proposed state. +// +// Notes: +// - If h > 0 but inclusion_indicator(variable1, variable2) == 0, returns 0.0. +// - The proposed value is computed as pairwise_effects(idx, h) + delta. +// - Residual scores are adjusted by delta without modifying residual_matrices. double log_pseudoposterior_pair_component( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -1113,49 +1101,47 @@ double log_pseudoposterior_pair_component( -/** - * Computes the log-ratio of pseudolikelihood normalizing constants - * for a single variable under current vs. proposed parameters (bgmCompare model). - * - * This function is used in Metropolis–Hastings updates for main-effect parameters. - * It evaluates how the normalizing constant (denominator of the pseudolikelihood) - * changes when switching from the current to the proposed parameter values. - * - * Procedure: - * - For each group: - * * Construct group-specific main effects (current vs. proposed). - * * Construct group-specific pairwise weights for the variable. - * * Compute residual scores for observations under both models. - * * Calculate denominators with stability bounds (ordinal vs. Blume–Capel cases). - * * Accumulate the log-ratio contribution across all observations. - * - * Inputs: - * - current_main_effects, proposed_main_effects: Matrices of main-effect parameters - * (rows = categories, cols = groups). - * - current_pairwise_effects, proposed_pairwise_effects: Matrices of pairwise-effect parameters - * (rows = pairs, cols = groups). - * - main_effect_indices: Index ranges [row_start,row_end] for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - observations: Observation matrix (persons × variables). - * - group_indices: Row ranges [start,end] for each group in observations. - * - num_categories: Number of categories per variable. - * - num_groups: Number of groups. - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - variable: Index of the variable being updated. - * - * Returns: - * - The scalar log-ratio of pseudolikelihood constants - * (current model vs. proposed model). - * - * Notes: - * - For ordinal variables, denominators include exp(-bound) and category terms. - * - For Blume–Capel variables, denominators use linear/quadratic scores - * with baseline centering. - * - Stability bounds (`bound_current`, `bound_proposed`) are applied to avoid overflow. - */ +// Computes the log-ratio of pseudolikelihood normalizing constants +// for a single variable under current vs. proposed parameters (bgmCompare model). +// +// This function is used in Metropolis–Hastings updates for main-effect parameters. +// It evaluates how the normalizing constant (denominator of the pseudolikelihood) +// changes when switching from the current to the proposed parameter values. +// +// Procedure: +// - For each group: +// * Construct group-specific main effects (current vs. proposed). +// * Construct group-specific pairwise weights for the variable. +// * Compute residual scores for observations under both models. +// * Calculate denominators with stability bounds (ordinal vs. Blume–Capel cases). +// * Accumulate the log-ratio contribution across all observations. +// +// Inputs: +// - current_main_effects, proposed_main_effects: Matrices of main-effect parameters +// (rows = categories, cols = groups). +// - current_pairwise_effects, proposed_pairwise_effects: Matrices of pairwise-effect parameters +// (rows = pairs, cols = groups). +// - main_effect_indices: Index ranges [row_start,row_end] for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - observations: Observation matrix (persons × variables). +// - group_indices: Row ranges [start,end] for each group in observations. +// - num_categories: Number of categories per variable. +// - num_groups: Number of groups. +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - variable: Index of the variable being updated. +// +// Returns: +// - The scalar log-ratio of pseudolikelihood constants +// (current model vs. proposed model). +// +// Notes: +// - For ordinal variables, denominators include exp(-bound) and category terms. +// - For Blume–Capel variables, denominators use linear/quadratic scores +// with baseline centering. +// - Stability bounds (`bound_current`, `bound_proposed`) are applied to avoid overflow. double log_ratio_pseudolikelihood_constant_variable( const arma::mat& current_main_effects, const arma::mat& current_pairwise_effects, @@ -1256,50 +1242,48 @@ double log_ratio_pseudolikelihood_constant_variable( -/** - * Computes the log pseudolikelihood ratio for updating a single main-effect parameter (bgmCompare model). - * - * This function is used in Metropolis–Hastings updates for main effects. - * It compares the likelihood of the data under the current vs. proposed - * value of a single variable’s main-effect parameter, while keeping - * all other parameters fixed. - * - * Procedure: - * - For each group: - * * Compute group-specific main effects for the variable (current vs. proposed). - * * Add contributions from observed sufficient statistics - * (category counts or Blume–Capel stats). - * - Add the ratio of pseudolikelihood normalizing constants by calling - * `log_ratio_pseudolikelihood_constant_variable()`. - * - * Inputs: - * - current_main_effects: Matrix of main-effect parameters (current state). - * - proposed_main_effects: Matrix of main-effect parameters (candidate state). - * - current_pairwise_effects: Matrix of pairwise-effect parameters (fixed at current state). - * - main_effect_indices: Index ranges [row_start,row_end] for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - observations: Observation matrix (persons × variables). - * - group_indices: Row ranges [start,end] for each group in observations. - * - num_categories: Number of categories per variable. - * - counts_per_category_group: Per-group category counts (for ordinal variables). - * - blume_capel_stats_group: Per-group sufficient statistics (for Blume–Capel variables). - * - num_groups: Number of groups. - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - variable: Index of the variable being updated. - * - * Returns: - * - The scalar log pseudolikelihood ratio (proposed vs. current). - * - * Notes: - * - A temporary copy of `inclusion_indicator` is made to ensure the - * variable’s self-term (diagonal entry) is included. - * - Only the variable under update changes between current and proposed states; - * all other variables and pairwise effects remain fixed. - * - This function does not add prior contributions — only pseudolikelihood terms. - */ +// Computes the log pseudolikelihood ratio for updating a single main-effect parameter (bgmCompare model). +// +// This function is used in Metropolis–Hastings updates for main effects. +// It compares the likelihood of the data under the current vs. proposed +// value of a single variable’s main-effect parameter, while keeping +// all other parameters fixed. +// +// Procedure: +// - For each group: +// * Compute group-specific main effects for the variable (current vs. proposed). +// * Add contributions from observed sufficient statistics +// (category counts or Blume–Capel stats). +// - Add the ratio of pseudolikelihood normalizing constants by calling +// `log_ratio_pseudolikelihood_constant_variable()`. +// +// Inputs: +// - current_main_effects: Matrix of main-effect parameters (current state). +// - proposed_main_effects: Matrix of main-effect parameters (candidate state). +// - current_pairwise_effects: Matrix of pairwise-effect parameters (fixed at current state). +// - main_effect_indices: Index ranges [row_start,row_end] for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - observations: Observation matrix (persons × variables). +// - group_indices: Row ranges [start,end] for each group in observations. +// - num_categories: Number of categories per variable. +// - counts_per_category_group: Per-group category counts (for ordinal variables). +// - blume_capel_stats_group: Per-group sufficient statistics (for Blume–Capel variables). +// - num_groups: Number of groups. +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - variable: Index of the variable being updated. +// +// Returns: +// - The scalar log pseudolikelihood ratio (proposed vs. current). +// +// Notes: +// - A temporary copy of `inclusion_indicator` is made to ensure the +// variable’s self-term (diagonal entry) is included. +// - Only the variable under update changes between current and proposed states; +// all other variables and pairwise effects remain fixed. +// - This function does not add prior contributions — only pseudolikelihood terms. double log_pseudolikelihood_ratio_main( const arma::mat& current_main_effects, const arma::mat& proposed_main_effects, @@ -1359,49 +1343,47 @@ double log_pseudolikelihood_ratio_main( } -/** - * Computes the log pseudolikelihood ratio for updating a single pairwise-effect parameter (bgmCompare model). - * - * This function is used in Metropolis–Hastings updates for pairwise effects. - * It compares the likelihood of the data under the current vs. proposed - * value of a single interaction (var1,var2), while keeping all other - * parameters fixed. - * - * Procedure: - * - Ensure the interaction is included in a temporary copy of inclusion_indicator. - * - For each group: - * * Compute group-specific pairwise effect for (var1,var2), current vs. proposed. - * * Add linear contribution from the pairwise sufficient statistic. - * - Add the ratio of pseudolikelihood normalizing constants for both variables: - * * Call `log_ratio_pseudolikelihood_constant_variable()` separately for var1 and var2, - * comparing current vs. proposed pairwise weights. - * - * Inputs: - * - main_effects: Matrix of main-effect parameters (fixed). - * - current_pairwise_effects: Matrix of pairwise-effect parameters (current state). - * - proposed_pairwise_effects: Matrix of pairwise-effect parameters (candidate state). - * - main_effect_indices: Index ranges [row_start,row_end] for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - observations: Observation matrix (persons × variables). - * - group_indices: Row ranges [start,end] for each group in observations. - * - num_categories: Number of categories per variable. - * - pairwise_stats_group: Per-group pairwise sufficient statistics. - * - num_groups: Number of groups. - * - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - var1, var2: Indices of the variable pair being updated. - * - * Returns: - * - The scalar log pseudolikelihood ratio (proposed vs. current). - * - * Notes: - * - A temporary copy of `inclusion_indicator` is used to force the edge (var1,var2) as active. - * - Only the selected pair changes between current and proposed states; - * all other effects remain fixed. - * - This function does not add prior contributions — only pseudolikelihood terms. - */ +// Computes the log pseudolikelihood ratio for updating a single pairwise-effect parameter (bgmCompare model). +// +// This function is used in Metropolis–Hastings updates for pairwise effects. +// It compares the likelihood of the data under the current vs. proposed +// value of a single interaction (var1,var2), while keeping all other +// parameters fixed. +// +// Procedure: +// - Ensure the interaction is included in a temporary copy of inclusion_indicator. +// - For each group: +// * Compute group-specific pairwise effect for (var1,var2), current vs. proposed. +// * Add linear contribution from the pairwise sufficient statistic. +// - Add the ratio of pseudolikelihood normalizing constants for both variables: +// * Call `log_ratio_pseudolikelihood_constant_variable()` separately for var1 and var2, +// comparing current vs. proposed pairwise weights. +// +// Inputs: +// - main_effects: Matrix of main-effect parameters (fixed). +// - current_pairwise_effects: Matrix of pairwise-effect parameters (current state). +// - proposed_pairwise_effects: Matrix of pairwise-effect parameters (candidate state). +// - main_effect_indices: Index ranges [row_start,row_end] for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - observations: Observation matrix (persons × variables). +// - group_indices: Row ranges [start,end] for each group in observations. +// - num_categories: Number of categories per variable. +// - pairwise_stats_group: Per-group pairwise sufficient statistics. +// - num_groups: Number of groups. +// - inclusion_indicator: Symmetric binary matrix of active variables (diag) and pairs (off-diag). +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - var1, var2: Indices of the variable pair being updated. +// +// Returns: +// - The scalar log pseudolikelihood ratio (proposed vs. current). +// +// Notes: +// - A temporary copy of `inclusion_indicator` is used to force the edge (var1,var2) as active. +// - Only the selected pair changes between current and proposed states; +// all other effects remain fixed. +// - This function does not add prior contributions — only pseudolikelihood terms. double log_pseudolikelihood_ratio_pairwise( const arma::mat& main_effects, const arma::mat& current_pairwise_effects, diff --git a/src/bgmCompare/bgmCompare_logp_and_grad.h b/src/bgmCompare/bgmCompare_logp_and_grad.h index a9a7c827..01011458 100644 --- a/src/bgmCompare/bgmCompare_logp_and_grad.h +++ b/src/bgmCompare/bgmCompare_logp_and_grad.h @@ -254,7 +254,7 @@ double log_pseudoposterior_pair_component( * * Compares proposed vs. current main-effect parameters across all groups, * combining sufficient-statistic differences with normalizing-constant ratios. - * Used by the reversible-jump indicator update for main effects. + * Used by the Metropolis-Hastings indicator update for main effects. * * @param current_main_effects Current main-effect matrix * @param proposed_main_effects Proposed main-effect matrix @@ -285,7 +285,7 @@ double log_pseudolikelihood_ratio_main( * * Compares proposed vs. current pairwise-effect parameters for a single edge, * summing the data contribution and normalizing-constant ratios for both - * endpoint variables. Used by the reversible-jump indicator update. + * endpoint variables. Used by the Metropolis-Hastings indicator update. * * @param current_pairwise_effects Current pairwise-effect matrix * @param proposed_pairwise_effects Proposed pairwise-effect matrix diff --git a/src/bgmCompare/bgmCompare_output.h b/src/bgmCompare/bgmCompare_output.h index 148a1984..632566d2 100644 --- a/src/bgmCompare/bgmCompare_output.h +++ b/src/bgmCompare/bgmCompare_output.h @@ -10,28 +10,27 @@ * * Stores posterior samples of main and pairwise effects, optional * inclusion indicators, and diagnostics for HMC/NUTS runs. - * - * Members: - * - main_samples: [iter × (#main × groups)] matrix of main-effect samples. - * - pairwise_samples:[iter × (#pair × groups)] matrix of pairwise-effect samples. - * - indicator_samples:[iter × (#edges + #variables)] indicator samples (if used). - * - treedepth_samples:[iter] tree depth diagnostics (NUTS only). - * - divergent_samples:[iter] divergent transition flags (NUTS only). - * - energy_samples: [iter] energy diagnostic (NUTS only). - * - chain_id: Identifier of the chain. - * - has_indicator: True if indicator samples are stored. */ struct bgmCompareOutput { + /// Main-effect samples [iter x (#main x groups)]. arma::mat main_samples; + /// Pairwise-effect samples [iter x (#pair x groups)]. arma::mat pairwise_samples; + /// Inclusion indicator samples [iter x (#edges + #variables)] (if used). arma::imat indicator_samples; + /// Tree depth diagnostics [iter] (NUTS only). arma::ivec treedepth_samples; + /// Divergent transition flags [iter] (NUTS only). arma::ivec divergent_samples; + /// Energy diagnostic [iter] (NUTS only). arma::vec energy_samples; + /// Identifier of the chain. int chain_id; + /// True if indicator samples are stored. bool has_indicator; + /// True if the chain was interrupted by the user. bool userInterrupt; }; diff --git a/src/bgmCompare/bgmCompare_sampler.cpp b/src/bgmCompare/bgmCompare_sampler.cpp index 19fbd8d2..cb5ee5f6 100644 --- a/src/bgmCompare/bgmCompare_sampler.cpp +++ b/src/bgmCompare/bgmCompare_sampler.cpp @@ -17,51 +17,49 @@ -/** - * Imputes missing observations for the bgmCompare model. - * - * This function performs single imputation of missing values during Gibbs sampling. - * Each missing entry is resampled from its conditional distribution given: - * - the current main and pairwise effect parameters, - * - the observed data for that individual, - * - group-specific sufficient statistics. - * - * Workflow: - * 1. For each missing entry, identify its (person, variable, group). - * 2. Compute group-specific main and pairwise effects via projections. - * 3. Calculate unnormalized probabilities for all categories of the variable: - * - Ordinal: softmax using category-specific thresholds. - * - Blume–Capel: quadratic + linear score with baseline centering. - * 4. Sample a new category with inverse transform sampling. - * 5. If the imputed value differs from the old one, update: - * - `observations` (raw data matrix), - * - `counts_per_category` or `blume_capel_stats` (main-effect sufficient stats), - * - `pairwise_stats` (pairwise sufficient stats). - * - * Inputs: - * - main_effects, pairwise_effects: Current parameter matrices. - * - main_effect_indices, pairwise_effect_indices: Lookup tables for variable/pair rows. - * - inclusion_indicator: Indicates which differences/pairs are included. - * - projection: Group projection matrix. - * - observations: Data matrix [persons × variables]; updated in place. - * - num_groups: Number of groups. - * - group_membership: Group assignment for each person. - * - group_indices: Row ranges [start,end] for each group. - * - counts_per_category: Group-level sufficient statistics for ordinal variables. - * - blume_capel_stats: Group-level sufficient statistics for Blume–Capel variables. - * - pairwise_stats: Group-level sufficient statistics for pairwise interactions. - * - num_categories: Number of categories for each variable in each group. - * - missing_data_indices: Matrix of (person, variable) pairs with missing values. - * - is_ordinal_variable: Indicator vector (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - rng: Random number generator. - * - * Notes: - * - The function updates both raw data and sufficient statistics in-place. - * - Group-specific pairwise effects are recomputed per missing entry. - * - For efficiency, you may consider incremental updates to `pairwise_stats` - * instead of full recomputation (`obs.t() * obs`) after each change. - */ +// Imputes missing observations for the bgmCompare model. +// +// This function performs single imputation of missing values during Gibbs sampling. +// Each missing entry is resampled from its conditional distribution given: +// - the current main and pairwise effect parameters, +// - the observed data for that individual, +// - group-specific sufficient statistics. +// +// Workflow: +// 1. For each missing entry, identify its (person, variable, group). +// 2. Compute group-specific main and pairwise effects via projections. +// 3. Calculate unnormalized probabilities for all categories of the variable: +// - Ordinal: softmax using category-specific thresholds. +// - Blume–Capel: quadratic + linear score with baseline centering. +// 4. Sample a new category with inverse transform sampling. +// 5. If the imputed value differs from the old one, update: +// - `observations` (raw data matrix), +// - `counts_per_category` or `blume_capel_stats` (main-effect sufficient stats), +// - `pairwise_stats` (pairwise sufficient stats). +// +// Inputs: +// - main_effects, pairwise_effects: Current parameter matrices. +// - main_effect_indices, pairwise_effect_indices: Lookup tables for variable/pair rows. +// - inclusion_indicator: Indicates which differences/pairs are included. +// - projection: Group projection matrix. +// - observations: Data matrix [persons × variables]; updated in place. +// - num_groups: Number of groups. +// - group_membership: Group assignment for each person. +// - group_indices: Row ranges [start,end] for each group. +// - counts_per_category: Group-level sufficient statistics for ordinal variables. +// - blume_capel_stats: Group-level sufficient statistics for Blume–Capel variables. +// - pairwise_stats: Group-level sufficient statistics for pairwise interactions. +// - num_categories: Number of categories for each variable in each group. +// - missing_data_indices: Matrix of (person, variable) pairs with missing values. +// - is_ordinal_variable: Indicator vector (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - rng: Random number generator. +// +// Notes: +// - The function updates both raw data and sufficient statistics in-place. +// - Group-specific pairwise effects are recomputed per missing entry. +// - For efficiency, you may consider incremental updates to `pairwise_stats` +// instead of full recomputation (`obs.t() * obs`) after each change. void impute_missing_bgmcompare( const arma::mat& main_effects, const arma::mat& pairwise_effects, @@ -188,58 +186,56 @@ void impute_missing_bgmcompare( -/** - * Updates main effect parameters in bgmCompare using a random-walk Metropolis step. - * - * For each variable, the function proposes new parameter values for either: - * - all categories (ordinal variables), or - * - two parameters (linear and quadratic, Blume–Capel variables). - * - * If group-specific differences are enabled (`inclusion_indicator(v,v) == 1`), - * additional parameters (one per group contrast) are also updated. - * - * Each proposed parameter is evaluated via - * `log_pseudoposterior_main_component()`, and accepted/rejected using - * the Metropolis–Hastings rule. Proposal standard deviations are adapted - * online with `MetropolisAdaptationController`. - * - * Workflow: - * 1. Iterate over all variables. - * 2. For each category (ordinal) or parameter (Blume–Capel): - * - Update the "overall" effect (h=0). - * - Optionally update group-difference effects (h=1..G-1). - * 3. Record acceptance probabilities and update adaptation statistics. - * - * Inputs: - * - main_effects: Matrix of main effect parameters [rows = effects, cols = groups]; - * updated in place. - * - pairwise_effects: Current pairwise effects (passed through to log posterior). - * - main_effect_indices: Row index ranges for each variable’s main effects. - * - pairwise_effect_indices: Index map for pairwise effects. - * - inclusion_indicator: Indicator matrix; diagonal entries control group differences. - * - projection: Group projection matrix. - * - num_categories: Number of categories for each variable. - * - observations: Data matrix [persons × variables]. - * - num_groups: Number of groups (G). - * - group_indices: Row ranges per group in `observations`. - * - counts_per_category, blume_capel_stats: Group-specific sufficient statistics. - * - is_ordinal_variable: Indicator for ordinal vs. Blume–Capel. - * - baseline_category: Reference categories (Blume–Capel only). - * - difference_scale: Scale parameter for group difference priors. - * - main_alpha, main_beta: Parameters for the Beta prior on main effects. - * - iteration: Current iteration index (for adaptation). - * - rwm_adapt: Adaptation controller for proposal SDs. - * - rng: Random number generator. - * - proposal_sd_main: Proposal standard deviations [same shape as `main_effects`]; - * updated in place. - * - * Notes: - * - Acceptance probabilities are stored per parameter and fed to `metropolis_adapt.update()`. - * - This function does not alter pairwise effects, but passes them into - * the posterior for likelihood consistency. - * - The helper lambda `do_update` encapsulates the proposal/accept/revert loop - * for a single parameter, improving readability. - */ +// Updates main effect parameters in bgmCompare using a random-walk Metropolis step. +// +// For each variable, the function proposes new parameter values for either: +// - all categories (ordinal variables), or +// - two parameters (linear and quadratic, Blume–Capel variables). +// +// If group-specific differences are enabled (`inclusion_indicator(v,v) == 1`), +// additional parameters (one per group contrast) are also updated. +// +// Each proposed parameter is evaluated via +// `log_pseudoposterior_main_component()`, and accepted/rejected using +// the Metropolis–Hastings rule. Proposal standard deviations are adapted +// online with `MetropolisAdaptationController`. +// +// Workflow: +// 1. Iterate over all variables. +// 2. For each category (ordinal) or parameter (Blume–Capel): +// - Update the "overall" effect (h=0). +// - Optionally update group-difference effects (h=1..G-1). +// 3. Record acceptance probabilities and update adaptation statistics. +// +// Inputs: +// - main_effects: Matrix of main effect parameters [rows = effects, cols = groups]; +// updated in place. +// - pairwise_effects: Current pairwise effects (passed through to log posterior). +// - main_effect_indices: Row index ranges for each variable’s main effects. +// - pairwise_effect_indices: Index map for pairwise effects. +// - inclusion_indicator: Indicator matrix; diagonal entries control group differences. +// - projection: Group projection matrix. +// - num_categories: Number of categories for each variable. +// - observations: Data matrix [persons × variables]. +// - num_groups: Number of groups (G). +// - group_indices: Row ranges per group in `observations`. +// - counts_per_category, blume_capel_stats: Group-specific sufficient statistics. +// - is_ordinal_variable: Indicator for ordinal vs. Blume–Capel. +// - baseline_category: Reference categories (Blume–Capel only). +// - difference_scale: Scale parameter for group difference priors. +// - main_alpha, main_beta: Parameters for the Beta prior on main effects. +// - iteration: Current iteration index (for adaptation). +// - rwm_adapt: Adaptation controller for proposal SDs. +// - rng: Random number generator. +// - proposal_sd_main: Proposal standard deviations [same shape as `main_effects`]; +// updated in place. +// +// Notes: +// - Acceptance probabilities are stored per parameter and fed to `metropolis_adapt.update()`. +// - This function does not alter pairwise effects, but passes them into +// the posterior for likelihood consistency. +// - The helper lambda `do_update` encapsulates the proposal/accept/revert loop +// for a single parameter, improving readability. void update_main_effects_metropolis_bgmcompare ( arma::mat& main_effects, arma::mat& pairwise_effects, @@ -324,57 +320,55 @@ void update_main_effects_metropolis_bgmcompare ( -/** - * Updates pairwise interaction parameters in bgmCompare using a random-walk - * Metropolis step. - * - * For each variable pair (var1,var2), the function proposes new parameter - * values for: - * - the overall interaction (h=0), and - * - optionally group-difference effects (h=1..G-1) if enabled in - * `inclusion_indicator`. - * - * Each proposed parameter is evaluated via - * `log_pseudoposterior_pair_component()`, and accepted/rejected using the - * Metropolis–Hastings rule. Proposal standard deviations are adapted online - * through `MetropolisAdaptationController`. - * - * Workflow: - * 1. Iterate over all unique pairs of variables. - * 2. For each pair, update the overall effect (h=0). - * 3. If group differences are active, update group-specific difference - * effects (h=1..G-1). - * 4. Record acceptance probabilities and update proposal SDs via `rwm_adapt`. - * - * Inputs: - * - main_effects: Matrix of main effect parameters (passed through to log posterior). - * - pairwise_effects: Matrix of pairwise interaction parameters [rows = pairs, cols = groups]; - * updated in place. - * - main_effect_indices: Index map for main effects (per variable). - * - pairwise_effect_indices: Row index map for pairwise effects (per var1,var2). - * - inclusion_indicator: Indicator matrix; off-diagonal entries control group differences. - * - projection: Group projection matrix. - * - num_categories: Number of categories per variable. - * - observations: Data matrix [persons × variables]. - * - num_groups: Number of groups (G). - * - group_indices: Row ranges per group in `observations`. - * - pairwise_stats: Group-specific sufficient statistics for pairwise effects. - * - is_ordinal_variable: Indicator for ordinal vs. Blume–Capel variables. - * - baseline_category: Reference categories (Blume–Capel only). - * - pairwise_scale: Scale parameter for overall interaction prior. - * - difference_scale: Scale parameter for group difference priors. - * - iteration: Current iteration index (for adaptation). - * - rwm_adapt: Adaptation controller for proposal SDs. - * - rng: Random number generator. - * - proposal_sd_pair: Proposal standard deviations [same shape as `pairwise_effects`]; - * updated in place. - * - * Notes: - * - Acceptance probabilities are tracked per parameter and fed to - * `metropolis_adapt.update()`. - * - The helper lambda `do_update` encapsulates the proposal/accept/reject - * logic for a single parameter. - */ +// Updates pairwise interaction parameters in bgmCompare using a random-walk +// Metropolis step. +// +// For each variable pair (var1,var2), the function proposes new parameter +// values for: +// - the overall interaction (h=0), and +// - optionally group-difference effects (h=1..G-1) if enabled in +// `inclusion_indicator`. +// +// Each proposed parameter is evaluated via +// `log_pseudoposterior_pair_component()`, and accepted/rejected using the +// Metropolis–Hastings rule. Proposal standard deviations are adapted online +// through `MetropolisAdaptationController`. +// +// Workflow: +// 1. Iterate over all unique pairs of variables. +// 2. For each pair, update the overall effect (h=0). +// 3. If group differences are active, update group-specific difference +// effects (h=1..G-1). +// 4. Record acceptance probabilities and update proposal SDs via `rwm_adapt`. +// +// Inputs: +// - main_effects: Matrix of main effect parameters (passed through to log posterior). +// - pairwise_effects: Matrix of pairwise interaction parameters [rows = pairs, cols = groups]; +// updated in place. +// - main_effect_indices: Index map for main effects (per variable). +// - pairwise_effect_indices: Row index map for pairwise effects (per var1,var2). +// - inclusion_indicator: Indicator matrix; off-diagonal entries control group differences. +// - projection: Group projection matrix. +// - num_categories: Number of categories per variable. +// - observations: Data matrix [persons × variables]. +// - num_groups: Number of groups (G). +// - group_indices: Row ranges per group in `observations`. +// - pairwise_stats: Group-specific sufficient statistics for pairwise effects. +// - is_ordinal_variable: Indicator for ordinal vs. Blume–Capel variables. +// - baseline_category: Reference categories (Blume–Capel only). +// - pairwise_scale: Scale parameter for overall interaction prior. +// - difference_scale: Scale parameter for group difference priors. +// - iteration: Current iteration index (for adaptation). +// - rwm_adapt: Adaptation controller for proposal SDs. +// - rng: Random number generator. +// - proposal_sd_pair: Proposal standard deviations [same shape as `pairwise_effects`]; +// updated in place. +// +// Notes: +// - Acceptance probabilities are tracked per parameter and fed to +// `metropolis_adapt.update()`. +// - The helper lambda `do_update` encapsulates the proposal/accept/reject +// logic for a single parameter. void update_pairwise_effects_metropolis_bgmcompare ( arma::mat& main_effects, arma::mat& pairwise_effects, @@ -489,54 +483,52 @@ void update_pairwise_effects_metropolis_bgmcompare ( -/** - * Heuristically determine an initial HMC/NUTS step size for bgmCompare. - * - * This function vectorizes the current model parameters, then repeatedly - * simulates short HMC trajectories to calibrate a stable starting step size - * that achieves a target acceptance rate. - * - * Workflow: - * 1. Vectorize current parameters into a single state vector. - * 2. Define closures for log-posterior evaluation and gradient computation: - * - `log_post`: unpacks parameters and evaluates the log pseudoposterior. - * - `grad`: unpacks parameters and evaluates the gradient of the - * pseudoposterior. - * 3. Pass these to `heuristic_initial_step_size`, which runs the heuristic - * tuning loop. - * - * Inputs: - * - main_effects: Matrix of main-effect parameters [n_main_rows × G]. - * - pairwise_effects: Matrix of pairwise interaction parameters [n_pairs × G]. - * - main_effect_indices: Row index ranges for each variable’s main effects. - * - pairwise_effect_indices: Row index map for pairwise effects. - * - inclusion_indicator: Matrix marking which main and pairwise differences - * are active. - * - projection: Group projection matrix (encodes contrasts). - * - num_categories: Number of categories per variable [V]. - * - observations: Data matrix [persons × variables]. - * - num_groups: Number of groups (G). - * - group_indices: Row ranges per group in `observations`. - * - counts_per_category: Per-group sufficient statistics for ordinal variables. - * - blume_capel_stats: Per-group sufficient statistics for Blume–Capel variables. - * - pairwise_stats: Per-group sufficient statistics for pairwise effects. - * - is_ordinal_variable: Indicator for ordinal vs. Blume–Capel variables [V]. - * - baseline_category: Reference categories for Blume–Capel variables [V]. - * - pairwise_scale: Scale parameter for overall pairwise priors. - * - difference_scale: Scale parameter for group difference priors. - * - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. - * - target_acceptance: Desired acceptance probability (e.g. 0.8). - * - rng: Random number generator. - * - * Returns: - * - A double value for the initial HMC/NUTS step size. - * - * Notes: - * - This routine is only used during warmup to initialize - * `HMCAdaptationController`. - * - Correct indexing of parameters relies on `build_index_maps` to ensure - * consistency between vectorization and gradient computation. - */ +// Heuristically determine an initial HMC/NUTS step size for bgmCompare. +// +// This function vectorizes the current model parameters, then repeatedly +// simulates short HMC trajectories to calibrate a stable starting step size +// that achieves a target acceptance rate. +// +// Workflow: +// 1. Vectorize current parameters into a single state vector. +// 2. Define closures for log-posterior evaluation and gradient computation: +// - `log_post`: unpacks parameters and evaluates the log pseudoposterior. +// - `grad`: unpacks parameters and evaluates the gradient of the +// pseudoposterior. +// 3. Pass these to `heuristic_initial_step_size`, which runs the heuristic +// tuning loop. +// +// Inputs: +// - main_effects: Matrix of main-effect parameters [n_main_rows × G]. +// - pairwise_effects: Matrix of pairwise interaction parameters [n_pairs × G]. +// - main_effect_indices: Row index ranges for each variable’s main effects. +// - pairwise_effect_indices: Row index map for pairwise effects. +// - inclusion_indicator: Matrix marking which main and pairwise differences +// are active. +// - projection: Group projection matrix (encodes contrasts). +// - num_categories: Number of categories per variable [V]. +// - observations: Data matrix [persons × variables]. +// - num_groups: Number of groups (G). +// - group_indices: Row ranges per group in `observations`. +// - counts_per_category: Per-group sufficient statistics for ordinal variables. +// - blume_capel_stats: Per-group sufficient statistics for Blume–Capel variables. +// - pairwise_stats: Per-group sufficient statistics for pairwise effects. +// - is_ordinal_variable: Indicator for ordinal vs. Blume–Capel variables [V]. +// - baseline_category: Reference categories for Blume–Capel variables [V]. +// - pairwise_scale: Scale parameter for overall pairwise priors. +// - difference_scale: Scale parameter for group difference priors. +// - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. +// - target_acceptance: Desired acceptance probability (e.g. 0.8). +// - rng: Random number generator. +// +// Returns: +// - A double value for the initial HMC/NUTS step size. +// +// Notes: +// - This routine is only used during warmup to initialize +// `HMCAdaptationController`. +// - Correct indexing of parameters relies on `build_index_maps` to ensure +// consistency between vectorization and gradient computation. double find_initial_stepsize_bgmcompare( arma::mat& main_effects, arma::mat& pairwise_effects, @@ -629,52 +621,50 @@ double find_initial_stepsize_bgmcompare( -/** - * Perform one Hamiltonian Monte Carlo (HMC) update step for the bgmCompare model. - * - * The function: - * 1. Vectorizes the current parameter state (main + pairwise effects). - * 2. Defines closures for log-posterior evaluation and gradient calculation. - * 3. Applies HMC with fixed leapfrog steps to propose a new state. - * 4. Unpacks the accepted state back into main and pairwise matrices. - * 5. Updates the adaptation controller with acceptance information. - * - * Inputs: - * - main_effects, pairwise_effects: Current parameter matrices, updated in place. - * - main_effect_indices, pairwise_effect_indices: Index maps for parameters. - * - inclusion_indicator: Indicates active main and pairwise differences. - * - projection: Group projection matrix for contrasts. - * - num_categories: Number of categories per variable [V]. - * - observations: Data matrix [N × V]. - * - num_groups: Number of groups. - * - group_indices: Row ranges for each group in `observations`. - * - counts_per_category, blume_capel_stats: Per-group sufficient statistics. - * - pairwise_stats: Per-group pairwise sufficient statistics. - * - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. - * - baseline_category: Reference categories for Blume–Capel variables. - * - pairwise_scale: Scale of overall pairwise prior. - * - difference_scale: Scale of group-difference prior. - * - main_alpha, main_beta: Hyperparameters for main-effect priors. - * - num_leapfrogs: Number of leapfrog steps in the HMC trajectory. - * - iteration: Current sampler iteration (for adaptation scheduling). - * - hmc_adapt: Adaptation controller for step size and mass matrix. - * - learn_mass_matrix: Whether to adapt the mass matrix. - * - selection: If true, restrict mass matrix to active parameters only. - * - rng: Random number generator. - * - * Side effects: - * - Updates `main_effects` and `pairwise_effects` with the new state. - * - Updates `hmc_adapt` with acceptance probability and diagnostics. - * - * Returns: - * - None directly; state is updated in place. - * - * Notes: - * - This variant is specific to bgmCompare, where parameters are stored in - * row-wise structures with possible group-difference columns. - * - Consistency between vectorization, unvectorization, and gradient - * indexing is enforced via `build_index_maps`. - */ +// Perform one Hamiltonian Monte Carlo (HMC) update step for the bgmCompare model. +// +// The function: +// 1. Vectorizes the current parameter state (main + pairwise effects). +// 2. Defines closures for log-posterior evaluation and gradient calculation. +// 3. Applies HMC with fixed leapfrog steps to propose a new state. +// 4. Unpacks the accepted state back into main and pairwise matrices. +// 5. Updates the adaptation controller with acceptance information. +// +// Inputs: +// - main_effects, pairwise_effects: Current parameter matrices, updated in place. +// - main_effect_indices, pairwise_effect_indices: Index maps for parameters. +// - inclusion_indicator: Indicates active main and pairwise differences. +// - projection: Group projection matrix for contrasts. +// - num_categories: Number of categories per variable [V]. +// - observations: Data matrix [N × V]. +// - num_groups: Number of groups. +// - group_indices: Row ranges for each group in `observations`. +// - counts_per_category, blume_capel_stats: Per-group sufficient statistics. +// - pairwise_stats: Per-group pairwise sufficient statistics. +// - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. +// - baseline_category: Reference categories for Blume–Capel variables. +// - pairwise_scale: Scale of overall pairwise prior. +// - difference_scale: Scale of group-difference prior. +// - main_alpha, main_beta: Hyperparameters for main-effect priors. +// - num_leapfrogs: Number of leapfrog steps in the HMC trajectory. +// - iteration: Current sampler iteration (for adaptation scheduling). +// - hmc_adapt: Adaptation controller for step size and mass matrix. +// - learn_mass_matrix: Whether to adapt the mass matrix. +// - selection: If true, restrict mass matrix to active parameters only. +// - rng: Random number generator. +// +// Side effects: +// - Updates `main_effects` and `pairwise_effects` with the new state. +// - Updates `hmc_adapt` with acceptance probability and diagnostics. +// +// Returns: +// - None directly; state is updated in place. +// +// Notes: +// - This variant is specific to bgmCompare, where parameters are stored in +// row-wise structures with possible group-difference columns. +// - Consistency between vectorization, unvectorization, and gradient +// indexing is enforced via `build_index_maps`. void update_hmc_bgmcompare( arma::mat& main_effects, arma::mat& pairwise_effects, @@ -809,51 +799,49 @@ void update_hmc_bgmcompare( -/** - * Perform one No-U-Turn Sampler (NUTS) update step for the bgmCompare model. - * - * The function: - * 1. Vectorizes the current parameter state (main + pairwise effects). - * 2. Defines closures for log-posterior evaluation and gradient calculation. - * 3. Runs a NUTS trajectory (adaptive tree-based extension of HMC). - * 4. Unpacks the accepted state back into main and pairwise matrices. - * 5. Updates the adaptation controller with acceptance probability. - * - * Inputs: - * - main_effects, pairwise_effects: Current parameter matrices, updated in place. - * - main_effect_indices, pairwise_effect_indices: Index maps for parameters. - * - inclusion_indicator: Indicates active main and pairwise differences. - * - projection: Group projection matrix for contrasts. - * - num_categories: Number of categories per variable [V]. - * - observations: Data matrix [N × V]. - * - num_groups: Number of groups. - * - group_indices: Row ranges for each group in `observations`. - * - counts_per_category, blume_capel_stats: Per-group sufficient statistics. - * - pairwise_stats: Per-group pairwise sufficient statistics. - * - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. - * - baseline_category: Reference categories for Blume–Capel variables. - * - pairwise_scale: Scale of overall pairwise prior. - * - difference_scale: Scale of group-difference prior. - * - main_alpha, main_beta: Hyperparameters for main-effect priors. - * - nuts_max_depth: Maximum tree depth for NUTS doubling procedure. - * - iteration: Current sampler iteration (for adaptation scheduling). - * - hmc_adapt: Adaptation controller for step size and mass matrix. - * - learn_mass_matrix: Whether to adapt the mass matrix (unused inside NUTS but relevant to controller). - * - selection: If true, restrict mass matrix to active parameters only. - * - rng: Random number generator. - * - * Returns: - * - A `StepResult` containing the accepted state and diagnostics - * (e.g. tree depth, divergences, energy). - * - * Notes: - * - This variant is specific to bgmCompare, where parameters are stored in - * row-wise structures with group-difference columns. - * - Consistency between vectorization, unvectorization, and gradient - * indexing is enforced via `build_index_maps`. - * - Diagnostics from the returned `StepResult` can be used to monitor - * sampler stability (e.g. divergences, tree depth). - */ +// Perform one No-U-Turn Sampler (NUTS) update step for the bgmCompare model. +// +// The function: +// 1. Vectorizes the current parameter state (main + pairwise effects). +// 2. Defines closures for log-posterior evaluation and gradient calculation. +// 3. Runs a NUTS trajectory (adaptive tree-based extension of HMC). +// 4. Unpacks the accepted state back into main and pairwise matrices. +// 5. Updates the adaptation controller with acceptance probability. +// +// Inputs: +// - main_effects, pairwise_effects: Current parameter matrices, updated in place. +// - main_effect_indices, pairwise_effect_indices: Index maps for parameters. +// - inclusion_indicator: Indicates active main and pairwise differences. +// - projection: Group projection matrix for contrasts. +// - num_categories: Number of categories per variable [V]. +// - observations: Data matrix [N × V]. +// - num_groups: Number of groups. +// - group_indices: Row ranges for each group in `observations`. +// - counts_per_category, blume_capel_stats: Per-group sufficient statistics. +// - pairwise_stats: Per-group pairwise sufficient statistics. +// - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. +// - baseline_category: Reference categories for Blume–Capel variables. +// - pairwise_scale: Scale of overall pairwise prior. +// - difference_scale: Scale of group-difference prior. +// - main_alpha, main_beta: Hyperparameters for main-effect priors. +// - nuts_max_depth: Maximum tree depth for NUTS doubling procedure. +// - iteration: Current sampler iteration (for adaptation scheduling). +// - hmc_adapt: Adaptation controller for step size and mass matrix. +// - learn_mass_matrix: Whether to adapt the mass matrix (unused inside NUTS but relevant to controller). +// - selection: If true, restrict mass matrix to active parameters only. +// - rng: Random number generator. +// +// Returns: +// - A `StepResult` containing the accepted state and diagnostics +// (e.g. tree depth, divergences, energy). +// +// Notes: +// - This variant is specific to bgmCompare, where parameters are stored in +// row-wise structures with group-difference columns. +// - Consistency between vectorization, unvectorization, and gradient +// indexing is enforced via `build_index_maps`. +// - Diagnostics from the returned `StepResult` can be used to monitor +// sampler stability (e.g. divergences, tree depth). StepResult update_nuts_bgmcompare( arma::mat& main_effects, arma::mat& pairwise_effects, @@ -990,55 +978,53 @@ StepResult update_nuts_bgmcompare( -/** - * Adapt proposal standard deviations (SDs) for main and pairwise effects - * during the warmup phase of the bgmCompare sampler. - * - * This function uses a Robbins–Monro stochastic approximation scheme to - * adjust proposal SDs toward a target acceptance rate. Adaptation occurs - * only when permitted by the current warmup schedule. - * - * Workflow: - * 1. For each main effect parameter (ordinal or Blume–Capel), run one - * random-walk Metropolis (RWM) step, update the parameter, and adjust - * the proposal SD. - * 2. For each pairwise effect parameter (overall and group differences), - * do the same. - * 3. Proposal SDs are updated symmetrically across group columns if - * differences are included. - * - * Inputs: - * - proposal_sd_main_effects: Current SDs for main effects, updated in place. - * - proposal_sd_pairwise_effects: Current SDs for pairwise effects, updated in place. - * - main_effects, pairwise_effects: Parameter matrices, updated in place. - * - main_effect_indices, pairwise_effect_indices: Index maps for main/pairwise parameters. - * - inclusion_indicator: Marks which group differences are active. - * - projection: Group projection matrix. - * - num_categories: Categories per variable. - * - observations: Data matrix [N × V]. - * - num_groups: Number of groups. - * - group_indices: Row ranges per group in `observations`. - * - counts_per_category, blume_capel_stats: Per-group sufficient statistics for main effects. - * - pairwise_stats: Per-group sufficient statistics for pairwise effects. - * - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. - * - baseline_category: Reference category for Blume–Capel variables. - * - pairwise_scale: Scale of Cauchy prior for pairwise effects. - * - difference_scale: Scale of Cauchy prior for group differences. - * - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. - * - iteration: Current iteration (to check schedule stage). - * - rng: Random number generator. - * - sched: Warmup schedule controlling when adaptation is active. - * - target_accept: Desired acceptance probability (default 0.44). - * - rm_decay: Robbins–Monro decay rate (default 0.75). - * - * Side effects: - * - Updates `main_effects` and `pairwise_effects` with new parameter values. - * - Updates `proposal_sd_main_effects` and `proposal_sd_pairwise_effects`. - * - * Notes: - * - Adapts only when `sched.adapt_proposal_sd(iteration)` is true. - * - Helps stabilize RWM acceptance rates before switching to sampling. - */ +// Adapt proposal standard deviations (SDs) for main and pairwise effects +// during the warmup phase of the bgmCompare sampler. +// +// This function uses a Robbins–Monro stochastic approximation scheme to +// adjust proposal SDs toward a target acceptance rate. Adaptation occurs +// only when permitted by the current warmup schedule. +// +// Workflow: +// 1. For each main effect parameter (ordinal or Blume–Capel), run one +// random-walk Metropolis (RWM) step, update the parameter, and adjust +// the proposal SD. +// 2. For each pairwise effect parameter (overall and group differences), +// do the same. +// 3. Proposal SDs are updated symmetrically across group columns if +// differences are included. +// +// Inputs: +// - proposal_sd_main_effects: Current SDs for main effects, updated in place. +// - proposal_sd_pairwise_effects: Current SDs for pairwise effects, updated in place. +// - main_effects, pairwise_effects: Parameter matrices, updated in place. +// - main_effect_indices, pairwise_effect_indices: Index maps for main/pairwise parameters. +// - inclusion_indicator: Marks which group differences are active. +// - projection: Group projection matrix. +// - num_categories: Categories per variable. +// - observations: Data matrix [N × V]. +// - num_groups: Number of groups. +// - group_indices: Row ranges per group in `observations`. +// - counts_per_category, blume_capel_stats: Per-group sufficient statistics for main effects. +// - pairwise_stats: Per-group sufficient statistics for pairwise effects. +// - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. +// - baseline_category: Reference category for Blume–Capel variables. +// - pairwise_scale: Scale of Cauchy prior for pairwise effects. +// - difference_scale: Scale of Cauchy prior for group differences. +// - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. +// - iteration: Current iteration (to check schedule stage). +// - rng: Random number generator. +// - sched: Warmup schedule controlling when adaptation is active. +// - target_accept: Desired acceptance probability (default 0.44). +// - rm_decay: Robbins–Monro decay rate (default 0.75). +// +// Side effects: +// - Updates `main_effects` and `pairwise_effects` with new parameter values. +// - Updates `proposal_sd_main_effects` and `proposal_sd_pairwise_effects`. +// +// Notes: +// - Adapts only when `sched.adapt_proposal_sd(iteration)` is true. +// - Helps stabilize RWM acceptance rates before switching to sampling. void tune_proposal_sd_bgmcompare( arma::mat& proposal_sd_main_effects, arma::mat& proposal_sd_pairwise_effects, @@ -1220,55 +1206,53 @@ void tune_proposal_sd_bgmcompare( -/** - * Metropolis–Hastings updates for difference-inclusion indicators in bgmCompare. - * - * This function toggles whether group-level differences are included for - * main effects (diagonal entries of `inclusion_indicator`) and pairwise - * effects (off-diagonal entries). Each update proposes either: - * - Turning a currently excluded difference “on” by drawing a new non-zero - * value from a Gaussian proposal, or - * - Turning an included difference “off” by setting its value(s) to zero. - * - * The acceptance probability combines: - * - Pseudolikelihood ratio (data contribution), - * - Prior ratio on inclusion indicators, - * - Prior ratio on parameter values (Cauchy vs. point-mass-at-zero), - * - Proposal density correction. - * - * Inputs: - * - inclusion_probability_difference: Prior inclusion probabilities for - * group differences [V × V]. - * - index: Matrix mapping pairwise interactions to variable indices. - * - main_effects, pairwise_effects: Parameter matrices, updated in place. - * - main_effect_indices, pairwise_effect_indices: Index maps for parameters. - * - projection: Group projection matrix. - * - observations: Data matrix [N × V]. - * - num_groups: Number of groups. - * - group_indices: Row ranges per group in `observations`. - * - num_categories: Categories per variable [V × G]. - * - inclusion_indicator: Indicator matrix for differences, updated in place. - * - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables [V]. - * - baseline_category: Reference category for Blume–Capel variables [V]. - * - proposal_sd_main, proposal_sd_pairwise: Proposal SD matrices for main - * and pairwise effects. - * - difference_scale: Scale of Cauchy prior for group differences. - * - counts_per_category, blume_capel_stats: Per-group sufficient statistics - * for main effects. - * - pairwise_stats: Per-group sufficient statistics for pairwise effects. - * - rng: Random number generator. - * - * Side effects: - * - Updates `inclusion_indicator` entries for main/pairwise differences. - * - Updates corresponding slices of `main_effects` and `pairwise_effects`. - * - * Notes: - * - For main effects, differences correspond to columns 1..G-1 of the - * parameter matrix. - * - For pairwise effects, differences correspond to columns 1..G-1 of the - * pairwise-effect matrix rows. - * - Ensures symmetry of `inclusion_indicator` for pairwise updates. - */ +// Metropolis–Hastings updates for difference-inclusion indicators in bgmCompare. +// +// This function toggles whether group-level differences are included for +// main effects (diagonal entries of `inclusion_indicator`) and pairwise +// effects (off-diagonal entries). Each update proposes either: +// - Turning a currently excluded difference “on” by drawing a new non-zero +// value from a Gaussian proposal, or +// - Turning an included difference “off” by setting its value(s) to zero. +// +// The acceptance probability combines: +// - Pseudolikelihood ratio (data contribution), +// - Prior ratio on inclusion indicators, +// - Prior ratio on parameter values (Cauchy vs. point-mass-at-zero), +// - Proposal density correction. +// +// Inputs: +// - inclusion_probability_difference: Prior inclusion probabilities for +// group differences [V × V]. +// - index: Matrix mapping pairwise interactions to variable indices. +// - main_effects, pairwise_effects: Parameter matrices, updated in place. +// - main_effect_indices, pairwise_effect_indices: Index maps for parameters. +// - projection: Group projection matrix. +// - observations: Data matrix [N × V]. +// - num_groups: Number of groups. +// - group_indices: Row ranges per group in `observations`. +// - num_categories: Categories per variable [V × G]. +// - inclusion_indicator: Indicator matrix for differences, updated in place. +// - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables [V]. +// - baseline_category: Reference category for Blume–Capel variables [V]. +// - proposal_sd_main, proposal_sd_pairwise: Proposal SD matrices for main +// and pairwise effects. +// - difference_scale: Scale of Cauchy prior for group differences. +// - counts_per_category, blume_capel_stats: Per-group sufficient statistics +// for main effects. +// - pairwise_stats: Per-group sufficient statistics for pairwise effects. +// - rng: Random number generator. +// +// Side effects: +// - Updates `inclusion_indicator` entries for main/pairwise differences. +// - Updates corresponding slices of `main_effects` and `pairwise_effects`. +// +// Notes: +// - For main effects, differences correspond to columns 1..G-1 of the +// parameter matrix. +// - For pairwise effects, differences correspond to columns 1..G-1 of the +// pairwise-effect matrix rows. +// - Ensures symmetry of `inclusion_indicator` for pairwise updates. void update_indicator_differences_metropolis_bgmcompare ( const arma::mat& inclusion_probability_difference, const arma::imat& index, @@ -1464,65 +1448,63 @@ void update_indicator_differences_metropolis_bgmcompare ( -/** - * Perform one Gibbs update step for the bgmCompare model. - * - * This function executes a single iteration of the Gibbs sampler, including: - * - * Step 0: (optional) Initialize graph structure if difference selection - * is enabled and the current iteration marks the start of Stage 3c. - * - * Step 1: (optional) Update inclusion indicators for group differences - * (main and pairwise effects) via Metropolis–Hastings proposals. - * - * Step 2: Update model parameters according to the selected update method: - * - "adaptive-metropolis": Update main and pairwise effects individually - * with random-walk Metropolis and adaptive proposal SDs. - * - "hamiltonian-mc": Update the full parameter vector using HMC. - * - "nuts": Update the full parameter vector using the No-U-Turn Sampler. - * If past burn-in, store NUTS diagnostics (tree depth, divergences, energy). - * - * Step 3: (Stage 3b only) Adapt proposal SDs for Metropolis updates using - * Robbins–Monro tuning. - * - * Inputs: - * - observations: Data matrix [N × V]. - * - num_categories: Number of categories per variable [V]. - * - pairwise_scale, difference_scale: Prior scale parameters. - * - counts_per_category, blume_capel_stats: Sufficient statistics per group. - * - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. - * - inclusion_indicator: Matrix of active group differences [V × V], updated in place. - * - main_effects, pairwise_effects: Parameter matrices, updated in place. - * - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. - * - baseline_category: Reference categories for Blume–Capel variables. - * - iteration: Current iteration index. - * - pairwise_effect_indices, main_effect_indices: Index maps for parameters. - * - pairwise_stats: Per-group pairwise sufficient statistics. - * - nuts_max_depth: Maximum tree depth for NUTS. - * - hmc_adapt: Adaptation controller for HMC/NUTS. - * - metropolis_adapt_main, metropolis_adapt_pair: Adaptation controllers for RWM updates. - * - learn_mass_matrix: Whether to adapt the mass matrix in HMC/NUTS. - * - schedule: Warmup schedule, controls adaptation and selection phases. - * - treedepth_samples, divergent_samples, energy_samples: Buffers for NUTS diagnostics. - * - projection: Group projection matrix. - * - num_groups: Number of groups. - * - group_indices: Row ranges per group in `observations`. - * - rng: Random number generator. - * - inclusion_probability: Prior probabilities for including differences. - * - hmc_nuts_leapfrogs: Number of leapfrog steps for HMC updates. - * - update_method: Update strategy ("adaptive-metropolis", "hamiltonian-mc", "nuts"). - * - proposal_sd_main, proposal_sd_pair: Proposal SD matrices for Metropolis updates. - * - index: Index table for pairwise differences. - * - * Side effects: - * - Updates parameters, inclusion indicators, and sufficient statistics. - * - Updates adaptation controllers and (if NUTS) diagnostic buffers. - * - * Notes: - * - This function encapsulates all update logic for bgmCompare. - * - Choice of `update_method` governs whether updates are local (RWM) or - * global (HMC/NUTS). - */ +// Perform one Gibbs update step for the bgmCompare model. +// +// This function executes a single iteration of the Gibbs sampler, including: +// +// Step 0: (optional) Initialize graph structure if difference selection +// is enabled and the current iteration marks the start of Stage 3c. +// +// Step 1: (optional) Update inclusion indicators for group differences +// (main and pairwise effects) via Metropolis–Hastings proposals. +// +// Step 2: Update model parameters according to the selected update method: +// - "adaptive-metropolis": Update main and pairwise effects individually +// with random-walk Metropolis and adaptive proposal SDs. +// - "hamiltonian-mc": Update the full parameter vector using HMC. +// - "nuts": Update the full parameter vector using the No-U-Turn Sampler. +// If past burn-in, store NUTS diagnostics (tree depth, divergences, energy). +// +// Step 3: (Stage 3b only) Adapt proposal SDs for Metropolis updates using +// Robbins–Monro tuning. +// +// Inputs: +// - observations: Data matrix [N × V]. +// - num_categories: Number of categories per variable [V]. +// - pairwise_scale, difference_scale: Prior scale parameters. +// - counts_per_category, blume_capel_stats: Sufficient statistics per group. +// - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. +// - inclusion_indicator: Matrix of active group differences [V × V], updated in place. +// - main_effects, pairwise_effects: Parameter matrices, updated in place. +// - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. +// - baseline_category: Reference categories for Blume–Capel variables. +// - iteration: Current iteration index. +// - pairwise_effect_indices, main_effect_indices: Index maps for parameters. +// - pairwise_stats: Per-group pairwise sufficient statistics. +// - nuts_max_depth: Maximum tree depth for NUTS. +// - hmc_adapt: Adaptation controller for HMC/NUTS. +// - metropolis_adapt_main, metropolis_adapt_pair: Adaptation controllers for RWM updates. +// - learn_mass_matrix: Whether to adapt the mass matrix in HMC/NUTS. +// - schedule: Warmup schedule, controls adaptation and selection phases. +// - treedepth_samples, divergent_samples, energy_samples: Buffers for NUTS diagnostics. +// - projection: Group projection matrix. +// - num_groups: Number of groups. +// - group_indices: Row ranges per group in `observations`. +// - rng: Random number generator. +// - inclusion_probability: Prior probabilities for including differences. +// - hmc_nuts_leapfrogs: Number of leapfrog steps for HMC updates. +// - update_method: Update strategy ("adaptive-metropolis", "hamiltonian-mc", "nuts"). +// - proposal_sd_main, proposal_sd_pair: Proposal SD matrices for Metropolis updates. +// - index: Index table for pairwise differences. +// +// Side effects: +// - Updates parameters, inclusion indicators, and sufficient statistics. +// - Updates adaptation controllers and (if NUTS) diagnostic buffers. +// +// Notes: +// - This function encapsulates all update logic for bgmCompare. +// - Choice of `update_method` governs whether updates are local (RWM) or +// global (HMC/NUTS). void gibbs_update_step_bgmcompare ( const arma::imat& observations, const arma::ivec& num_categories, @@ -1647,72 +1629,70 @@ void gibbs_update_step_bgmcompare ( -/** - * Run a full Gibbs sampler for the bgmCompare model. - * - * This function controls the full MCMC lifecycle for a single chain: - * - Initializes parameter matrices, proposal SDs, and adaptation controllers. - * - Optionally imputes missing data at each iteration. - * - Executes Gibbs updates for main and pairwise effects, including - * difference-selection if enabled. - * - Adapts step size, mass matrix, and proposal SDs during warmup. - * - Updates inclusion probabilities under the chosen prior - * (e.g. Beta–Bernoulli). - * - Collects posterior samples and diagnostics into a `SamplerOutput` struct. - * - * Inputs: - * - chain_id: Identifier for this chain (1-based). - * - observations: Data matrix [N × V]. - * - num_groups: Number of groups (G). - * - counts_per_category: Per-group sufficient statistics (ordinal variables). - * - blume_capel_stats: Per-group sufficient statistics (Blume–Capel variables). - * - pairwise_stats: Per-group sufficient statistics for pairwise effects. - * - num_categories: Number of categories per variable [V]. - * - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. - * - pairwise_scale: Scale parameter for overall pairwise priors. - * - difference_scale: Scale parameter for group-difference priors. - * - difference_selection_alpha, difference_selection_beta: Hyperparameters - * for difference-selection prior. - * - difference_prior: Prior type for difference-selection ("Beta-Bernoulli", ...). - * - iter: Number of post–burn-in sampling iterations. - * - warmup: Number of warmup iterations. - * - na_impute: If true, impute missing observations at each iteration. - * - missing_data_indices: Matrix of [person, variable] indices of missings. - * - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. - * - baseline_category: Reference categories for Blume–Capel variables. - * - difference_selection: If true, include MH updates for group-difference indicators. - * - main_effect_indices, pairwise_effect_indices: Index maps for parameter rows. - * - target_accept: Target acceptance probability (HMC/NUTS). - * - nuts_max_depth: Maximum tree depth for NUTS. - * - learn_mass_matrix: Whether to adapt the mass matrix (HMC/NUTS). - * - projection: Group projection matrix for contrasts. - * - group_membership: Mapping of persons to groups. - * - group_indices: Row ranges per group in `observations`. - * - interaction_index_matrix: Index map for pairwise interactions. - * - inclusion_probability: Matrix of prior inclusion probabilities, updated in place. - * - rng: Random number generator. - * - update_method: Update strategy ("adaptive-metropolis", "hamiltonian-mc", "nuts"). - * - hmc_num_leapfrogs: Number of leapfrog steps for HMC. - * - * Returns: - * - A `SamplerOutput` struct containing: - * - main_samples: MCMC samples for main effects. - * - pairwise_samples: MCMC samples for pairwise effects. - * - indicator_samples: (optional) Inclusion indicator samples if - * difference-selection is enabled. - * - treedepth_samples, divergent_samples, energy_samples: - * Diagnostics (for NUTS). - * - chain_id: Identifier for this chain. - * - * Notes: - * - Warmup is orchestrated via `WarmupSchedule`, which controls adaptation - * phases and difference-selection activation. - * - Proposal SDs are tuned via Robbins–Monro during Stage 3b. - * - Difference-selection updates toggle inclusion indicators and adjust - * associated parameters with MH proposals. - * - This function runs entirely in C++ and is wrapped for parallel execution - * via `GibbsCompareChainRunner`. - */ +// Run a full Gibbs sampler for the bgmCompare model. +// +// This function controls the full MCMC lifecycle for a single chain: +// - Initializes parameter matrices, proposal SDs, and adaptation controllers. +// - Optionally imputes missing data at each iteration. +// - Executes Gibbs updates for main and pairwise effects, including +// difference-selection if enabled. +// - Adapts step size, mass matrix, and proposal SDs during warmup. +// - Updates inclusion probabilities under the chosen prior +// (e.g. Beta–Bernoulli). +// - Collects posterior samples and diagnostics into a `SamplerOutput` struct. +// +// Inputs: +// - chain_id: Identifier for this chain (1-based). +// - observations: Data matrix [N × V]. +// - num_groups: Number of groups (G). +// - counts_per_category: Per-group sufficient statistics (ordinal variables). +// - blume_capel_stats: Per-group sufficient statistics (Blume–Capel variables). +// - pairwise_stats: Per-group sufficient statistics for pairwise effects. +// - num_categories: Number of categories per variable [V]. +// - main_alpha, main_beta: Hyperparameters for Beta prior on main effects. +// - pairwise_scale: Scale parameter for overall pairwise priors. +// - difference_scale: Scale parameter for group-difference priors. +// - difference_selection_alpha, difference_selection_beta: Hyperparameters +// for difference-selection prior. +// - difference_prior: Prior type for difference-selection ("Beta-Bernoulli", ...). +// - iter: Number of post–burn-in sampling iterations. +// - warmup: Number of warmup iterations. +// - na_impute: If true, impute missing observations at each iteration. +// - missing_data_indices: Matrix of [person, variable] indices of missings. +// - is_ordinal_variable: Marks ordinal vs. Blume–Capel variables. +// - baseline_category: Reference categories for Blume–Capel variables. +// - difference_selection: If true, include MH updates for group-difference indicators. +// - main_effect_indices, pairwise_effect_indices: Index maps for parameter rows. +// - target_accept: Target acceptance probability (HMC/NUTS). +// - nuts_max_depth: Maximum tree depth for NUTS. +// - learn_mass_matrix: Whether to adapt the mass matrix (HMC/NUTS). +// - projection: Group projection matrix for contrasts. +// - group_membership: Mapping of persons to groups. +// - group_indices: Row ranges per group in `observations`. +// - interaction_index_matrix: Index map for pairwise interactions. +// - inclusion_probability: Matrix of prior inclusion probabilities, updated in place. +// - rng: Random number generator. +// - update_method: Update strategy ("adaptive-metropolis", "hamiltonian-mc", "nuts"). +// - hmc_num_leapfrogs: Number of leapfrog steps for HMC. +// +// Returns: +// - A `SamplerOutput` struct containing: +// - main_samples: MCMC samples for main effects. +// - pairwise_samples: MCMC samples for pairwise effects. +// - indicator_samples: (optional) Inclusion indicator samples if +// difference-selection is enabled. +// - treedepth_samples, divergent_samples, energy_samples: +// Diagnostics (for NUTS). +// - chain_id: Identifier for this chain. +// +// Notes: +// - Warmup is orchestrated via `WarmupSchedule`, which controls adaptation +// phases and difference-selection activation. +// - Proposal SDs are tuned via Robbins–Monro during Stage 3b. +// - Difference-selection updates toggle inclusion indicators and adjust +// associated parameters with MH proposals. +// - This function runs entirely in C++ and is wrapped for parallel execution +// via `GibbsCompareChainRunner`. bgmCompareOutput run_gibbs_sampler_bgmCompare( int chain_id, arma::imat observations, diff --git a/src/bgmCompare_interface.cpp b/src/bgmCompare_interface.cpp index 7941a23d..ff1730ff 100644 --- a/src/bgmCompare_interface.cpp +++ b/src/bgmCompare_interface.cpp @@ -15,21 +15,19 @@ using namespace RcppParallel; -/** - * Container for the result of a single MCMC chain (bgmCompare model). - * - * Fields: - * - error: True if the chain terminated with an error, false otherwise. - * - error_msg: Error message if an error occurred (empty if none). - * - chain_id: Integer identifier for the chain (1-based). - * - result: bgmCompareOutput object containing chain results - * (samples, diagnostics, metadata). - * - * Usage: - * - Used in parallel execution to collect results from each chain. - * - Checked after execution to propagate errors or assemble outputs - * into an R-accessible list. - */ +// Container for the result of a single MCMC chain (bgmCompare model). +// +// Fields: +// - error: True if the chain terminated with an error, false otherwise. +// - error_msg: Error message if an error occurred (empty if none). +// - chain_id: Integer identifier for the chain (1-based). +// - result: bgmCompareOutput object containing chain results +// (samples, diagnostics, metadata). +// +// Usage: +// - Used in parallel execution to collect results from each chain. +// - Checked after execution to propagate errors or assemble outputs +// into an R-accessible list. struct bgmCompareChainResult { bool error; std::string error_msg; @@ -39,60 +37,58 @@ struct bgmCompareChainResult { -/** - * Parallel worker for running a single Gibbs sampling chain (bgmCompare model). - * - * This struct wraps all inputs needed for one chain and provides an - * `operator()` so that multiple chains can be launched in parallel with TBB. - * - * Workflow per chain: - * - Construct a chain-specific RNG from `chain_rngs`. - * - Copy master statistics and observation data into per-chain buffers. - * - Call `run_gibbs_sampler_bgmCompare()` to execute the full chain. - * - Catch and record any errors (sets `error = true` and stores `error_msg`). - * - Store results into the shared `results` vector at the chain index. - * - * Inputs (stored as const references or values): - * - observations_master: Input observation matrix (persons × variables). - * - num_groups: Number of groups. - * - counts_per_category_master: Group-level category counts. - * - blume_capel_stats_master: Group-level Blume–Capel sufficient statistics. - * - pairwise_stats_master: Group-level pairwise sufficient statistics. - * - num_categories: Number of categories per variable. - * - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. - * - pairwise_scale: Scale for Cauchy prior on baseline pairwise effects. - * - difference_scale: Scale for Cauchy prior on group differences. - * - difference_selection_alpha, difference_selection_beta: Hyperparameters for difference-selection prior. - * - difference_prior: Choice of prior distribution for group differences. - * - iter, warmup: Iteration counts. - * - na_impute: If true, perform missing data imputation. - * - missing_data_indices: Indices of missing observations. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - difference_selection: If true, perform difference selection updates. - * - main_effect_indices: Index ranges for main effects. - * - pairwise_effect_indices: Index ranges for pairwise effects. - * - target_accept: Target acceptance rate for adaptive methods. - * - nuts_max_depth: Maximum tree depth for NUTS. - * - learn_mass_matrix: If true, adapt mass matrix during warmup. - * - projection: Group projection matrix. - * - group_membership: Group assignment for each observation. - * - group_indices: Row ranges [start,end] for each group in observations. - * - interaction_index_matrix: Lookup table of variable pairs. - * - inclusion_probability_master: Prior inclusion probabilities for pairwise effects. - * - chain_rngs: Pre-initialized RNG engines (one per chain). - * - update_method: Sampler type ("adaptive-metropolis", "hamiltonian-mc", "nuts"). - * - hmc_num_leapfrogs: Number of leapfrog steps (HMC). - * - * Output: - * - results: Vector of `bgmCompareChainResult` objects, one per chain, filled in place. - * - * Notes: - * - Each worker instance is shared across threads but invoked with different - * [begin,end) ranges, corresponding to chain indices. - * - Per-chain copies of statistics and observations prevent cross-thread mutation. - * - Errors are caught locally so one failing chain does not crash the entire run. - */ +// Parallel worker for running a single Gibbs sampling chain (bgmCompare model). +// +// This struct wraps all inputs needed for one chain and provides an +// `operator()` so that multiple chains can be launched in parallel with TBB. +// +// Workflow per chain: +// - Construct a chain-specific RNG from `chain_rngs`. +// - Copy master statistics and observation data into per-chain buffers. +// - Call `run_gibbs_sampler_bgmCompare()` to execute the full chain. +// - Catch and record any errors (sets `error = true` and stores `error_msg`). +// - Store results into the shared `results` vector at the chain index. +// +// Inputs (stored as const references or values): +// - observations_master: Input observation matrix (persons × variables). +// - num_groups: Number of groups. +// - counts_per_category_master: Group-level category counts. +// - blume_capel_stats_master: Group-level Blume–Capel sufficient statistics. +// - pairwise_stats_master: Group-level pairwise sufficient statistics. +// - num_categories: Number of categories per variable. +// - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. +// - pairwise_scale: Scale for Cauchy prior on baseline pairwise effects. +// - difference_scale: Scale for Cauchy prior on group differences. +// - difference_selection_alpha, difference_selection_beta: Hyperparameters for difference-selection prior. +// - difference_prior: Choice of prior distribution for group differences. +// - iter, warmup: Iteration counts. +// - na_impute: If true, perform missing data imputation. +// - missing_data_indices: Indices of missing observations. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - difference_selection: If true, perform difference selection updates. +// - main_effect_indices: Index ranges for main effects. +// - pairwise_effect_indices: Index ranges for pairwise effects. +// - target_accept: Target acceptance rate for adaptive methods. +// - nuts_max_depth: Maximum tree depth for NUTS. +// - learn_mass_matrix: If true, adapt mass matrix during warmup. +// - projection: Group projection matrix. +// - group_membership: Group assignment for each observation. +// - group_indices: Row ranges [start,end] for each group in observations. +// - interaction_index_matrix: Lookup table of variable pairs. +// - inclusion_probability_master: Prior inclusion probabilities for pairwise effects. +// - chain_rngs: Pre-initialized RNG engines (one per chain). +// - update_method: Sampler type ("adaptive-metropolis", "hamiltonian-mc", "nuts"). +// - hmc_num_leapfrogs: Number of leapfrog steps (HMC). +// +// Output: +// - results: Vector of `bgmCompareChainResult` objects, one per chain, filled in place. +// +// Notes: +// - Each worker instance is shared across threads but invoked with different +// [begin,end) ranges, corresponding to chain indices. +// - Per-chain copies of statistics and observations prevent cross-thread mutation. +// - Errors are caught locally so one failing chain does not crash the entire run. struct GibbsCompareChainRunner : public Worker { const arma::imat& observations_master; const int num_groups; @@ -287,69 +283,67 @@ struct GibbsCompareChainRunner : public Worker { -/** - * Runs multiple parallel Gibbs sampling chains for the bgmCompare model. - * - * This function is the main entry point from R into the C++ backend for bgmCompare. - * It launches `num_chains` independent chains in parallel using TBB, - * each managed by `GibbsCompareChainRunner`. - * - * Workflow: - * - Initialize a per-chain RNG from the global seed. - * - Construct a `GibbsCompareChainRunner` worker with all shared inputs. - * - Launch the worker across chains using `parallelFor`. - * - Collect results from all chains into an Rcpp::List. - * - * Inputs: - * - observations: Observation matrix (persons × variables). - * - num_groups: Number of groups. - * - counts_per_category: Group-level category counts (for ordinal variables). - * - blume_capel_stats: Group-level sufficient statistics (for Blume–Capel variables). - * - pairwise_stats: Group-level pairwise sufficient statistics. - * - num_categories: Number of categories per variable. - * - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. - * - pairwise_scale: Scale for Cauchy prior on baseline pairwise effects. - * - difference_scale: Scale for Cauchy prior on group differences. - * - difference_selection_alpha, difference_selection_beta: Hyperparameters for difference-selection prior. - * - difference_prior: Choice of prior distribution for group differences. - * - iter: Number of post-warmup iterations to draw. - * - warmup: Number of warmup iterations. - * - na_impute: If true, perform missing data imputation during sampling. - * - missing_data_indices: Indices of missing entries in observations. - * - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). - * - baseline_category: Reference categories for Blume–Capel variables. - * - difference_selection: If true, perform difference selection updates. - * - main_effect_indices: Index ranges [row_start,row_end] for each variable. - * - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. - * - target_accept: Target acceptance rate for adaptive samplers. - * - nuts_max_depth: Maximum tree depth for NUTS. - * - learn_mass_matrix: If true, adapt the mass matrix during warmup. - * - projection: Group projection matrix (num_groups × (num_groups − 1)). - * - group_membership: Group assignment for each observation. - * - group_indices: Row ranges [start,end] for each group in observations. - * - interaction_index_matrix: Lookup table of variable pairs. - * - inclusion_probability: Prior inclusion probabilities for pairwise effects. - * - num_chains: Number of chains to run. - * - nThreads: Maximum number of threads for parallel execution. - * - seed: Base random seed (incremented per chain). - * - update_method: Sampler type ("adaptive-metropolis", "hamiltonian-mc", "nuts"). - * - hmc_num_leapfrogs: Number of leapfrog steps for HMC. - * - * Returns: - * - Rcpp::List of length `num_chains`, where each element is either: - * * An error record (fields: "error", "chain_id"), if the chain failed. - * * A result list containing: - * - "main_samples": Posterior samples of main effects. - * - "pairwise_samples": Posterior samples of pairwise effects. - * - "treedepth__", "divergent__", "energy__": NUTS diagnostics. - * - "indicator_samples": Inclusion indicators (if selection was enabled). - * - "chain_id": Identifier of the chain. - * - * Notes: - * - Parallel execution is controlled by TBB; `nThreads` limits concurrency. - * - Each chain gets its own RNG stream, initialized as `seed + chain_id`. - * - This function is called by the exported R function `bgmCompare()`. - */ +// Runs multiple parallel Gibbs sampling chains for the bgmCompare model. +// +// This function is the main entry point from R into the C++ backend for bgmCompare. +// It launches `num_chains` independent chains in parallel using TBB, +// each managed by `GibbsCompareChainRunner`. +// +// Workflow: +// - Initialize a per-chain RNG from the global seed. +// - Construct a `GibbsCompareChainRunner` worker with all shared inputs. +// - Launch the worker across chains using `parallelFor`. +// - Collect results from all chains into an Rcpp::List. +// +// Inputs: +// - observations: Observation matrix (persons × variables). +// - num_groups: Number of groups. +// - counts_per_category: Group-level category counts (for ordinal variables). +// - blume_capel_stats: Group-level sufficient statistics (for Blume–Capel variables). +// - pairwise_stats: Group-level pairwise sufficient statistics. +// - num_categories: Number of categories per variable. +// - main_alpha, main_beta: Hyperparameters for Beta priors on main effects. +// - pairwise_scale: Scale for Cauchy prior on baseline pairwise effects. +// - difference_scale: Scale for Cauchy prior on group differences. +// - difference_selection_alpha, difference_selection_beta: Hyperparameters for difference-selection prior. +// - difference_prior: Choice of prior distribution for group differences. +// - iter: Number of post-warmup iterations to draw. +// - warmup: Number of warmup iterations. +// - na_impute: If true, perform missing data imputation during sampling. +// - missing_data_indices: Indices of missing entries in observations. +// - is_ordinal_variable: Indicator (1 = ordinal, 0 = Blume–Capel). +// - baseline_category: Reference categories for Blume–Capel variables. +// - difference_selection: If true, perform difference selection updates. +// - main_effect_indices: Index ranges [row_start,row_end] for each variable. +// - pairwise_effect_indices: Lookup table mapping (var1,var2) → row in pairwise_effects. +// - target_accept: Target acceptance rate for adaptive samplers. +// - nuts_max_depth: Maximum tree depth for NUTS. +// - learn_mass_matrix: If true, adapt the mass matrix during warmup. +// - projection: Group projection matrix (num_groups × (num_groups − 1)). +// - group_membership: Group assignment for each observation. +// - group_indices: Row ranges [start,end] for each group in observations. +// - interaction_index_matrix: Lookup table of variable pairs. +// - inclusion_probability: Prior inclusion probabilities for pairwise effects. +// - num_chains: Number of chains to run. +// - nThreads: Maximum number of threads for parallel execution. +// - seed: Base random seed (incremented per chain). +// - update_method: Sampler type ("adaptive-metropolis", "hamiltonian-mc", "nuts"). +// - hmc_num_leapfrogs: Number of leapfrog steps for HMC. +// +// Returns: +// - Rcpp::List of length `num_chains`, where each element is either: +// * An error record (fields: "error", "chain_id"), if the chain failed. +// * A result list containing: +// - "main_samples": Posterior samples of main effects. +// - "pairwise_samples": Posterior samples of pairwise effects. +// - "treedepth__", "divergent__", "energy__": NUTS diagnostics. +// - "indicator_samples": Inclusion indicators (if selection was enabled). +// - "chain_id": Identifier of the chain. +// +// Notes: +// - Parallel execution is controlled by TBB; `nThreads` limits concurrency. +// - Each chain gets its own RNG stream, initialized as `seed + chain_id`. +// - This function is called by the exported R function `bgmCompare()`. // [[Rcpp::export]] Rcpp::List run_bgmCompare_parallel( const arma::imat& observations, diff --git a/src/math/cholesky_helpers.h b/src/math/cholesky_helpers.h new file mode 100644 index 00000000..4bbf09e2 --- /dev/null +++ b/src/math/cholesky_helpers.h @@ -0,0 +1,45 @@ +#pragma once + +/** + * @file cholesky_helpers.h + * @brief Shared algebraic helpers for Cholesky-based precision updates. + * + * Pure functions with no model-specific state. Used by both GGMModel and + * MixedMRFModel for proposal constant extraction and log-determinant + * computation. + */ + +#include +#include + +namespace cholesky_helpers { + +/** + * Log-determinant of a positive-definite matrix from its upper-triangular + * Cholesky factor R (where Ω = R'R). + * + * @param R Upper-triangular Cholesky factor. + * @return log|Ω| = 2 Σ log(R_ii). + */ +inline double get_log_det(const arma::mat& R) { + return 2.0 * arma::accu(arma::log(R.diag())); +} + +/** + * Schur complement element: A(ii,jj) − A(ii,i) A(jj,i) / A(i,i). + * + * Used to compute entries of the inverse of a submatrix from the full + * covariance matrix. + * + * @param A Symmetric positive-definite matrix. + * @param i Conditioning index. + * @param ii Row index of the desired element. + * @param jj Column index of the desired element. + * @return Schur complement entry. + */ +inline double compute_inv_submatrix_i(const arma::mat& A, size_t i, + size_t ii, size_t jj) { + return A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i); +} + +} // namespace cholesky_helpers diff --git a/src/models/ggm/cholupdate.cpp b/src/math/cholupdate.cpp similarity index 99% rename from src/models/ggm/cholupdate.cpp rename to src/math/cholupdate.cpp index 81331be9..3d3776b9 100644 --- a/src/models/ggm/cholupdate.cpp +++ b/src/math/cholupdate.cpp @@ -1,4 +1,4 @@ -#include "models/ggm/cholupdate.h" +#include "math/cholupdate.h" extern "C" { diff --git a/src/models/ggm/cholupdate.h b/src/math/cholupdate.h similarity index 100% rename from src/models/ggm/cholupdate.h rename to src/math/cholupdate.h diff --git a/src/mcmc/algorithms/nuts.cpp b/src/mcmc/algorithms/nuts.cpp index 5c5a4dcb..aaeb89cf 100644 --- a/src/mcmc/algorithms/nuts.cpp +++ b/src/mcmc/algorithms/nuts.cpp @@ -8,26 +8,22 @@ #include "rng/rng_utils.h" -/** - * The generalized U-turn criterion used here is described in Betancourt (2017). - * The implementation follows the approach in STAN's base_nuts.hpp (BSD-3-Clause license). - * - * References: - * Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. - * arXiv preprint arXiv:1701.02434. - * Stan Development Team. base_nuts.hpp. - * https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/base_nuts.hpp - */ - - -/** - * Computes the generalized U-turn criterion for the NUTS algorithm - * - * @param p_sharp_minus Sharp momentum (M^{-1} p) at backward end - * @param p_sharp_plus Sharp momentum (M^{-1} p) at forward end - * @param rho Sum of momenta along the trajectory - * @return true if criterion satisfied (continue), false if U-turn detected (stop) - */ +// The generalized U-turn criterion used here is described in Betancourt (2017). +// The implementation follows the approach in STAN's base_nuts.hpp (BSD-3-Clause license). +// +// References: +// Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. +// arXiv preprint arXiv:1701.02434. +// Stan Development Team. base_nuts.hpp. +// https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/base_nuts.hpp + + +// Computes the generalized U-turn criterion for the NUTS algorithm +// +// @param p_sharp_minus Sharp momentum (M^{-1} p) at backward end +// @param p_sharp_plus Sharp momentum (M^{-1} p) at forward end +// @param rho Sum of momenta along the trajectory +// @return true if criterion satisfied (continue), false if U-turn detected (stop) bool compute_criterion(const arma::vec& p_sharp_minus, const arma::vec& p_sharp_plus, const arma::vec& rho) { @@ -36,25 +32,23 @@ bool compute_criterion(const arma::vec& p_sharp_minus, -/** - * Recursively builds a binary tree of leapfrog steps in the NUTS algorithm - * - * Explores forward or backward in time, evaluating trajectory termination - * criteria. Based on Algorithm 6 in Hoffman & Gelman (2014). - * - * @param theta Current position at the base of the tree - * @param r Current momentum at the base of the tree - * @param log_u Log slice variable for accept/reject decision - * @param v Direction of expansion (-1 backward, +1 forward) - * @param j Current tree depth - * @param step_size Step size used in leapfrog integration - * @param theta_0 Initial position at the start of sampling - * @param r0 Initial momentum at the start of sampling - * @param logp0 Log posterior at initial position - * @param kin0 Kinetic energy at initial momentum - * @param memo Memoizer object for caching evaluations - * @return BuildTreeResult with updated endpoints, candidate sample, and diagnostics - */ +// Recursively builds a binary tree of leapfrog steps in the NUTS algorithm +// +// Explores forward or backward in time, evaluating trajectory termination +// criteria. Based on Algorithm 6 in Hoffman & Gelman (2014). +// +// @param theta Current position at the base of the tree +// @param r Current momentum at the base of the tree +// @param log_u Log slice variable for accept/reject decision +// @param v Direction of expansion (-1 backward, +1 forward) +// @param j Current tree depth +// @param step_size Step size used in leapfrog integration +// @param theta_0 Initial position at the start of sampling +// @param r0 Initial momentum at the start of sampling +// @param logp0 Log posterior at initial position +// @param kin0 Kinetic energy at initial momentum +// @param memo Memoizer object for caching evaluations +// @return BuildTreeResult with updated endpoints, candidate sample, and diagnostics BuildTreeResult build_tree( const arma::vec& theta, const arma::vec& r, diff --git a/src/mcmc/execution/chain_result.h b/src/mcmc/execution/chain_result.h index 32e24e8d..ae1b1e3c 100644 --- a/src/mcmc/execution/chain_result.h +++ b/src/mcmc/execution/chain_result.h @@ -14,29 +14,36 @@ class ChainResult { public: ChainResult() = default; - // Error handling + /// True if the chain terminated with an error. bool error = false; + /// True if the chain was interrupted by the user. bool userInterrupt = false; + /// Error message (empty if none). std::string error_msg; - // Chain identifier + /// Integer identifier for the chain (1-based). int chain_id = 0; - // Parameter samples (param_dim × n_iter) + /// Parameter samples (param_dim x n_iter). arma::mat samples; - // Edge indicator samples (n_edges × n_iter), only if edge_selection = true + /// Edge indicator samples (n_edges x n_iter), only if edge_selection = true. arma::imat indicator_samples; + /// Whether indicator samples are stored. bool has_indicators = false; - // SBM allocation samples (n_variables × n_iter), only if SBM edge prior + /// SBM allocation samples (n_variables x n_iter), only if SBM edge prior. arma::imat allocation_samples; + /// Whether allocation samples are stored. bool has_allocations = false; - // NUTS/HMC diagnostics (n_iter), only if using NUTS/HMC + /// NUTS/HMC tree depth diagnostics (n_iter). arma::ivec treedepth_samples; + /// NUTS/HMC divergent transition flags (n_iter). arma::ivec divergent_samples; + /// NUTS/HMC energy diagnostic (n_iter). arma::vec energy_samples; + /// Whether NUTS/HMC diagnostics are stored. bool has_nuts_diagnostics = false; /** diff --git a/src/mcmc/execution/chain_runner.cpp b/src/mcmc/execution/chain_runner.cpp index 8d4bdcfa..1635b15b 100644 --- a/src/mcmc/execution/chain_runner.cpp +++ b/src/mcmc/execution/chain_runner.cpp @@ -5,11 +5,14 @@ #include "mcmc/samplers/nuts_sampler.h" #include "mcmc/samplers/hmc_sampler.h" #include "mcmc/samplers/metropolis_sampler.h" +#include "mcmc/samplers/hybrid_nuts_sampler.h" std::unique_ptr create_sampler(const SamplerConfig& config, WarmupSchedule& schedule) { if (config.sampler_type == "nuts") { return std::make_unique(config, schedule); + } else if (config.sampler_type == "hybrid-nuts") { + return std::make_unique(config, schedule); } else if (config.sampler_type == "hmc" || config.sampler_type == "hamiltonian-mc") { return std::make_unique(config, schedule); } else if (config.sampler_type == "mh" || config.sampler_type == "adaptive-metropolis") { @@ -32,6 +35,7 @@ void run_mcmc_chain( // Construct warmup schedule (shared by runner and sampler) const bool learn_sd = (config.sampler_type == "nuts" || + config.sampler_type == "hybrid-nuts" || config.sampler_type == "hmc" || config.sampler_type == "hamiltonian-mc"); WarmupSchedule schedule(config.no_warmup, config.edge_selection, learn_sd); @@ -91,7 +95,7 @@ void run_mcmc_chain( } } - chain_result.store_sample(sample_index, model.get_full_vectorized_parameters()); + chain_result.store_sample(sample_index, model.get_storage_vectorized_parameters()); if (chain_result.has_indicators) { chain_result.store_indicators(sample_index, model.get_vectorized_indicator_parameters()); @@ -139,13 +143,14 @@ std::vector run_mcmc_sampler( const int no_threads, ProgressManager& pm ) { - const bool has_nuts_diag = (config.sampler_type == "nuts"); + const bool has_nuts_diag = (config.sampler_type == "nuts" || + config.sampler_type == "hybrid-nuts"); const bool has_sbm_alloc = edge_prior.has_allocations() || (config.edge_selection && dynamic_cast(&edge_prior) != nullptr); std::vector results(no_chains); for (int c = 0; c < no_chains; ++c) { - results[c].reserve(model.full_parameter_dimension(), config.no_iter); + results[c].reserve(model.storage_dimension(), config.no_iter); if (config.edge_selection) { size_t n_edges = model.get_vectorized_indicator_parameters().n_elem; diff --git a/src/mcmc/execution/sampler_config.h b/src/mcmc/execution/sampler_config.h index d2ce341b..595c819a 100644 --- a/src/mcmc/execution/sampler_config.h +++ b/src/mcmc/execution/sampler_config.h @@ -12,28 +12,32 @@ * - Edge selection settings */ struct SamplerConfig { - // Sampler type: "adaptive_metropolis", "nuts", "hmc" + /// Sampler type: "adaptive_metropolis", "nuts", or "hmc". std::string sampler_type = "adaptive_metropolis"; - // Iteration counts + /// Number of post-warmup iterations. int no_iter = 1000; + /// Number of warmup iterations. int no_warmup = 500; - // NUTS/HMC parameters + /// Maximum NUTS tree depth. int max_tree_depth = 10; - int num_leapfrogs = 10; // For HMC only + /// Number of leapfrog steps (HMC only). + int num_leapfrogs = 10; + /// Initial step size for gradient-based samplers. double initial_step_size = 0.1; + /// Target acceptance rate for dual-averaging adaptation. double target_acceptance = 0.8; - // Edge selection settings + /// Enable spike-and-slab edge selection. bool edge_selection = false; - // Missing data imputation + /// Enable missing-data imputation during sampling. bool na_impute = false; - // Random seed + /// Random seed. int seed = 42; - // Constructor with defaults + /// Default constructor. SamplerConfig() = default; }; diff --git a/src/mcmc/samplers/hmc_adaptation.h b/src/mcmc/samplers/hmc_adaptation.h index 9c31de64..064f0bbe 100644 --- a/src/mcmc/samplers/hmc_adaptation.h +++ b/src/mcmc/samplers/hmc_adaptation.h @@ -16,13 +16,21 @@ */ class DualAveraging { public: + /// Current log step size. double log_step_size; + /// Smoothed log step size (final estimate). double log_step_size_avg; + /// Running error statistic. double hbar; + /// Bias term: log(10 * initial_step_size). double mu; + /// Shrinkage parameter (default 0.05). double gamma; + /// Stabilisation offset (default 10). double t0; + /// Decay exponent for averaging weights (default 0.75). double kappa; + /// Iteration counter. int t; DualAveraging(double initial_step_size) @@ -68,8 +76,11 @@ class DualAveraging { */ class DiagMassMatrixAccumulator { public: + /// Number of samples accumulated. int count; + /// Running mean of parameter samples. arma::vec mean; + /// Running sum of squared deviations (Welford M2 statistic). arma::vec m2; DiagMassMatrixAccumulator(int dim) diff --git a/src/mcmc/samplers/hybrid_nuts_sampler.h b/src/mcmc/samplers/hybrid_nuts_sampler.h new file mode 100644 index 00000000..344d35d7 --- /dev/null +++ b/src/mcmc/samplers/hybrid_nuts_sampler.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include "mcmc/samplers/sampler_base.h" +#include "mcmc/algorithms/nuts.h" +#include "models/mixed/mixed_mrf_model.h" + +/** + * HybridNUTSSampler - NUTS for unconstrained block + MH for Kyy + * + * Designed for MixedMRFModel where the NUTS block covers (mux, Kxx, muy, Kxy) + * and the SPD-constrained Kyy is updated via component-wise Metropolis. + * Inherits warmup adaptation (step size, diagonal mass matrix) from + * GradientSamplerBase for the NUTS block. Kyy proposal SDs are adapted + * via the embedded Robbins-Monro schedule inside the model. + */ +class HybridNUTSSampler : public GradientSamplerBase { +public: + explicit HybridNUTSSampler(const SamplerConfig& config, WarmupSchedule& schedule) + : GradientSamplerBase(config.initial_step_size, config.target_acceptance, schedule), + max_tree_depth_(config.max_tree_depth), + schedule_(schedule) + {} + + bool has_nuts_diagnostics() const override { return true; } + + void initialize(BaseModel& model) override { + // Initialize NUTS adaptation (step-size heuristic + mass matrix) + GradientSamplerBase::initialize(model); + + // Initialize Kyy Metropolis adaptation (stores total_warmup_) + model.init_metropolis_adaptation(schedule_); + } + +protected: + StepResult do_gradient_step(BaseModel& model) override { + // --- Phase 1: NUTS step for the unconstrained block --- + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + auto joint_fn = [&model](const arma::vec& params) + -> std::pair { + return model.logp_and_gradient(params); + }; + + arma::vec active_inv_mass = model.get_active_inv_mass(); + + StepResult result = nuts_step( + theta, step_size_, joint_fn, + active_inv_mass, rng, max_tree_depth_ + ); + + model.set_vectorized_parameters(result.state); + + // --- Phase 2: Kyy Metropolis step --- + auto& mixed = static_cast(model); + mixed.do_pairwise_continuous_metropolis_step(current_iteration_); + + return result; + } + +private: + int max_tree_depth_; + WarmupSchedule& schedule_; + +public: + // The iteration counter is set by the overridden step() method via + // the base class. We store it so do_gradient_step() can pass it + // to the Kyy Metropolis update for Robbins-Monro adaptation. + StepResult step(BaseModel& model, int iteration) override { + current_iteration_ = iteration; + return GradientSamplerBase::step(model, iteration); + } + +private: + int current_iteration_ = -1; +}; diff --git a/src/mcmc/samplers/metropolis_adaptation.h b/src/mcmc/samplers/metropolis_adaptation.h index 86eb8e06..ebd29e71 100644 --- a/src/mcmc/samplers/metropolis_adaptation.h +++ b/src/mcmc/samplers/metropolis_adaptation.h @@ -15,8 +15,11 @@ */ class MetropolisAdaptationController { public: + /// Reference to the proposal standard deviation matrix (modified in place). arma::mat& proposal_sd; + /// Total number of warmup iterations. const int total_warmup; + /// Target Metropolis acceptance rate. const double target_accept; MetropolisAdaptationController(arma::mat& proposal_sd_matrix, diff --git a/src/models/base_model.h b/src/models/base_model.h index 145ffebc..95981f92 100644 --- a/src/models/base_model.h +++ b/src/models/base_model.h @@ -125,7 +125,7 @@ class BaseModel { // ========================================================================= /** - * Update edge indicators via Metropolis-Hastings birth/death moves. + * Update edge indicators via Metropolis-Hastings add-delete moves. * * Only meaningful when has_edge_selection() returns true. GGMModel * handles this inside do_one_metropolis_step() instead. @@ -153,7 +153,11 @@ class BaseModel { /** * @return Full parameter dimension (fixed size, includes inactive parameters). * - * Used for fixed-size sample storage. Defaults to parameter_dimension(). + * Used by GradientSamplerBase for mass-matrix sizing and adaptation. + * For most models this equals the storage dimension. For models where + * some parameters are not sampled by NUTS (e.g., MixedMRFModel's Kyy), + * this returns the NUTS-block dimension. + * Defaults to parameter_dimension(). */ virtual size_t full_parameter_dimension() const { return parameter_dimension(); @@ -162,14 +166,35 @@ class BaseModel { /** * @return All parameters in a fixed-size vector (inactive edges are 0). * - * Used for sample storage to avoid dimension changes when edges are - * toggled on/off. + * Used by GradientSamplerBase for adaptation (online covariance). + * Dimension must match full_parameter_dimension(). */ virtual arma::vec get_full_vectorized_parameters() const = 0; /** @return Dimensionality of the active parameter space. Pure virtual. */ virtual size_t parameter_dimension() const = 0; + /** + * @return Dimension for sample storage (includes all parameters). + * + * For most models this equals full_parameter_dimension(). Override + * when storage needs more entries than the NUTS block (e.g., Kyy + * parameters in MixedMRFModel). + */ + virtual size_t storage_dimension() const { + return full_parameter_dimension(); + } + + /** + * @return All parameters in a fixed-size vector for sample storage. + * + * Dimension must match storage_dimension(). Default delegates to + * get_full_vectorized_parameters(). + */ + virtual arma::vec get_storage_vectorized_parameters() const { + return get_full_vectorized_parameters(); + } + // ========================================================================= // Infrastructure // ========================================================================= @@ -220,7 +245,7 @@ class BaseModel { /** * Enable or disable edge-selection proposals. - * @param active true to enable edge birth/death moves + * @param active true to enable edge add-delete moves */ virtual void set_edge_selection_active(bool active) { (void)active; diff --git a/src/models/ggm/ggm_model.cpp b/src/models/ggm/ggm_model.cpp index 725cca17..8dea6782 100644 --- a/src/models/ggm/ggm_model.cpp +++ b/src/models/ggm/ggm_model.cpp @@ -1,27 +1,24 @@ #include "models/ggm/ggm_model.h" #include "rng/rng_utils.h" -#include "models/ggm/cholupdate.h" +#include "math/explog_macros.h" +#include "math/cholupdate.h" #include "mcmc/execution/step_result.h" #include "mcmc/execution/warmup_schedule.h" -double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { - return(A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i)); -} - void GGMModel::get_constants(size_t i, size_t j) { - double logdet_omega = get_log_det(cholesky_of_precision_); + double logdet_omega = cholesky_helpers::get_log_det(cholesky_of_precision_); - double log_adj_omega_ii = logdet_omega + std::log(std::abs(covariance_matrix_(i, i))); - double log_adj_omega_ij = logdet_omega + std::log(std::abs(covariance_matrix_(i, j))); - double log_adj_omega_jj = logdet_omega + std::log(std::abs(covariance_matrix_(j, j))); + double log_adj_omega_ii = logdet_omega + MY_LOG(std::abs(covariance_matrix_(i, i))); + double log_adj_omega_ij = logdet_omega + MY_LOG(std::abs(covariance_matrix_(i, j))); + double log_adj_omega_jj = logdet_omega + MY_LOG(std::abs(covariance_matrix_(j, j))); - double inv_omega_sub_j1j1 = compute_inv_submatrix_i(covariance_matrix_, i, j, j); - double log_abs_inv_omega_sub_jj = log_adj_omega_ii + std::log(std::abs(inv_omega_sub_j1j1)); - double Phi_q1q = (2 * std::signbit(covariance_matrix_(i, j)) - 1) * std::exp( + double inv_omega_sub_j1j1 = cholesky_helpers::compute_inv_submatrix_i(covariance_matrix_, i, j, j); + double log_abs_inv_omega_sub_jj = log_adj_omega_ii + MY_LOG(std::abs(inv_omega_sub_j1j1)); + double Phi_q1q = (2 * std::signbit(covariance_matrix_(i, j)) - 1) * MY_EXP( (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) ); - double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); + double Phi_q1q1 = MY_EXP((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); constants_[0] = Phi_q1q; constants_[1] = Phi_q1q1; @@ -40,17 +37,12 @@ double GGMModel::constrained_diagonal(const double x) const { } } -double GGMModel::get_log_det(arma::mat triangular_A) const { - // log-determinant of A'A where A is upper-triangular Cholesky factor - return 2 * arma::accu(arma::log(triangular_A.diag())); -} - double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { - double logdet_omega = get_log_det(phi); + double logdet_omega = cholesky_helpers::get_log_det(phi); double trace_prod = arma::accu(omega % suf_stat_); - double log_likelihood = n_ * (p_ * log(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; + double log_likelihood = n_ * (p_ * MY_LOG(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; return log_likelihood; } @@ -66,7 +58,7 @@ double GGMModel::log_density_impl_edge(size_t i, size_t j) const { double cc12 = 1 - (covariance_matrix_(i, j) * Ui2 + covariance_matrix_(j, j) * Uj2); double cc22 = 0 + Ui2 * Ui2 * covariance_matrix_(i, i) + 2 * Ui2 * Uj2 * covariance_matrix_(i, j) + Uj2 * Uj2 * covariance_matrix_(j, j); - double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + double logdet = MY_LOG(std::abs(cc11 * cc22 - cc12 * cc12)); // logdet - (logdet(aOmega_prop) - logdet(aOmega)) double trace_prod = -2 * (suf_stat_(j, j) * Uj2 + suf_stat_(i, j) * Ui2); @@ -84,7 +76,7 @@ double GGMModel::log_density_impl_diag(size_t j) const { double cc12 = 1 - covariance_matrix_(j, j) * Uj2; double cc22 = 0 + Uj2 * Uj2 * covariance_matrix_(j, j); - double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + double logdet = MY_LOG(std::abs(cc11 * cc22 - cc12 * cc12)); double trace_prod = -2 * suf_stat_(j, j) * Uj2; double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; @@ -120,7 +112,7 @@ void GGMModel::update_edge_parameter(size_t i, size_t j, int iteration) { ln_alpha += R::dcauchy(precision_proposal_(i, j), 0.0, pairwise_scale_, true); ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, pairwise_scale_, true); - if (std::log(runif(rng_)) < ln_alpha) { + if (MY_LOG(runif(rng_)) < ln_alpha) { double omega_ij_old = precision_matrix_(i, j); double omega_jj_old = precision_matrix_(j, j); @@ -173,8 +165,8 @@ void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_o } void GGMModel::update_diagonal_parameter(size_t i, int iteration) { - double logdet_omega = get_log_det(cholesky_of_precision_); - double logdet_omega_sub_ii = logdet_omega + std::log(covariance_matrix_(i, i)); + double logdet_omega = cholesky_helpers::get_log_det(cholesky_of_precision_); + double logdet_omega_sub_ii = logdet_omega + MY_LOG(covariance_matrix_(i, i)); size_t e = i * (i + 3) / 2; // parameter index in vectorized form (column-major upper triangle, i==j) double proposal_sd = proposal_sds_(e); @@ -183,15 +175,15 @@ void GGMModel::update_diagonal_parameter(size_t i, int iteration) { double theta_prop = rnorm(rng_, theta_curr, proposal_sd); precision_proposal_ = precision_matrix_; - precision_proposal_(i, i) = precision_matrix_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); + precision_proposal_(i, i) = precision_matrix_(i, i) - MY_EXP(theta_curr) * MY_EXP(theta_curr) + MY_EXP(theta_prop) * MY_EXP(theta_prop); double ln_alpha = log_density_impl_diag(i); - ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); - ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); + ln_alpha += R::dgamma(MY_EXP(theta_prop), 1.0, 1.0, true); + ln_alpha -= R::dgamma(MY_EXP(theta_curr), 1.0, 1.0, true); ln_alpha += theta_prop - theta_curr; // Jacobian adjustment - if (std::log(runif(rng_)) < ln_alpha) { + if (MY_LOG(runif(rng_)) < ln_alpha) { double omega_ii = precision_matrix_(i, i); precision_matrix_(i, i) = precision_proposal_(i, i); cholesky_update_after_diag(omega_ii, i); @@ -255,12 +247,12 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { // } - ln_alpha += std::log(1.0 - inclusion_probability_(i, j)) - std::log(inclusion_probability_(i, j)); + ln_alpha += MY_LOG(1.0 - inclusion_probability_(i, j)) - MY_LOG(inclusion_probability_(i, j)); - ln_alpha += R::dnorm(precision_matrix_(i, j) / constants_[3], 0.0, proposal_sd, true) - std::log(constants_[3]); + ln_alpha += R::dnorm(precision_matrix_(i, j) / constants_[3], 0.0, proposal_sd, true) - MY_LOG(constants_[3]); ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, pairwise_scale_, true); - if (std::log(runif(rng_)) < ln_alpha) { + if (MY_LOG(runif(rng_)) < ln_alpha) { // Store old values for Cholesky update double omega_ij_old = precision_matrix_(i, j); @@ -304,15 +296,15 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; // } // } - ln_alpha += std::log(inclusion_probability_(i, j)) - std::log(1.0 - inclusion_probability_(i, j)); + ln_alpha += MY_LOG(inclusion_probability_(i, j)) - MY_LOG(1.0 - inclusion_probability_(i, j)); // Prior change: add slab (Cauchy prior) ln_alpha += R::dcauchy(omega_prop_ij, 0.0, pairwise_scale_, true); // Proposal term: proposed edge value given it was generated from truncated normal - ln_alpha -= R::dnorm(omega_prop_ij / constants_[3], 0.0, proposal_sd, true) - std::log(constants_[3]); + ln_alpha -= R::dnorm(omega_prop_ij / constants_[3], 0.0, proposal_sd, true) - MY_LOG(constants_[3]); - if (std::log(runif(rng_)) < ln_alpha) { + if (MY_LOG(runif(rng_)) < ln_alpha) { // Accept: turn ON the edge // Store old values for Cholesky update double omega_ij_old = precision_matrix_(i, j); diff --git a/src/models/ggm/ggm_model.h b/src/models/ggm/ggm_model.h index 235998a6..ff9acff0 100644 --- a/src/models/ggm/ggm_model.h +++ b/src/models/ggm/ggm_model.h @@ -3,6 +3,7 @@ #include #include #include "models/base_model.h" +#include "math/cholesky_helpers.h" #include "rng/rng_utils.h" @@ -153,7 +154,7 @@ class GGMModel : public BaseModel { /** * Enable or disable edge-selection proposals. - * @param active true to enable edge birth/death moves + * @param active true to enable edge add-delete moves */ void set_edge_selection_active(bool active) override { edge_selection_active_ = active; @@ -179,7 +180,7 @@ class GGMModel : public BaseModel { * Perform one full Metropolis sweep. * * Iterates over all off-diagonal entries (edge updates), all diagonal - * entries, and (when active) all edge indicator birth/death moves. + * entries, and (when active) all edge indicator add-delete moves. * * @param iteration Current iteration index (for Robbins-Monro adaptation) */ @@ -275,7 +276,7 @@ class GGMModel : public BaseModel { arma::mat inclusion_probability_; /// Whether the model was constructed with edge selection. bool edge_selection_; - /// Whether edge birth/death proposals are currently active. + /// Whether edge add-delete proposals are currently active. bool edge_selection_active_ = false; /// Scale parameter of the Cauchy slab prior on off-diagonal elements. double pairwise_scale_; @@ -364,7 +365,7 @@ class GGMModel : public BaseModel { void update_diagonal_parameter(size_t i, int iteration); /** - * Reversible-jump birth/death move for an edge indicator. + * Metropolis-Hastings add-delete move for an edge indicator. * * If the edge is on, proposes deletion; if off, proposes a new value * from a scaled normal. Acceptance combines the likelihood ratio, @@ -388,17 +389,7 @@ class GGMModel : public BaseModel { */ void get_constants(size_t i, size_t j); - /** - * Compute the (ii, jj) entry of the Schur complement of A with - * row/column i eliminated. - * - * @param A Symmetric matrix - * @param i Index of the eliminated row/column - * @param ii Row index of the desired entry - * @param jj Column index of the desired entry - * @return A(ii,jj) - A(ii,i) * A(jj,i) / A(i,i) - */ - double compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const; + /** * Return the diagonal value omega_jj required to keep the precision @@ -434,12 +425,7 @@ class GGMModel : public BaseModel { */ double log_density_impl_diag(size_t j) const; - /** - * Compute log-determinant from a triangular matrix: 2 * sum(log(diag)). - * - * @param triangular_A Upper-triangular Cholesky factor - */ - double get_log_det(arma::mat triangular_A) const; + /** * Update the Cholesky factor after changing an off-diagonal element. diff --git a/src/models/mixed/mixed_mrf_gradient.cpp b/src/models/mixed/mixed_mrf_gradient.cpp new file mode 100644 index 00000000..c4119afc --- /dev/null +++ b/src/models/mixed/mixed_mrf_gradient.cpp @@ -0,0 +1,585 @@ +#include +#include "models/mixed/mixed_mrf_model.h" +#include "utils/variable_helpers.h" +#include "math/explog_macros.h" + + +// ============================================================================= +// Gradient cache +// ============================================================================= +// The gradient cache stores precomputed index mappings and observed-statistic +// contributions that do not change during a leapfrog trajectory. It is +// invalidated whenever edge indicators change (same pattern as the OMRF). +// ============================================================================= + +void MixedMRFModel::ensure_gradient_cache() { + if(gradient_cache_valid_) return; + + // --- Build index matrix for pairwise_effects_discrete_ upper-triangular entries --- + // Maps (i, j) to a position in the flat gradient vector (offset from + // the start of pairwise_discrete entries, which sits at num_main_). + kxx_index_cache_.set_size(p_, p_); + kxx_index_cache_.zeros(); + + int num_active_kxx = 0; + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(edge_indicators_(i, j) == 1) { + kxx_index_cache_(i, j) = num_main_ + num_active_kxx; + kxx_index_cache_(j, i) = kxx_index_cache_(i, j); + num_active_kxx++; + } + } + } + + // --- Build index matrix for pairwise_effects_cross_ entries --- + // Maps (i, j) to a position in the flat gradient vector (offset from + // the start of pairwise_cross entries, which sits at num_main_ + active_kxx + q). + kxy_index_cache_.set_size(p_, q_); + kxy_index_cache_.zeros(); + + int kxy_offset = num_main_ + num_active_kxx + static_cast(q_); + int num_active_kxy = 0; + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(i, p_ + j) == 1) { + kxy_index_cache_(i, j) = kxy_offset + num_active_kxy; + num_active_kxy++; + } + } + } + + // main_effects_continuous_ offset in gradient vector + main_effects_continuous_grad_offset_ = num_main_ + num_active_kxx; + + // --- Precompute observed statistics portion of the gradient --- + size_t active_dim = num_main_ + num_active_kxx + q_ + num_active_kxy; + grad_obs_cache_.set_size(active_dim); + grad_obs_cache_.zeros(); + + // Observed statistics for discrete main effects + int offset = 0; + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + int C_s = num_categories_(s); + for(int c = 0; c < C_s; ++c) { + grad_obs_cache_(offset + c) = counts_per_category_(c + 1, s); + } + offset += C_s; + } else { + grad_obs_cache_(offset) = blume_capel_stats_(0, s); + grad_obs_cache_(offset + 1) = blume_capel_stats_(1, s); + offset += 2; + } + } + + // Observed statistics for pairwise_effects_discrete_ edges + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(edge_indicators_(i, j) == 0) continue; + int loc = kxx_index_cache_(i, j); + // Factor 2 from the symmetric double-count in the pseudo-likelihood + grad_obs_cache_(loc) = 2.0 * arma::dot( + discrete_observations_dbl_.col(i), + discrete_observations_dbl_.col(j) + ); + } + } + + // No precomputed observed stats for means or cross effects — those depend on + // continuous_observations_ combined with current parameters, so they + // are computed fresh each logp_and_gradient call. + + // Cache transpose of discrete observations for vectorized pairwise gradient + discrete_observations_dbl_t_ = discrete_observations_dbl_.t(); + + gradient_cache_valid_ = true; +} + + +void MixedMRFModel::invalidate_gradient_cache() { + gradient_cache_valid_ = false; +} + + +// ============================================================================= +// Unvectorize NUTS parameters into temporaries +// ============================================================================= +// Unpacks a NUTS-dimension parameter vector into temporary matrices without +// mutating model state. Used during leapfrog trajectory evaluation. +// ============================================================================= + +void MixedMRFModel::unvectorize_nuts_to_temps( + const arma::vec& params, + arma::mat& temp_main_discrete, + arma::mat& temp_pairwise_discrete, + arma::vec& temp_main_continuous, + arma::mat& temp_pairwise_cross +) const { + size_t idx = 0; + + // 1. main_effects_discrete_ + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) { + temp_main_discrete(s, c) = params(idx++); + } + } else { + temp_main_discrete(s, 0) = params(idx++); + temp_main_discrete(s, 1) = params(idx++); + } + } + + // 2. pairwise_effects_discrete_ upper-triangular (active only) + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(edge_indicators_(i, j) == 1) { + temp_pairwise_discrete(i, j) = params(idx++); + temp_pairwise_discrete(j, i) = temp_pairwise_discrete(i, j); + } + } + } + + // 3. main_effects_continuous_ + for(size_t j = 0; j < q_; ++j) { + temp_main_continuous(j) = params(idx++); + } + + // 4. pairwise_effects_cross_ row-major (active only) + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(i, p_ + j) == 1) { + temp_pairwise_cross(i, j) = params(idx++); + } + } + } +} + + +// ============================================================================= +// gradient +// ============================================================================= + +arma::vec MixedMRFModel::gradient(const arma::vec& parameters) { + auto [logp, grad] = logp_and_gradient(parameters); + return grad; +} + + +// ============================================================================= +// logp_and_gradient — conditional pseudo-likelihood +// ============================================================================= +// Computes the log pseudo-posterior and its gradient with respect to the +// NUTS parameters (μ_x, K_xx, μ_y, K_xy). K_yy is treated as fixed. +// +// The pseudo-log-posterior is: +// l(θ) = sum_s log p(x_s | x_{-s}, y) [OMRF conditionals] +// + log p(y | x) [GGM conditional] +// + log π(θ) [priors] +// +// For marginal PL, the OMRF conditionals use Θ = K_xx + 2 K_xy Σ_yy K_xy' +// instead of K_xx directly, and derive rest scores and denominator offsets +// from Θ. The GGM conditional is the same in both modes. +// ============================================================================= + +std::pair MixedMRFModel::logp_and_gradient( + const arma::vec& parameters) +{ + ensure_gradient_cache(); + + // --- Unvectorize into temporaries --- + arma::mat temp_main_discrete = main_effects_discrete_; + arma::mat temp_pairwise_discrete = pairwise_effects_discrete_; + arma::vec temp_main_continuous = main_effects_continuous_; + arma::mat temp_pairwise_cross = pairwise_effects_cross_; + unvectorize_nuts_to_temps(parameters, temp_main_discrete, temp_pairwise_discrete, temp_main_continuous, temp_pairwise_cross); + + // --- Derived quantities --- + // Conditional mean: M_i = μ_y' + 2 x_i' K_xy Σ_yy (n x q) + arma::mat temp_cond_mean = arma::repmat(temp_main_continuous.t(), n_, 1) + + 2.0 * discrete_observations_dbl_ * temp_pairwise_cross * covariance_continuous_; + + // Residual: D = Y - M (n x q) + arma::mat D = continuous_observations_ - temp_cond_mean; + + // Theta for marginal PL + arma::mat temp_Theta; + if(use_marginal_pl_) { + temp_Theta = temp_pairwise_discrete + 2.0 * temp_pairwise_cross * covariance_continuous_ * temp_pairwise_cross.t(); + } + + // Start gradient from observed-statistics cache + arma::vec grad = grad_obs_cache_; + + double logp = 0.0; + + // For marginal PL: precompute K_xy Σ_yy (used in cross-contributions) + arma::mat cross_times_cov; // p x q + if(use_marginal_pl_) { + cross_times_cov = temp_pairwise_cross * covariance_continuous_; + } + + // ========================================================================= + // Part 1: OMRF conditionals + // ========================================================================= + + int main_effects_discrete_offset = 0; + for(size_t s = 0; s < p_; ++s) { + int C_s = num_categories_(s); + + // --- Rest score for variable s --- + arma::vec rest; + if(use_marginal_pl_) { + // Marginal: Θ-based rest + K_xy μ_y bias + double theta_ss = temp_Theta(s, s); + rest = discrete_observations_dbl_ * temp_Theta.col(s) + - discrete_observations_dbl_.col(s) * theta_ss + + 2.0 * arma::dot(temp_pairwise_cross.row(s), temp_main_continuous); + } else { + // Conditional: K_xx-based rest + 2 K_xy y + rest = discrete_observations_dbl_ * temp_pairwise_discrete.col(s) + - discrete_observations_dbl_.col(s) * temp_pairwise_discrete(s, s) + + 2.0 * continuous_observations_ * temp_pairwise_cross.row(s).t(); + } + + if(is_ordinal_variable_(s)) { + arma::vec main_param = temp_main_discrete.row(s).cols(0, C_s - 1).t(); + + // Marginal PL: absorb Theta_ss into main_param + if(use_marginal_pl_) { + double theta_ss = temp_Theta(s, s); + for(int c = 0; c < C_s; ++c) { + main_param(c) += static_cast((c + 1) * (c + 1)) * theta_ss; + } + } + + // bound = per-observation upper bound on log-scores for numerical + // stability. Must cover max_c(main_param(c) + (c+1)*rest(i)). + // The highest-category term main_param(C_s-1) + C_s*rest dominates + // when rest > 0; category 0 (score = 0) dominates when rest << 0. + arma::vec bound = main_param(C_s - 1) + static_cast(C_s) * rest; + bound = arma::max(bound, arma::zeros(bound.n_elem)); + + LogZAndProbs result = compute_logZ_and_probs_ordinal( + main_param, rest, bound, C_s + ); + + // log pseudo-posterior contribution + logp -= arma::accu(result.log_Z); + + // Main-effect gradient: ∂/∂main_effects_discrete_{s,c} = count_c - sum_i prob(c) + for(int c = 0; c < C_s; ++c) { + grad(main_effects_discrete_offset + c) -= arma::accu(result.probs.col(c + 1)); + } + + // Expected value E_s[c+1|rest] per observation + arma::vec weights = arma::regspace(1, C_s); + arma::vec E = result.probs.cols(1, C_s) * weights; + + // Pairwise discrete gradient: sum_i x_{i,t} * (x_{i,s}+1 - E_s) + // (uses pre-transposed discrete observations for BLAS efficiency) + arma::vec pw_grad = discrete_observations_dbl_t_ * E; + for(size_t t = 0; t < p_; ++t) { + if(edge_indicators_(s, t) == 0 || s == t) continue; + int loc = (s < t) ? kxx_index_cache_(s, t) : kxx_index_cache_(t, s); + grad(loc) -= pw_grad(t); + } + + if(use_marginal_pl_) { + // Additional pairwise_discrete gradient from Θ_ss in denominator: + // ∂/∂pairwise_effects_discrete_{st} through Θ_ss: zero (∂Θ_ss/∂pairwise_effects_discrete_st = δ_{st}) + // So pairwise_discrete gradient from Θ rest scores is already handled above. + + // Pairwise_cross gradient from marginal OMRF (through Θ): + // ∂Theta_{st}/∂pairwise_effects_cross_{a,j} has two terms: + // = 2 [Σyy pairwise_effects_cross_t']_j δ_{as} + 2 [pairwise_effects_cross_s Σyy]_j δ_{at} + // Self-contribution (a=s): from rest_s → pairwise_effects_cross_s + // Cross-contribution (a=t): from rest_s → pairwise_effects_cross_t for each t≠s + + arma::vec weights_sq = arma::square(weights); + arma::vec E_sq = result.probs.cols(1, C_s) * weights_sq; + + arma::vec diff_pw = discrete_observations_dbl_t_ * + (discrete_observations_dbl_.col(s) - E); + diff_pw(s) = 0.0; + + double diff_diag = arma::dot( + discrete_observations_dbl_.col(s), + discrete_observations_dbl_.col(s)) - arma::accu(E_sq); + + double sum_obs_minus_E = arma::accu(discrete_observations_dbl_.col(s)) - arma::accu(E); + + // Self-contribution: a = s + // Off-diagonal Theta: ∂Θ_{st}/∂pairwise_effects_cross_{s,j} = 2 [Σyy pairwise_effects_cross_t']_j + // Diagonal Theta: ∂Θ_{ss}/∂pairwise_effects_cross_{s,j} = 4 [Σyy pairwise_effects_cross_s']_j + // Rest-score bias: ∂(2 pairwise_effects_cross_s μy)/∂pairwise_effects_cross_{s,j} = 2 μy_j + arma::rowvec kxy_self = 2.0 * (diff_pw.t() * temp_pairwise_cross) * covariance_continuous_ + + 4.0 * diff_diag * cross_times_cov.row(s) + + 2.0 * sum_obs_minus_E * temp_main_continuous.t(); + + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(s, p_ + j) == 0) continue; + int loc = kxy_index_cache_(s, j); + grad(loc) += kxy_self(j); + } + + // Cross-contribution: a = t, for each t ≠ s + // ∂l_s/∂pairwise_effects_cross_{t,:} = diff_pw(t) * 2 * pairwise_effects_cross_s * Σyy + arma::rowvec V_s = 2.0 * cross_times_cov.row(s); // 2 K_xy_s Σ_yy + for(size_t t = 0; t < p_; ++t) { + if(t == s || std::abs(diff_pw(t)) < 1e-300) continue; + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(t, p_ + j) == 0) continue; + int loc = kxy_index_cache_(t, j); + grad(loc) += diff_pw(t) * V_s(j); + } + } + + // Continuous mean gradient from marginal OMRF: + // ∂l_s/∂main_effects_continuous_j = 2 pairwise_effects_cross_{sj} * sum_i (x_{is} - E_s) + for(size_t j = 0; j < q_; ++j) { + grad(main_effects_continuous_grad_offset_ + j) += 2.0 * temp_pairwise_cross(s, j) * sum_obs_minus_E; + } + } else { + // Conditional PL: pairwise_cross gradient from OMRF rest score + // ∂/∂pairwise_effects_cross_{s,j} = 2 sum_i y_{ij} (x_{is}+1 - E_s) + arma::rowvec kxy_grad_s = 2.0 * ( + discrete_observations_dbl_.col(s) - E + ).t() * continuous_observations_; + + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(s, p_ + j) == 0) continue; + int loc = kxy_index_cache_(s, j); + grad(loc) += kxy_grad_s(j); + } + } + + main_effects_discrete_offset += C_s; + } else { + // --- Blume-Capel variable --- + int ref = baseline_category_(s); + double lin_eff = temp_main_discrete(s, 0); + double quad_eff = temp_main_discrete(s, 1); + + // Marginal PL: absorb Theta_ss into quadratic effect + double effective_quad = quad_eff; + if(use_marginal_pl_) { + effective_quad += temp_Theta(s, s); + } + + arma::vec bc_bound; + LogZAndProbs result = compute_logZ_and_probs_blume_capel( + rest, lin_eff, effective_quad, ref, C_s, bc_bound + ); + + logp -= arma::accu(result.log_Z); + + arma::vec score = arma::regspace(0, C_s) - static_cast(ref); + arma::vec sq_score = arma::square(score); + + // Main-effect gradient + grad(main_effects_discrete_offset) -= arma::accu(result.probs * score); + grad(main_effects_discrete_offset + 1) -= arma::accu(result.probs * sq_score); + + // Expected score per person + arma::vec E = result.probs * score; + + // Pairwise discrete gradient + arma::vec pw_grad = discrete_observations_dbl_t_ * E; + for(size_t t = 0; t < p_; ++t) { + if(edge_indicators_(s, t) == 0 || s == t) continue; + int loc = (s < t) ? kxx_index_cache_(s, t) : kxx_index_cache_(t, s); + grad(loc) -= pw_grad(t); + } + + if(use_marginal_pl_) { + // Pairwise_cross gradient from marginal OMRF (same structure as ordinal) + arma::vec E_sq = result.probs * sq_score; + + arma::vec diff_pw = discrete_observations_dbl_t_ * + (discrete_observations_dbl_.col(s) - E); + diff_pw(s) = 0.0; + + double diff_diag = arma::dot( + discrete_observations_dbl_.col(s), + discrete_observations_dbl_.col(s)) - arma::accu(E_sq); + + double sum_obs_minus_E = arma::accu(discrete_observations_dbl_.col(s)) - arma::accu(E); + + // Self-contribution: a = s + arma::rowvec kxy_self = 2.0 * (diff_pw.t() * temp_pairwise_cross) * covariance_continuous_ + + 4.0 * diff_diag * cross_times_cov.row(s) + + 2.0 * sum_obs_minus_E * temp_main_continuous.t(); + + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(s, p_ + j) == 0) continue; + int loc = kxy_index_cache_(s, j); + grad(loc) += kxy_self(j); + } + + // Cross-contribution: a = t, for each t ≠ s + arma::rowvec V_s = 2.0 * cross_times_cov.row(s); + for(size_t t = 0; t < p_; ++t) { + if(t == s || std::abs(diff_pw(t)) < 1e-300) continue; + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(t, p_ + j) == 0) continue; + int loc = kxy_index_cache_(t, j); + grad(loc) += diff_pw(t) * V_s(j); + } + } + + // Continuous mean gradient from marginal OMRF + for(size_t j = 0; j < q_; ++j) { + grad(main_effects_continuous_grad_offset_ + j) += 2.0 * temp_pairwise_cross(s, j) * sum_obs_minus_E; + } + } else { + // Conditional PL: pairwise_cross gradient from OMRF rest score + arma::rowvec kxy_grad_s = 2.0 * ( + discrete_observations_dbl_.col(s) - E + ).t() * continuous_observations_; + + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(s, p_ + j) == 0) continue; + int loc = kxy_index_cache_(s, j); + grad(loc) += kxy_grad_s(j); + } + } + + main_effects_discrete_offset += 2; + } + } + + // Add numerator contribution to logp from discrete sufficient statistics + // (already in grad_obs_cache_ as counts, but logp needs the actual dot-products) + main_effects_discrete_offset = 0; + for(size_t s = 0; s < p_; ++s) { + int C_s = num_categories_(s); + arma::vec rest; + if(use_marginal_pl_) { + double theta_ss = temp_Theta(s, s); + rest = discrete_observations_dbl_ * temp_Theta.col(s) + - discrete_observations_dbl_.col(s) * theta_ss + + 2.0 * arma::dot(temp_pairwise_cross.row(s), temp_main_continuous); + // Theta_ss quadratic contribution + logp += temp_Theta(s, s) * arma::dot( + discrete_observations_dbl_.col(s), + discrete_observations_dbl_.col(s)); + } else { + rest = discrete_observations_dbl_ * temp_pairwise_discrete.col(s) + - discrete_observations_dbl_.col(s) * temp_pairwise_discrete(s, s) + + 2.0 * continuous_observations_ * temp_pairwise_cross.row(s).t(); + } + // Numerator: dot(x_s, rest) + main-effect sums + logp += arma::dot(discrete_observations_dbl_.col(s), rest); + + if(is_ordinal_variable_(s)) { + for(int c = 1; c <= C_s; ++c) { + logp += static_cast(counts_per_category_(c, s)) * temp_main_discrete(s, c - 1); + } + } else { + logp += temp_main_discrete(s, 0) * static_cast(blume_capel_stats_(0, s)) + + temp_main_discrete(s, 1) * static_cast(blume_capel_stats_(1, s)); + } + } + + // ========================================================================= + // Part 2: GGM conditional log-likelihood and gradients + // ========================================================================= + // log p(y | x) = n/2 (log|K_yy| - q log(2π)) - ½ trace(K_yy D'D) + // where D = Y - M, M_i = μ_y' + 2 x_i' K_xy Σ_yy + // + // K_yy is fixed, so log|K_yy| contributes to logp but not gradient. + + double quad_sum = arma::accu((D * pairwise_effects_continuous_) % D); + logp += static_cast(n_) / 2.0 * + (-static_cast(q_) * MY_LOG(2.0 * arma::datum::pi) + + log_det_precision_) + - quad_sum / 2.0; + + // ∂/∂μ_y: K_yy D' 1_n = K_yy sum_over_rows(D) + arma::vec D_colsums = arma::sum(D, 0).t(); // q-vector + arma::vec grad_main_effects_continuous_ggm = pairwise_effects_continuous_ * D_colsums; + + for(size_t j = 0; j < q_; ++j) { + grad(main_effects_continuous_grad_offset_ + j) += grad_main_effects_continuous_ggm(j); + } + + // ∂/∂K_xy: The GGM conditional depends on K_xy through M. + // ∂M/∂pairwise_effects_cross_{s,j} = 2 x_s [Σ_yy]_{j,:} + // ∂logp_ggm/∂K_xy = 2 X' D (shortcut: K_yy Σ_yy = I eliminates K_yy) + // + // Correctly: ∂(−½ trace(K_yy D'D))/∂pairwise_effects_cross_{s,j} + // = trace(K_yy D' ∂M/∂pairwise_effects_cross_{s,j}) + // = trace(K_yy D' · 2 x_s [Σ_yy]_{j,:}) + // = 2 [x_s' D K_yy Σ_yy]_j + // = 2 [x_s' D]_j (since K_yy Σ_yy = I) + arma::mat grad_pairwise_effects_cross_ggm = 2.0 * discrete_observations_dbl_t_ * D; // p x q + + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(i, p_ + j) == 0) continue; + int loc = kxy_index_cache_(i, j); + grad(loc) += grad_pairwise_effects_cross_ggm(i, j); + } + } + + // ========================================================================= + // Part 3: Prior log-densities and gradient contributions + // ========================================================================= + + // --- main_effects_discrete_ priors: Beta(alpha, beta) on sigmoid scale --- + main_effects_discrete_offset = 0; + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + int C_s = num_categories_(s); + for(int c = 0; c < C_s; ++c) { + double val = temp_main_discrete(s, c); + logp += val * main_alpha_ - + std::log1p(MY_EXP(val)) * (main_alpha_ + main_beta_); + double p = 1.0 / (1.0 + MY_EXP(-val)); + grad(main_effects_discrete_offset + c) += main_alpha_ - (main_alpha_ + main_beta_) * p; + } + main_effects_discrete_offset += C_s; + } else { + for(int k = 0; k < 2; ++k) { + double val = temp_main_discrete(s, k); + logp += val * main_alpha_ - + std::log1p(MY_EXP(val)) * (main_alpha_ + main_beta_); + double p = 1.0 / (1.0 + MY_EXP(-val)); + grad(main_effects_discrete_offset + k) += main_alpha_ - (main_alpha_ + main_beta_) * p; + } + main_effects_discrete_offset += 2; + } + } + + // --- pairwise_effects_discrete_ priors: Cauchy(0, pairwise_scale_) --- + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(edge_indicators_(i, j) == 0) continue; + int loc = kxx_index_cache_(i, j); + double val = temp_pairwise_discrete(i, j); + logp += R::dcauchy(val, 0.0, pairwise_scale_, true); + grad(loc) -= 2.0 * val / (val * val + pairwise_scale_ * pairwise_scale_); + } + } + + // --- main_effects_continuous_ priors: Normal(0, 1) --- + for(size_t j = 0; j < q_; ++j) { + double val = temp_main_continuous(j); + logp += R::dnorm(val, 0.0, 1.0, true); + grad(main_effects_continuous_grad_offset_ + j) -= val; // -val from -val^2/2 + } + + // --- pairwise_effects_cross_ priors: Cauchy(0, pairwise_scale_) --- + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(edge_indicators_(i, p_ + j) == 0) continue; + int loc = kxy_index_cache_(i, j); + double val = temp_pairwise_cross(i, j); + logp += R::dcauchy(val, 0.0, pairwise_scale_, true); + grad(loc) -= 2.0 * val / (val * val + pairwise_scale_ * pairwise_scale_); + } + } + + return {logp, grad}; +} diff --git a/src/models/mixed/mixed_mrf_likelihoods.cpp b/src/models/mixed/mixed_mrf_likelihoods.cpp new file mode 100644 index 00000000..fabb80a1 --- /dev/null +++ b/src/models/mixed/mixed_mrf_likelihoods.cpp @@ -0,0 +1,140 @@ +#include +#include "models/mixed/mixed_mrf_model.h" +#include "utils/variable_helpers.h" +#include "math/explog_macros.h" + + +// ============================================================================= +// log_conditional_omrf +// ============================================================================= +// Conditional OMRF pseudolikelihood for discrete variable s: +// log f(x_s | x_{-s}, y) = numerator - sum_v log(Z_v) +// +// Branches on is_ordinal_variable_(s) to select ordinal thresholds or +// Blume-Capel (linear + quadratic) main effects. +// ============================================================================= + +double MixedMRFModel::log_conditional_omrf(int s) const { + int C_s = num_categories_(s); + + // Rest score: contribution from other discrete vars + continuous vars + arma::vec rest = discrete_observations_dbl_ * pairwise_effects_discrete_.col(s) + - discrete_observations_dbl_.col(s) * pairwise_effects_discrete_(s, s) + + 2.0 * continuous_observations_ * pairwise_effects_cross_.row(s).t(); + + // Numerator (sufficient-statistic form): dot(x_s, rest) + main-effect sums + double numer = arma::dot(discrete_observations_dbl_.col(s), rest); + + if(is_ordinal_variable_(s)) { + // Ordinal: add threshold contributions sum_{c=1}^{C_s} count_c * main_effects_discrete_(s, c-1) + for(int c = 1; c <= C_s; ++c) { + numer += static_cast(counts_per_category_(c, s)) * main_effects_discrete_(s, c - 1); + } + + // Denominator via compute_denom_ordinal (FAST/SAFE block-split) + arma::vec main_param = main_effects_discrete_.row(s).cols(0, C_s - 1).t(); + arma::vec bound = static_cast(C_s) * rest; + arma::vec denom = compute_denom_ordinal(rest, main_param, bound); + + return numer - arma::accu(bound + ARMA_MY_LOG(denom)); + } else { + // Blume-Capel: alpha * sum(x) + beta * sum(x^2) + double alpha = main_effects_discrete_(s, 0); + double beta = main_effects_discrete_(s, 1); + numer += alpha * static_cast(blume_capel_stats_(0, s)) + + beta * static_cast(blume_capel_stats_(1, s)); + + // Denominator via compute_denom_blume_capel (computes bound internally) + arma::vec bound; + arma::vec denom = compute_denom_blume_capel( + rest, alpha, beta, baseline_category_(s), C_s, bound + ); + + return numer - arma::accu(bound + ARMA_MY_LOG(denom)); + } +} + + +// ============================================================================= +// log_marginal_omrf +// ============================================================================= +// Marginal OMRF pseudolikelihood for discrete variable s: +// log f(x_s | x_{-s}) using Θ = K_xx + 2 K_xy K_yy^{-1} K_xy' +// +// Differs from conditional form in three ways: +// 1. rest score uses Theta_ instead of pairwise_effects_discrete_, minus self-interaction +// 2. scalar bias 2 K_xy(s,:) μ_y added to rest +// 3. numerator includes Θ(s,s) * sum(x_s^2) +// 4. denominator offsets include c^2 * Θ(s,s) +// ============================================================================= + +double MixedMRFModel::log_marginal_omrf(int s) const { + int C_s = num_categories_(s); + + // Rest score: Θ-based interaction + K_xy μ_y bias + double theta_ss = Theta_(s, s); + arma::vec rest = discrete_observations_dbl_ * Theta_.col(s) + - discrete_observations_dbl_.col(s) * theta_ss + + 2.0 * arma::dot(pairwise_effects_cross_.row(s), main_effects_continuous_); + + // Numerator: dot(x_s, rest) + theta_ss * dot(x_s, x_s) + main effects + double numer = arma::dot(discrete_observations_dbl_.col(s), rest) + + theta_ss * arma::dot(discrete_observations_dbl_.col(s), + discrete_observations_dbl_.col(s)); + + if(is_ordinal_variable_(s)) { + for(int c = 1; c <= C_s; ++c) { + numer += static_cast(counts_per_category_(c, s)) * main_effects_discrete_(s, c - 1); + } + + // Denominator: main_param(c) = μ_x(s,c) + (c+1)^2 Θ_ss + arma::vec main_param(C_s); + for(int c = 0; c < C_s; ++c) { + main_param(c) = main_effects_discrete_(s, c) + static_cast((c + 1) * (c + 1)) * theta_ss; + } + + arma::vec bound = static_cast(C_s) * rest; + arma::vec denom = compute_denom_ordinal(rest, main_param, bound); + + return numer - arma::accu(bound + ARMA_MY_LOG(denom)); + } else { + // Blume-Capel: alpha * sum(x) + beta * sum(x^2) + double alpha = main_effects_discrete_(s, 0); + double beta = main_effects_discrete_(s, 1); + numer += alpha * static_cast(blume_capel_stats_(0, s)) + + beta * static_cast(blume_capel_stats_(1, s)); + + // Denominator: theta_c includes Theta_(s,s) * (c - ref)^2 + int ref = baseline_category_(s); + double effective_beta = beta + theta_ss; + + arma::vec bound; + arma::vec denom = compute_denom_blume_capel( + rest, alpha, effective_beta, ref, C_s, bound + ); + + return numer - arma::accu(bound + ARMA_MY_LOG(denom)); + } +} + + +// ============================================================================= +// log_conditional_ggm +// ============================================================================= +// Conditional GGM log-likelihood: log f(y | x) +// y | x ~ N(conditional_mean_, covariance_continuous_) +// +// Uses cached covariance_continuous_, log_det_precision_, and conditional_mean_. +// ============================================================================= + +double MixedMRFModel::log_conditional_ggm() const { + arma::mat D = continuous_observations_ - conditional_mean_; + + // Quadratic form: trace(K_yy D'D) = sum((D K_yy) .* D) + double quad_sum = arma::accu((D * pairwise_effects_continuous_) % D); + + return static_cast(n_) / 2.0 * + (-static_cast(q_) * MY_LOG(2.0 * arma::datum::pi) + + log_det_precision_) + - quad_sum / 2.0; +} diff --git a/src/models/mixed/mixed_mrf_metropolis.cpp b/src/models/mixed/mixed_mrf_metropolis.cpp new file mode 100644 index 00000000..4cf945e0 --- /dev/null +++ b/src/models/mixed/mixed_mrf_metropolis.cpp @@ -0,0 +1,809 @@ +#include +#include "models/mixed/mixed_mrf_model.h" +#include "rng/rng_utils.h" +#include "mcmc/execution/step_result.h" +#include "math/explog_macros.h" + + +// ============================================================================= +// Beta-type prior used for all main effects (ordinal thresholds and BC α/β). +// Matches OMRFModel::log_pseudoposterior_main_component. +// ============================================================================= + +static double log_beta_prior(double x, double alpha, double beta) { + return x * alpha - std::log1p(MY_EXP(x)) * (alpha + beta); +} + + +// ============================================================================= +// update_main_effect +// ============================================================================= +// MH update for one main-effect parameter. +// Ordinal: main_effects_discrete_(s, c) = threshold for category c+1 (c in [0, C_s-1]) +// Blume-Capel: main_effects_discrete_(s, 0) = linear α, main_effects_discrete_(s, 1) = quadratic β +// (c indexes 0 or 1 for BC) +// +// The accept/reject uses log_conditional_omrf(s) + beta-type prior. +// ============================================================================= + +void MixedMRFModel::update_main_effect(int s, int c, int iteration) { + double& current = main_effects_discrete_(s, c); + double proposal_sd = proposal_sd_main_discrete_(s, c); + + double current_val = current; + double proposed = rnorm(rng_, current_val, proposal_sd); + + // Current log-posterior + double ll_curr = (use_marginal_pl_ ? log_marginal_omrf(s) : log_conditional_omrf(s)) + + log_beta_prior(current_val, main_alpha_, main_beta_); + + // Proposed log-posterior + current = proposed; + double ll_prop = (use_marginal_pl_ ? log_marginal_omrf(s) : log_conditional_omrf(s)) + + log_beta_prior(proposed, main_alpha_, main_beta_); + + double ln_alpha = ll_prop - ll_curr; + + if(MY_LOG(runif(rng_)) >= ln_alpha) { + current = current_val; // reject + } + + if(iteration >= 1 && iteration < total_warmup_) { + double rm_weight = std::pow(iteration, -0.75); + proposal_sd_main_discrete_(s, c) = update_proposal_sd_with_robbins_monro( + proposal_sd_main_discrete_(s, c), ln_alpha, rm_weight, 0.44); + } +} + + +// ============================================================================= +// update_continuous_mean +// ============================================================================= +// MH update for one continuous mean parameter main_effects_continuous_(j). +// The accept/reject uses log_conditional_ggm() + Normal(0, 1) prior. +// Must save/restore conditional_mean_ around the proposal. +// ============================================================================= + +void MixedMRFModel::update_continuous_mean(int j, int iteration) { + double current_val = main_effects_continuous_(j); + double proposed = rnorm(rng_, current_val, proposal_sd_main_continuous_(j)); + + // Current log-posterior (Normal(0,1) prior: -x^2/2 up to constant) + double ll_curr = log_conditional_ggm() + R::dnorm(current_val, 0.0, 1.0, true); + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ll_curr += log_marginal_omrf(s); + } + + // Set proposed value and refresh conditional_mean_ + arma::mat cond_mean_saved = conditional_mean_; + main_effects_continuous_(j) = proposed; + recompute_conditional_mean(); + + double ll_prop = log_conditional_ggm() + R::dnorm(proposed, 0.0, 1.0, true); + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ll_prop += log_marginal_omrf(s); + } + + double ln_alpha = ll_prop - ll_curr; + + if(MY_LOG(runif(rng_)) >= ln_alpha) { + main_effects_continuous_(j) = current_val; // reject + conditional_mean_ = std::move(cond_mean_saved); + } + + if(iteration >= 1 && iteration < total_warmup_) { + double rm_weight = std::pow(iteration, -0.75); + proposal_sd_main_continuous_(j) = update_proposal_sd_with_robbins_monro( + proposal_sd_main_continuous_(j), ln_alpha, rm_weight, 0.44); + } +} + + +// ============================================================================= +// update_pairwise_discrete +// ============================================================================= +// MH update for one discrete-discrete interaction pairwise_effects_discrete_(i, j). +// Symmetric: sets both (i,j) and (j,i). +// Acceptance: log_conditional_omrf(i) + log_conditional_omrf(j) + Cauchy prior. +// ============================================================================= + +void MixedMRFModel::update_pairwise_discrete(int i, int j, int iteration) { + double current_val = pairwise_effects_discrete_(i, j); + double proposed = rnorm(rng_, current_val, proposal_sd_pairwise_discrete_(i, j)); + + // Current log-posterior + double ll_curr, ll_prop; + if(use_marginal_pl_) { + ll_curr = log_marginal_omrf(i) + log_marginal_omrf(j) + + R::dcauchy(current_val, 0.0, pairwise_scale_, true); + + pairwise_effects_discrete_(i, j) = proposed; + pairwise_effects_discrete_(j, i) = proposed; + recompute_Theta(); + + ll_prop = log_marginal_omrf(i) + log_marginal_omrf(j) + + R::dcauchy(proposed, 0.0, pairwise_scale_, true); + } else { + ll_curr = log_conditional_omrf(i) + log_conditional_omrf(j) + + R::dcauchy(current_val, 0.0, pairwise_scale_, true); + + pairwise_effects_discrete_(i, j) = proposed; + pairwise_effects_discrete_(j, i) = proposed; + + ll_prop = log_conditional_omrf(i) + log_conditional_omrf(j) + + R::dcauchy(proposed, 0.0, pairwise_scale_, true); + } + + double ln_alpha = ll_prop - ll_curr; + + if(MY_LOG(runif(rng_)) >= ln_alpha) { + pairwise_effects_discrete_(i, j) = current_val; // reject + pairwise_effects_discrete_(j, i) = current_val; + if(use_marginal_pl_) recompute_Theta(); + } + + if(iteration >= 1 && iteration < total_warmup_) { + double rm_weight = std::pow(iteration, -0.75); + proposal_sd_pairwise_discrete_(i, j) = update_proposal_sd_with_robbins_monro( + proposal_sd_pairwise_discrete_(i, j), ln_alpha, rm_weight, 0.44); + } +} + + +// ============================================================================= +// Rank-1 precision proposal helpers (permutation-free) +// ============================================================================= +// Direct analogs of GGMModel::get_constants / constrained_diagonal, +// operating on pairwise_effects_continuous_, cholesky_of_precision_, and covariance_continuous_. +// ============================================================================= + +void MixedMRFModel::get_precision_constants(int i, int j) { + double logdet = cholesky_helpers::get_log_det(cholesky_of_precision_); + + double log_adj_ii = logdet + MY_LOG(std::abs(covariance_continuous_(i, i))); + double log_adj_ij = logdet + MY_LOG(std::abs(covariance_continuous_(i, j))); + double log_adj_jj = logdet + MY_LOG(std::abs(covariance_continuous_(j, j))); + + double inv_sub_jj = cholesky_helpers::compute_inv_submatrix_i(covariance_continuous_, i, j, j); + double log_abs_inv_sub_jj = log_adj_ii + MY_LOG(std::abs(inv_sub_jj)); + + double Phi_q1q = (2 * std::signbit(covariance_continuous_(i, j)) - 1) * MY_EXP( + (log_adj_ij - (log_adj_jj + log_abs_inv_sub_jj) / 2) + ); + double Phi_q1q1 = MY_EXP((log_adj_jj - log_abs_inv_sub_jj) / 2); + + kyy_constants_[0] = Phi_q1q; + kyy_constants_[1] = Phi_q1q1; + kyy_constants_[2] = pairwise_effects_continuous_(i, j) - Phi_q1q * Phi_q1q1; + kyy_constants_[3] = Phi_q1q1; + kyy_constants_[4] = pairwise_effects_continuous_(j, j) - Phi_q1q * Phi_q1q; + kyy_constants_[5] = kyy_constants_[4] + + kyy_constants_[2] * kyy_constants_[2] / (kyy_constants_[3] * kyy_constants_[3]); +} + +double MixedMRFModel::precision_constrained_diagonal(double x) const { + if(x == 0.0) { + return kyy_constants_[5]; + } else { + double t = (x - kyy_constants_[2]) / kyy_constants_[3]; + return kyy_constants_[4] + t * t; + } +} + + +// ============================================================================= +// log_ggm_ratio_edge +// ============================================================================= +// Log-likelihood ratio for a rank-2 off-diagonal precision change using the +// matrix determinant lemma for the log-det part and Woodbury for the +// quadratic-form part. Assumes precision_proposal_ is filled. +// +// TODO: replace the O(npq + nq²) quadratic-form computation with +// an O(nq) rank-2 shortcut. +// ============================================================================= + +double MixedMRFModel::log_ggm_ratio_edge(int i, int j) const { + size_t ui = static_cast(i); + size_t uj = static_cast(j); + + // --- Log-determinant ratio via matrix determinant lemma --- + // ΔΩ has 3 nonzero entries: (i,j), (j,i), (j,j). + // Ui = old - new off-diag, Uj = (old - new diag) / 2 + double Ui = pairwise_effects_continuous_(ui, uj) - precision_proposal_(ui, uj); + double Uj = (pairwise_effects_continuous_(uj, uj) - precision_proposal_(uj, uj)) / 2.0; + + double cc11 = covariance_continuous_(uj, uj); + double cc12 = 1.0 - (covariance_continuous_(ui, uj) * Ui + + covariance_continuous_(uj, uj) * Uj); + double cc22 = Ui * Ui * covariance_continuous_(ui, ui) + + 2.0 * Ui * Uj * covariance_continuous_(ui, uj) + + Uj * Uj * covariance_continuous_(uj, uj); + + double logdet_ratio = MY_LOG(std::abs(cc11 * cc22 - cc12 * cc12)); + + // --- Proposed covariance via Woodbury --- + // ΔΩ = vf1 vf2' + vf2 vf1' where vf1 = [0,...,-1,...] (j-th), + // vf2 = [0,...,Ui,...,Uj,...] (i-th and j-th). + // s1 = Σ vf1 = -Σ[:,j], s2 = Σ vf2 = Ui*Σ[:,i] + Uj*Σ[:,j] + arma::vec s1 = -covariance_continuous_.col(uj); + arma::vec s2 = Ui * covariance_continuous_.col(ui) + Uj * covariance_continuous_.col(uj); + + // 2×2 core matrix T = I + [vf2,vf1]' [s1,s2] + // T = [1 + vf2's1, vf2's2; vf1's1, 1 + vf1's2] + double t11 = 1.0 + Ui * s1(ui) + Uj * s1(uj); // 1 + vf2' s1 + double t12 = Ui * s2(ui) + Uj * s2(uj); // vf2' s2 + double t21 = -s1(uj); // vf1' s1 = Σ(j,j) + double t22 = 1.0 - s2(uj); // 1 + vf1' s2 + + double det_T = t11 * t22 - t12 * t21; + + // T^{-1} + double inv_t11 = t22 / det_T; + double inv_t12 = -t12 / det_T; + double inv_t21 = -t21 / det_T; + double inv_t22 = t11 / det_T; + + // Σ' = Σ - [s1,s2] T^{-1} [s2',s1'] + // = Σ - (inv_t11*s1 + inv_t21*s2)*s2' - (inv_t12*s1 + inv_t22*s2)*s1' + arma::vec w1 = inv_t11 * s1 + inv_t21 * s2; // coefficient for s2' row + arma::vec w2 = inv_t12 * s1 + inv_t22 * s2; // coefficient for s1' row + arma::mat cov_prop = covariance_continuous_ - w1 * s2.t() - w2 * s1.t(); + + // --- Proposed conditional mean --- + // M' = μ_y' + 2 X K_xy Σ' + arma::mat cond_mean_prop = arma::repmat(main_effects_continuous_.t(), n_, 1) + + 2.0 * discrete_observations_dbl_ * pairwise_effects_cross_ * cov_prop; + + // --- Quadratic form difference --- + arma::mat D_curr = continuous_observations_ - conditional_mean_; + arma::mat D_prop = continuous_observations_ - cond_mean_prop; + + double quad_curr = arma::accu((D_curr * pairwise_effects_continuous_) % D_curr); + double quad_prop = arma::accu((D_prop * precision_proposal_) % D_prop); + + double n = static_cast(n_); + return n / 2.0 * logdet_ratio - (quad_prop - quad_curr) / 2.0; +} + + +// ============================================================================= +// log_ggm_ratio_diag +// ============================================================================= +// Log-likelihood ratio for a rank-1 diagonal precision change. +// Same structure as log_ggm_ratio_edge but simpler (Ui = 0). +// ============================================================================= + +double MixedMRFModel::log_ggm_ratio_diag(int i) const { + size_t ui = static_cast(i); + + // --- Log-determinant ratio (rank-1) --- + double Uj = (pairwise_effects_continuous_(ui, ui) - precision_proposal_(ui, ui)) / 2.0; + + double cc11 = covariance_continuous_(ui, ui); + double cc12 = 1.0 - covariance_continuous_(ui, ui) * Uj; + double cc22 = Uj * Uj * covariance_continuous_(ui, ui); + + double logdet_ratio = MY_LOG(std::abs(cc11 * cc22 - cc12 * cc12)); + + // --- Proposed covariance via Sherman-Morrison (rank-1 special case) --- + // ΔΩ = -2Uj * e_i e_i', so Σ' = Σ + 2Uj * Σ[:,i] Σ[i,:]' / (1 - 2Uj * Σ(i,i)) + arma::vec s = covariance_continuous_.col(ui); + double denom = 1.0 - 2.0 * Uj * covariance_continuous_(ui, ui); + arma::mat cov_prop = covariance_continuous_ + (2.0 * Uj / denom) * s * s.t(); + + // --- Proposed conditional mean --- + arma::mat cond_mean_prop = arma::repmat(main_effects_continuous_.t(), n_, 1) + + 2.0 * discrete_observations_dbl_ * pairwise_effects_cross_ * cov_prop; + + // --- Quadratic form difference --- + arma::mat D_curr = continuous_observations_ - conditional_mean_; + arma::mat D_prop = continuous_observations_ - cond_mean_prop; + + double quad_curr = arma::accu((D_curr * pairwise_effects_continuous_) % D_curr); + double quad_prop = arma::accu((D_prop * precision_proposal_) % D_prop); + + double n = static_cast(n_); + return n / 2.0 * logdet_ratio - (quad_prop - quad_curr) / 2.0; +} + + +// ============================================================================= +// cholesky_update_after_precision_edge +// ============================================================================= +// Rank-2 Cholesky update after accepting an off-diagonal precision change. +// Decomposes ΔΩ = vf1*vf2' + vf2*vf1' into two rank-1 ops. +// Then recomputes inv_cholesky_of_precision_ and covariance_continuous_. +// ============================================================================= + +void MixedMRFModel::cholesky_update_after_precision_edge( + double old_ij, double old_jj, int i, int j) +{ + kyy_v2_[0] = old_ij - precision_proposal_(i, j); + kyy_v2_[1] = (old_jj - precision_proposal_(j, j)) / 2.0; + + kyy_vf1_[i] = kyy_v1_[0]; // 0 + kyy_vf1_[j] = kyy_v1_[1]; // -1 + kyy_vf2_[i] = kyy_v2_[0]; + kyy_vf2_[j] = kyy_v2_[1]; + + kyy_u1_ = (kyy_vf1_ + kyy_vf2_) / std::sqrt(2.0); + kyy_u2_ = (kyy_vf1_ - kyy_vf2_) / std::sqrt(2.0); + + cholesky_update(cholesky_of_precision_, kyy_u1_); + cholesky_downdate(cholesky_of_precision_, kyy_u2_); + + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_continuous_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); + log_det_precision_ = cholesky_helpers::get_log_det(cholesky_of_precision_); + + kyy_vf1_[i] = 0.0; + kyy_vf1_[j] = 0.0; + kyy_vf2_[i] = 0.0; + kyy_vf2_[j] = 0.0; +} + + +// ============================================================================= +// cholesky_update_after_precision_diag +// ============================================================================= +// Rank-1 Cholesky update after accepting a diagonal precision change. +// ============================================================================= + +void MixedMRFModel::cholesky_update_after_precision_diag(double old_ii, int i) { + double delta = old_ii - precision_proposal_(i, i); + bool downdate = delta > 0.0; + + kyy_vf1_[i] = std::sqrt(std::abs(delta)); + + if(downdate) + cholesky_downdate(cholesky_of_precision_, kyy_vf1_); + else + cholesky_update(cholesky_of_precision_, kyy_vf1_); + + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_continuous_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); + log_det_precision_ = cholesky_helpers::get_log_det(cholesky_of_precision_); + + kyy_vf1_[i] = 0.0; +} + + +// ============================================================================= +// update_pairwise_effects_continuous_offdiag +// ============================================================================= +// MH update for one off-diagonal element of the precision matrix pairwise_effects_continuous_(i, j). +// Uses rank-1 Cholesky infrastructure (GGM-style, no permutation): +// 1. Extract constants from covariance_continuous_ and cholesky_of_precision_ +// 2. Propose on the unconstrained Cholesky scale +// 3. Map to precision space with constrained diagonal +// 4. Evaluate rank-2 log-likelihood ratio +// 5. On accept: rank-1 Cholesky update +// +// Prior: Cauchy(0, pairwise_scale_) on off-diag, Gamma(1, 1) on diagonal. +// ============================================================================= + +void MixedMRFModel::update_pairwise_effects_continuous_offdiag(int i, int j, int iteration) { + get_precision_constants(i, j); + + double phi_curr = kyy_constants_[0]; // Phi_q1q + double phi_prop = rnorm(rng_, phi_curr, proposal_sd_pairwise_continuous_(i, j)); + + double omega_prop_ij = kyy_constants_[2] + kyy_constants_[3] * phi_prop; + double omega_prop_jj = precision_constrained_diagonal(omega_prop_ij); + double diag_curr = pairwise_effects_continuous_(j, j); + + // Fill proposal matrix (only the 3 changed entries matter) + precision_proposal_ = pairwise_effects_continuous_; + precision_proposal_(i, j) = omega_prop_ij; + precision_proposal_(j, i) = omega_prop_ij; + precision_proposal_(j, j) = omega_prop_jj; + + double ln_alpha = log_ggm_ratio_edge(i, j); + + // Marginal mode: add OMRF ratio with proposed Theta + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ln_alpha -= log_marginal_omrf(s); + + arma::mat Theta_saved = Theta_; + arma::mat pairwise_effects_continuous_saved = pairwise_effects_continuous_; + pairwise_effects_continuous_ = precision_proposal_; + recompute_Theta(); + for(size_t s = 0; s < p_; ++s) + ln_alpha += log_marginal_omrf(s); + pairwise_effects_continuous_ = pairwise_effects_continuous_saved; + Theta_ = std::move(Theta_saved); + } + + // Prior ratio: Cauchy on off-diag + Gamma(1,1) on diagonal + ln_alpha += R::dcauchy(omega_prop_ij, 0.0, pairwise_scale_, true); + ln_alpha -= R::dcauchy(pairwise_effects_continuous_(i, j), 0.0, pairwise_scale_, true); + ln_alpha += R::dgamma(omega_prop_jj, 1.0, 1.0, true); + ln_alpha -= R::dgamma(diag_curr, 1.0, 1.0, true); + + if(MY_LOG(runif(rng_)) < ln_alpha) { + double old_ij = pairwise_effects_continuous_(i, j); + double old_jj = pairwise_effects_continuous_(j, j); + + pairwise_effects_continuous_(i, j) = omega_prop_ij; + pairwise_effects_continuous_(j, i) = omega_prop_ij; + pairwise_effects_continuous_(j, j) = omega_prop_jj; + + cholesky_update_after_precision_edge(old_ij, old_jj, i, j); + recompute_conditional_mean(); + if(use_marginal_pl_) recompute_Theta(); + } + + if(iteration >= 1 && iteration < total_warmup_) { + double rm_weight = std::pow(iteration, -0.75); + proposal_sd_pairwise_continuous_(i, j) = update_proposal_sd_with_robbins_monro( + proposal_sd_pairwise_continuous_(i, j), ln_alpha, rm_weight, 0.44); + } +} + + +// ============================================================================= +// update_pairwise_effects_continuous_diag +// ============================================================================= +// MH update for one diagonal element of the precision matrix. +// Proposes on the log-Cholesky scale to ensure positivity. +// Uses rank-1 Cholesky update on accept. +// Prior: Gamma(1, 1) on the diagonal element + Jacobian for log-scale proposal. +// ============================================================================= + +void MixedMRFModel::update_pairwise_effects_continuous_diag(int i, int iteration) { + double logdet = cholesky_helpers::get_log_det(cholesky_of_precision_); + double logdet_sub_ii = logdet + MY_LOG(covariance_continuous_(i, i)); + + double theta_curr = (logdet - logdet_sub_ii) / 2.0; + double theta_prop = rnorm(rng_, theta_curr, proposal_sd_pairwise_continuous_(i, i)); + + precision_proposal_ = pairwise_effects_continuous_; + precision_proposal_(i, i) = pairwise_effects_continuous_(i, i) + - MY_EXP(theta_curr) * MY_EXP(theta_curr) + + MY_EXP(theta_prop) * MY_EXP(theta_prop); + + double ln_alpha = log_ggm_ratio_diag(i); + + // Marginal mode: add OMRF ratio with proposed Theta + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ln_alpha -= log_marginal_omrf(s); + + arma::mat Theta_saved = Theta_; + arma::mat pairwise_effects_continuous_saved = pairwise_effects_continuous_; + pairwise_effects_continuous_ = precision_proposal_; + recompute_Theta(); + for(size_t s = 0; s < p_; ++s) + ln_alpha += log_marginal_omrf(s); + pairwise_effects_continuous_ = pairwise_effects_continuous_saved; + Theta_ = std::move(Theta_saved); + } + + // Prior ratio: Gamma(1,1) on diagonal + ln_alpha += R::dgamma(precision_proposal_(i, i), 1.0, 1.0, true); + ln_alpha -= R::dgamma(pairwise_effects_continuous_(i, i), 1.0, 1.0, true); + + // Jacobian for log-scale proposal + ln_alpha += theta_prop - theta_curr; + + if(MY_LOG(runif(rng_)) < ln_alpha) { + double old_ii = pairwise_effects_continuous_(i, i); + pairwise_effects_continuous_(i, i) = precision_proposal_(i, i); + + cholesky_update_after_precision_diag(old_ii, i); + recompute_conditional_mean(); + if(use_marginal_pl_) recompute_Theta(); + } + + if(iteration >= 1 && iteration < total_warmup_) { + double rm_weight = std::pow(iteration, -0.75); + proposal_sd_pairwise_continuous_(i, i) = update_proposal_sd_with_robbins_monro( + proposal_sd_pairwise_continuous_(i, i), ln_alpha, rm_weight, 0.44); + } +} + + +// ============================================================================= +// update_pairwise_cross +// ============================================================================= +// MH update for one cross-type interaction pairwise_effects_cross_(i, j). +// Acceptance: log_conditional_omrf(i) + log_conditional_ggm() + Cauchy prior. +// Must save/restore conditional_mean_ around the proposal. +// ============================================================================= + +void MixedMRFModel::update_pairwise_cross(int i, int j, int iteration) { + double current_val = pairwise_effects_cross_(i, j); + double proposed = rnorm(rng_, current_val, proposal_sd_pairwise_cross_(i, j)); + + // Current log-posterior + double ll_curr = log_conditional_ggm() + + R::dcauchy(current_val, 0.0, pairwise_scale_, true); + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ll_curr += log_marginal_omrf(s); + } else { + ll_curr += log_conditional_omrf(i); + } + + // Set proposed value and refresh caches + arma::mat cond_mean_saved = conditional_mean_; + arma::mat Theta_saved; + if(use_marginal_pl_) Theta_saved = Theta_; + pairwise_effects_cross_(i, j) = proposed; + recompute_conditional_mean(); + if(use_marginal_pl_) recompute_Theta(); + + double ll_prop = log_conditional_ggm() + + R::dcauchy(proposed, 0.0, pairwise_scale_, true); + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ll_prop += log_marginal_omrf(s); + } else { + ll_prop += log_conditional_omrf(i); + } + + double ln_alpha = ll_prop - ll_curr; + + if(MY_LOG(runif(rng_)) >= ln_alpha) { + pairwise_effects_cross_(i, j) = current_val; // reject + conditional_mean_ = std::move(cond_mean_saved); + if(use_marginal_pl_) Theta_ = std::move(Theta_saved); + } + + if(iteration >= 1 && iteration < total_warmup_) { + double rm_weight = std::pow(iteration, -0.75); + proposal_sd_pairwise_cross_(i, j) = update_proposal_sd_with_robbins_monro( + proposal_sd_pairwise_cross_(i, j), ln_alpha, rm_weight, 0.44); + } +} + + +// ============================================================================= +// update_edge_indicator_discrete +// ============================================================================= +// Metropolis-Hastings add-delete move for a discrete-discrete edge (i, j). +// Add (G=0→1): propose k ~ N(0, σ), accept with slab + Hastings. +// Delete (G=1→0): set k = 0, accept with reverse terms. +// ============================================================================= + +void MixedMRFModel::update_edge_indicator_discrete(int i, int j) { + double k_curr = pairwise_effects_discrete_(i, j); + double prop_sd = proposal_sd_pairwise_discrete_(i, j); + + int g_curr = gxx(i, j); + int g_prop = 1 - g_curr; + + double k_prop; + if(g_prop == 1) { + k_prop = rnorm(rng_, k_curr, prop_sd); // k_curr = 0 on a true add + } else { + k_prop = 0.0; + } + + // --- Likelihood ratio --- + double ll_curr, ll_prop; + if(use_marginal_pl_) { + ll_curr = log_marginal_omrf(i) + log_marginal_omrf(j); + + pairwise_effects_discrete_(i, j) = k_prop; + pairwise_effects_discrete_(j, i) = k_prop; + recompute_Theta(); + + ll_prop = log_marginal_omrf(i) + log_marginal_omrf(j); + + // Restore + pairwise_effects_discrete_(i, j) = k_curr; + pairwise_effects_discrete_(j, i) = k_curr; + recompute_Theta(); + } else { + ll_curr = log_conditional_omrf(i) + log_conditional_omrf(j); + + pairwise_effects_discrete_(i, j) = k_prop; + pairwise_effects_discrete_(j, i) = k_prop; + + ll_prop = log_conditional_omrf(i) + log_conditional_omrf(j); + + // Restore + pairwise_effects_discrete_(i, j) = k_curr; + pairwise_effects_discrete_(j, i) = k_curr; + } + + double ln_alpha = ll_prop - ll_curr; + + if(g_prop == 1) { + // Add: slab prior, subtract proposal density, inclusion prior + ln_alpha += R::dcauchy(k_prop, 0.0, pairwise_scale_, true); + ln_alpha -= R::dnorm(k_prop, k_curr, prop_sd, true); + ln_alpha += MY_LOG(inclusion_probability_(i, j)) + - MY_LOG(1.0 - inclusion_probability_(i, j)); + } else { + // Delete: subtract slab prior, add reverse proposal density, inclusion prior + ln_alpha -= R::dcauchy(k_curr, 0.0, pairwise_scale_, true); + ln_alpha += R::dnorm(k_curr, k_prop, prop_sd, true); + ln_alpha -= MY_LOG(inclusion_probability_(i, j)) + - MY_LOG(1.0 - inclusion_probability_(i, j)); + } + + if(MY_LOG(runif(rng_)) < ln_alpha) { + pairwise_effects_discrete_(i, j) = k_prop; + pairwise_effects_discrete_(j, i) = k_prop; + set_gxx(i, j, g_prop); + if(use_marginal_pl_) recompute_Theta(); + } +} + + +// ============================================================================= +// update_edge_indicator_continuous +// ============================================================================= +// Metropolis-Hastings add-delete move for a continuous-continuous edge (i, j). +// Uses Cholesky reparameterization (permute-free constants extraction). +// Add (G=0→1): propose ε ~ N(0, σ), k = C[2]*ε, constrain diagonal. +// Delete (G=1→0): set off-diag = 0, constrain diagonal. +// ============================================================================= + +void MixedMRFModel::update_edge_indicator_continuous(int i, int j) { + get_precision_constants(i, j); + + int g_curr = gyy(i, j); + int g_prop = 1 - g_curr; + + double omega_prop_ij, omega_prop_jj; + + if(g_prop == 1) { + // Add: propose from N(0, σ) on reparameterized scale + double epsilon = rnorm(rng_, 0.0, proposal_sd_pairwise_continuous_(i, j)); + omega_prop_ij = kyy_constants_[3] * epsilon; + omega_prop_jj = precision_constrained_diagonal(omega_prop_ij); + } else { + // Delete: set off-diagonal to 0 + omega_prop_ij = 0.0; + omega_prop_jj = precision_constrained_diagonal(0.0); + } + + // Fill proposal + precision_proposal_ = pairwise_effects_continuous_; + precision_proposal_(i, j) = omega_prop_ij; + precision_proposal_(j, i) = omega_prop_ij; + precision_proposal_(j, j) = omega_prop_jj; + + // --- Likelihood ratio --- + double ln_alpha = log_ggm_ratio_edge(i, j); + + if(use_marginal_pl_) { + for(size_t s = 0; s < p_; ++s) + ln_alpha -= log_marginal_omrf(s); + + arma::mat Theta_saved = Theta_; + arma::mat pairwise_effects_continuous_saved = pairwise_effects_continuous_; + pairwise_effects_continuous_ = precision_proposal_; + recompute_Theta(); + for(size_t s = 0; s < p_; ++s) + ln_alpha += log_marginal_omrf(s); + pairwise_effects_continuous_ = pairwise_effects_continuous_saved; + Theta_ = std::move(Theta_saved); + } + + // --- Spike-and-slab terms --- + if(g_prop == 1) { + // Add: slab prior on proposed off-diag + ln_alpha += R::dcauchy(omega_prop_ij, 0.0, pairwise_scale_, true); + // Subtract proposal density: dnorm(k_prop / C[2], 0, σ) / C[2] + // = dnorm(epsilon, 0, σ) / C[2] + ln_alpha -= R::dnorm(omega_prop_ij / kyy_constants_[3], 0.0, + proposal_sd_pairwise_continuous_(i, j), true) + - MY_LOG(kyy_constants_[3]); + // Inclusion prior: log(π / (1-π)) + ln_alpha += MY_LOG(inclusion_probability_(p_ + i, p_ + j)) + - MY_LOG(1.0 - inclusion_probability_(p_ + i, p_ + j)); + } else { + // Delete: subtract slab prior on current off-diag + ln_alpha -= R::dcauchy(pairwise_effects_continuous_(i, j), 0.0, pairwise_scale_, true); + // Add reverse proposal density: dnorm(k_curr / C[2], 0, σ) / C[2] + ln_alpha += R::dnorm(pairwise_effects_continuous_(i, j) / kyy_constants_[3], 0.0, + proposal_sd_pairwise_continuous_(i, j), true) + - MY_LOG(kyy_constants_[3]); + // Inclusion prior: log((1-π) / π) + ln_alpha -= MY_LOG(inclusion_probability_(p_ + i, p_ + j)) + - MY_LOG(1.0 - inclusion_probability_(p_ + i, p_ + j)); + } + + if(MY_LOG(runif(rng_)) < ln_alpha) { + double old_ij = pairwise_effects_continuous_(i, j); + double old_jj = pairwise_effects_continuous_(j, j); + + pairwise_effects_continuous_(i, j) = omega_prop_ij; + pairwise_effects_continuous_(j, i) = omega_prop_ij; + pairwise_effects_continuous_(j, j) = omega_prop_jj; + + set_gyy(i, j, g_prop); + cholesky_update_after_precision_edge(old_ij, old_jj, i, j); + recompute_conditional_mean(); + if(use_marginal_pl_) recompute_Theta(); + } +} + + +// ============================================================================= +// update_edge_indicator_cross +// ============================================================================= +// Metropolis-Hastings add-delete move for a cross-type edge (i, j). +// Add (G=0→1): propose k ~ N(0, σ). +// Delete (G=1→0): set k = 0. +// ============================================================================= + +void MixedMRFModel::update_edge_indicator_cross(int i, int j) { + double k_curr = pairwise_effects_cross_(i, j); + double prop_sd = proposal_sd_pairwise_cross_(i, j); + + int g_curr = gxy(i, j); + int g_prop = 1 - g_curr; + + double k_prop; + if(g_prop == 1) { + k_prop = rnorm(rng_, k_curr, prop_sd); // k_curr = 0 on a true add + } else { + k_prop = 0.0; + } + + // --- Likelihood ratio --- + double ll_curr, ll_prop; + if(use_marginal_pl_) { + ll_curr = log_conditional_ggm(); + for(size_t s = 0; s < p_; ++s) + ll_curr += log_marginal_omrf(s); + + arma::mat cond_mean_saved = conditional_mean_; + arma::mat Theta_saved = Theta_; + pairwise_effects_cross_(i, j) = k_prop; + recompute_conditional_mean(); + recompute_Theta(); + + ll_prop = log_conditional_ggm(); + for(size_t s = 0; s < p_; ++s) + ll_prop += log_marginal_omrf(s); + + // Restore + pairwise_effects_cross_(i, j) = k_curr; + conditional_mean_ = std::move(cond_mean_saved); + Theta_ = std::move(Theta_saved); + } else { + ll_curr = log_conditional_omrf(i) + log_conditional_ggm(); + + arma::mat cond_mean_saved = conditional_mean_; + pairwise_effects_cross_(i, j) = k_prop; + recompute_conditional_mean(); + + ll_prop = log_conditional_omrf(i) + log_conditional_ggm(); + + // Restore + pairwise_effects_cross_(i, j) = k_curr; + conditional_mean_ = std::move(cond_mean_saved); + } + + double ln_alpha = ll_prop - ll_curr; + + if(g_prop == 1) { + // Add + ln_alpha += R::dcauchy(k_prop, 0.0, pairwise_scale_, true); + ln_alpha -= R::dnorm(k_prop, k_curr, prop_sd, true); + ln_alpha += MY_LOG(inclusion_probability_(i, p_ + j)) + - MY_LOG(1.0 - inclusion_probability_(i, p_ + j)); + } else { + // Delete + ln_alpha -= R::dcauchy(k_curr, 0.0, pairwise_scale_, true); + ln_alpha += R::dnorm(k_curr, k_prop, prop_sd, true); + ln_alpha -= MY_LOG(inclusion_probability_(i, p_ + j)) + - MY_LOG(1.0 - inclusion_probability_(i, p_ + j)); + } + + if(MY_LOG(runif(rng_)) < ln_alpha) { + pairwise_effects_cross_(i, j) = k_prop; + set_gxy(i, j, g_prop); + recompute_conditional_mean(); + if(use_marginal_pl_) recompute_Theta(); + } +} diff --git a/src/models/mixed/mixed_mrf_model.cpp b/src/models/mixed/mixed_mrf_model.cpp new file mode 100644 index 00000000..0d95387c --- /dev/null +++ b/src/models/mixed/mixed_mrf_model.cpp @@ -0,0 +1,839 @@ +#include +#include "models/mixed/mixed_mrf_model.h" +#include "math/explog_macros.h" +#include "rng/rng_utils.h" +#include "mcmc/execution/warmup_schedule.h" + + +// ============================================================================= +// Constructor +// ============================================================================= + +MixedMRFModel::MixedMRFModel( + const arma::imat& discrete_observations, + const arma::mat& continuous_observations, + const arma::ivec& num_categories, + const arma::uvec& is_ordinal_variable, + const arma::ivec& baseline_category, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + bool edge_selection, + const std::string& pseudolikelihood, + double main_alpha, + double main_beta, + double pairwise_scale, + int seed +) : + n_(discrete_observations.n_rows), + p_(discrete_observations.n_cols), + q_(continuous_observations.n_cols), + discrete_observations_(discrete_observations), + continuous_observations_(continuous_observations), + num_categories_(num_categories), + is_ordinal_variable_(is_ordinal_variable), + baseline_category_(baseline_category), + edge_indicators_(initial_edge_indicators), + inclusion_probability_(inclusion_probability), + edge_selection_(edge_selection), + edge_selection_active_(false), + main_alpha_(main_alpha), + main_beta_(main_beta), + pairwise_scale_(pairwise_scale), + use_marginal_pl_(pseudolikelihood == "marginal"), + rng_(seed) +{ + // Dimension counts + num_main_ = count_num_main_effects(); + num_pairwise_xx_ = (p_ * (p_ - 1)) / 2; + num_pairwise_yy_ = (q_ * (q_ - 1)) / 2; + num_cross_ = p_ * q_; + + max_cats_ = num_categories_.max(); + + // Center Blume-Capel observations at baseline category so that all + // downstream code operates in a shifted coordinate system where the + // reference corresponds to zero (same convention as OMRFModel). + for(size_t s = 0; s < p_; ++s) { + if(!is_ordinal_variable_(s)) { + discrete_observations_.col(s) -= baseline_category_(s); + } + } + discrete_observations_dbl_ = arma::conv_to::from(discrete_observations_); + + // Compute sufficient statistics + compute_sufficient_statistics(); + + // Initialize parameters to zero + main_effects_discrete_ = arma::zeros(p_, max_cats_); + main_effects_continuous_ = arma::zeros(q_); + pairwise_effects_discrete_ = arma::zeros(p_, p_); + pairwise_effects_continuous_ = arma::eye(q_, q_); + pairwise_effects_cross_ = arma::zeros(p_, q_); + + // Initialize proposal SDs + proposal_sd_main_discrete_ = arma::ones(p_, max_cats_); + proposal_sd_main_continuous_ = arma::ones(q_); + proposal_sd_pairwise_discrete_ = arma::ones(p_, p_); + proposal_sd_pairwise_continuous_ = arma::ones(q_, q_); + proposal_sd_pairwise_cross_ = arma::ones(p_, q_); + + // Initialize precision caches (K_yy starts as identity) + cholesky_of_precision_ = arma::eye(q_, q_); + inv_cholesky_of_precision_ = arma::eye(q_, q_); + covariance_continuous_ = arma::eye(q_, q_); + log_det_precision_ = 0.0; + + // Rank-1 Cholesky update workspace + precision_proposal_ = arma::mat(q_, q_, arma::fill::none); + kyy_vf1_ = arma::zeros(q_); + kyy_vf2_ = arma::zeros(q_); + kyy_u1_ = arma::zeros(q_); + kyy_u2_ = arma::zeros(q_); + + // Initialize conditional mean: M = μ_y' + 2 X K_xy Σ_yy + // With K_xy = 0 and K_yy = I, this reduces to μ_y' = 0. + conditional_mean_ = arma::zeros(n_, q_); + + // Initialize Theta (marginal PL only): Θ = K_xx + 2 K_xy Σ_yy K_xy' + // With K_xy = 0, Θ = K_xx = 0. + if(use_marginal_pl_) { + Theta_ = arma::zeros(p_, p_); + } + + // Initialize edge-order permutation vectors + edge_order_xx_ = arma::regspace(0, num_pairwise_xx_ - 1); + edge_order_yy_ = arma::regspace(0, num_pairwise_yy_ - 1); + edge_order_xy_ = arma::regspace(0, num_cross_ - 1); +} + + +// ============================================================================= +// Copy constructor +// ============================================================================= + +MixedMRFModel::MixedMRFModel(const MixedMRFModel& other) + : BaseModel(other), + n_(other.n_), + p_(other.p_), + q_(other.q_), + num_main_(other.num_main_), + num_pairwise_xx_(other.num_pairwise_xx_), + num_pairwise_yy_(other.num_pairwise_yy_), + num_cross_(other.num_cross_), + discrete_observations_(other.discrete_observations_), + discrete_observations_dbl_(other.discrete_observations_dbl_), + continuous_observations_(other.continuous_observations_), + num_categories_(other.num_categories_), + max_cats_(other.max_cats_), + is_ordinal_variable_(other.is_ordinal_variable_), + baseline_category_(other.baseline_category_), + missing_index_discrete_(other.missing_index_discrete_), + missing_index_continuous_(other.missing_index_continuous_), + has_missing_(other.has_missing_), + counts_per_category_(other.counts_per_category_), + blume_capel_stats_(other.blume_capel_stats_), + main_effects_discrete_(other.main_effects_discrete_), + main_effects_continuous_(other.main_effects_continuous_), + pairwise_effects_discrete_(other.pairwise_effects_discrete_), + pairwise_effects_continuous_(other.pairwise_effects_continuous_), + pairwise_effects_cross_(other.pairwise_effects_cross_), + edge_indicators_(other.edge_indicators_), + inclusion_probability_(other.inclusion_probability_), + edge_selection_(other.edge_selection_), + edge_selection_active_(other.edge_selection_active_), + main_alpha_(other.main_alpha_), + main_beta_(other.main_beta_), + pairwise_scale_(other.pairwise_scale_), + proposal_sd_main_discrete_(other.proposal_sd_main_discrete_), + proposal_sd_main_continuous_(other.proposal_sd_main_continuous_), + proposal_sd_pairwise_discrete_(other.proposal_sd_pairwise_discrete_), + proposal_sd_pairwise_continuous_(other.proposal_sd_pairwise_continuous_), + proposal_sd_pairwise_cross_(other.proposal_sd_pairwise_cross_), + total_warmup_(other.total_warmup_), + cholesky_of_precision_(other.cholesky_of_precision_), + inv_cholesky_of_precision_(other.inv_cholesky_of_precision_), + covariance_continuous_(other.covariance_continuous_), + log_det_precision_(other.log_det_precision_), + Theta_(other.Theta_), + conditional_mean_(other.conditional_mean_), + kyy_constants_(other.kyy_constants_), + precision_proposal_(other.precision_proposal_), + kyy_v1_(other.kyy_v1_), + kyy_v2_(other.kyy_v2_), + kyy_vf1_(other.kyy_vf1_), + kyy_vf2_(other.kyy_vf2_), + kyy_u1_(other.kyy_u1_), + kyy_u2_(other.kyy_u2_), + gradient_cache_valid_(false), + use_marginal_pl_(other.use_marginal_pl_), + rng_(other.rng_), + edge_order_xx_(other.edge_order_xx_), + edge_order_yy_(other.edge_order_yy_), + edge_order_xy_(other.edge_order_xy_) +{ +} + + +// ============================================================================= +// Sufficient statistics +// ============================================================================= + +void MixedMRFModel::compute_sufficient_statistics() { + // Category counts for ordinal variables + counts_per_category_ = arma::zeros(max_cats_ + 1, p_); + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(size_t i = 0; i < n_; ++i) { + int cat = discrete_observations_(i, s); + if(cat >= 0 && cat <= num_categories_(s)) { + counts_per_category_(cat, s)++; + } + } + } + } + + // Blume-Capel statistics (linear and quadratic sums of centered obs) + blume_capel_stats_ = arma::zeros(2, p_); + for(size_t s = 0; s < p_; ++s) { + if(!is_ordinal_variable_(s)) { + for(size_t i = 0; i < n_; ++i) { + int val = discrete_observations_(i, s); // already centered + blume_capel_stats_(0, s) += val; + blume_capel_stats_(1, s) += val * val; + } + } + } +} + + +size_t MixedMRFModel::count_num_main_effects() const { + size_t count = 0; + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + count += num_categories_(s); + } else { + count += 2; // linear α and quadratic β + } + } + return count; +} + + +// ============================================================================= +// Cache maintenance +// ============================================================================= + +void MixedMRFModel::recompute_conditional_mean() { + // M = μ_y' + 2 X K_xy Σ_yy + conditional_mean_ = arma::repmat(main_effects_continuous_.t(), n_, 1) + + 2.0 * discrete_observations_dbl_ * pairwise_effects_cross_ * covariance_continuous_; +} + +void MixedMRFModel::recompute_pairwise_effects_continuous_decomposition() { + cholesky_of_precision_ = arma::chol(pairwise_effects_continuous_); // upper Cholesky: K_yy = R'R + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_continuous_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); + log_det_precision_ = cholesky_helpers::get_log_det(cholesky_of_precision_); +} + +void MixedMRFModel::recompute_Theta() { + // Θ = K_xx + 2 K_xy Σ_yy K_xy' + Theta_ = pairwise_effects_discrete_ + 2.0 * pairwise_effects_cross_ * covariance_continuous_ * pairwise_effects_cross_.t(); +} + + +// ============================================================================= +// Parameter vectorization +// ============================================================================= + +// NUTS vectorization order (excludes pairwise_effects_continuous_ — sampled by MH separately): +// 1. main_effects_discrete_: per-variable (ordinal: C_s thresholds; BC: 2 coefficients) +// 2. pairwise_effects_discrete_: upper-triangular, row-major — p(p-1)/2 +// 3. main_effects_continuous_: all q means +// 4. pairwise_effects_cross_: all p*q entries, row-major +// +// Storage vectorization order (includes pairwise_effects_continuous_): +// 1–4. Same as NUTS order +// 5. pairwise_effects_continuous_: upper-triangle including diagonal — q(q+1)/2 + +size_t MixedMRFModel::parameter_dimension() const { + if(!edge_selection_active_) { + return full_parameter_dimension(); + } + // Count active NUTS parameters only (no pairwise_effects_continuous_) + size_t dim = num_main_ + q_; // main effects always active + + // Active pairwise_effects_discrete_ edges + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(gxx(i, j)) dim++; + } + } + + // Active pairwise_effects_cross_ edges + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(gxy(i, j)) dim++; + } + } + + return dim; +} + +size_t MixedMRFModel::full_parameter_dimension() const { + // NUTS block: main + pairwise_discrete upper-tri + means + pairwise_cross (no precision) + return num_main_ + num_pairwise_xx_ + q_ + num_cross_; +} + +size_t MixedMRFModel::storage_dimension() const { + // All parameters including pairwise_effects_continuous_ + return num_main_ + num_pairwise_xx_ + q_ + + (q_ * (q_ + 1)) / 2 + num_cross_; +} + +arma::vec MixedMRFModel::get_vectorized_parameters() const { + // Active NUTS parameters only (excludes precision, excludes inactive edges) + arma::vec out(parameter_dimension()); + size_t idx = 0; + + // 1. main_effects_discrete_ + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) { + out(idx++) = main_effects_discrete_(s, c); + } + } else { + out(idx++) = main_effects_discrete_(s, 0); + out(idx++) = main_effects_discrete_(s, 1); + } + } + + // 2. pairwise_effects_discrete_ upper-triangular (active edges only when selection is active) + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(!edge_selection_active_ || gxx(i, j) == 1) { + out(idx++) = pairwise_effects_discrete_(i, j); + } + } + } + + // 3. main_effects_continuous_ + for(size_t j = 0; j < q_; ++j) { + out(idx++) = main_effects_continuous_(j); + } + + // 4. pairwise_effects_cross_ row-major (active edges only when selection is active) + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(!edge_selection_active_ || gxy(i, j) == 1) { + out(idx++) = pairwise_effects_cross_(i, j); + } + } + } + + return out; +} + +arma::vec MixedMRFModel::get_full_vectorized_parameters() const { + // All NUTS parameters, fixed size (inactive edges are 0, no precision) + arma::vec out(full_parameter_dimension(), arma::fill::zeros); + size_t idx = 0; + + // 1. main_effects_discrete_ + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) { + out(idx++) = main_effects_discrete_(s, c); + } + } else { + out(idx++) = main_effects_discrete_(s, 0); + out(idx++) = main_effects_discrete_(s, 1); + } + } + + // 2. pairwise_effects_discrete_ upper-triangular (all entries, zeros for inactive) + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + out(idx++) = pairwise_effects_discrete_(i, j); + } + } + + // 3. main_effects_continuous_ + for(size_t j = 0; j < q_; ++j) { + out(idx++) = main_effects_continuous_(j); + } + + // 4. pairwise_effects_cross_ row-major (all entries, zeros for inactive) + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + out(idx++) = pairwise_effects_cross_(i, j); + } + } + + return out; +} + +arma::vec MixedMRFModel::get_storage_vectorized_parameters() const { + // All parameters including pairwise_effects_continuous_, fixed size + arma::vec out(storage_dimension(), arma::fill::zeros); + size_t idx = 0; + + // 1. main_effects_discrete_ + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) { + out(idx++) = main_effects_discrete_(s, c); + } + } else { + out(idx++) = main_effects_discrete_(s, 0); + out(idx++) = main_effects_discrete_(s, 1); + } + } + + // 2. pairwise_effects_discrete_ upper-triangular + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + out(idx++) = pairwise_effects_discrete_(i, j); + } + } + + // 3. main_effects_continuous_ + for(size_t j = 0; j < q_; ++j) { + out(idx++) = main_effects_continuous_(j); + } + + // 4. pairwise_effects_cross_ row-major + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + out(idx++) = pairwise_effects_cross_(i, j); + } + } + + // 5. pairwise_effects_continuous_ upper-triangle including diagonal + for(size_t i = 0; i < q_; ++i) { + for(size_t j = i; j < q_; ++j) { + out(idx++) = pairwise_effects_continuous_(i, j); + } + } + + return out; +} + +void MixedMRFModel::set_vectorized_parameters(const arma::vec& params) { + // Unpack NUTS block only (no pairwise_effects_continuous_) + size_t idx = 0; + + // 1. main_effects_discrete_ + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) { + main_effects_discrete_(s, c) = params(idx++); + } + } else { + main_effects_discrete_(s, 0) = params(idx++); + main_effects_discrete_(s, 1) = params(idx++); + } + } + + // 2. pairwise_effects_discrete_ upper-triangular (active edges only when selection is active) + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(!edge_selection_active_ || gxx(i, j) == 1) { + pairwise_effects_discrete_(i, j) = params(idx); + pairwise_effects_discrete_(j, i) = params(idx); + idx++; + } + } + } + + // 3. main_effects_continuous_ + for(size_t j = 0; j < q_; ++j) { + main_effects_continuous_(j) = params(idx++); + } + + // 4. pairwise_effects_cross_ row-major (active edges only when selection is active) + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(!edge_selection_active_ || gxy(i, j) == 1) { + pairwise_effects_cross_(i, j) = params(idx++); + } + } + } + + // Refresh caches (precision unchanged, so no decomposition update needed) + recompute_conditional_mean(); + if(use_marginal_pl_) { + recompute_Theta(); + } +} + +arma::vec MixedMRFModel::get_active_inv_mass() const { + if(!edge_selection_active_) { + return inv_mass_; + } + + arma::vec active(parameter_dimension()); + // Main effects: always active + active.head(num_main_) = inv_mass_.head(num_main_); + + size_t offset_full = num_main_; + size_t offset_active = num_main_; + + // pairwise_effects_discrete_ active edges + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(gxx(i, j) == 1) { + active(offset_active++) = inv_mass_(offset_full); + } + offset_full++; + } + } + + // main_effects_continuous_: always active + for(size_t j = 0; j < q_; ++j) { + active(offset_active++) = inv_mass_(offset_full++); + } + + // pairwise_effects_cross_ active edges + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(gxy(i, j) == 1) { + active(offset_active++) = inv_mass_(offset_full); + } + offset_full++; + } + } + + return active; +} + +arma::ivec MixedMRFModel::get_vectorized_indicator_parameters() { + size_t total = num_pairwise_xx_ + num_pairwise_yy_ + num_cross_; + arma::ivec out(total); + size_t idx = 0; + + // 1. Upper-triangle of Gxx + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + out(idx++) = gxx(i, j); + } + } + + // 2. Upper-triangle of Gyy + for(size_t i = 0; i < q_ - 1; ++i) { + for(size_t j = i + 1; j < q_; ++j) { + out(idx++) = gyy(i, j); + } + } + + // 3. Full Gxy block row-major + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + out(idx++) = gxy(i, j); + } + } + + return out; +} + + +// ============================================================================= +// Infrastructure +// ============================================================================= + +void MixedMRFModel::set_seed(int seed) { + rng_ = SafeRNG(seed); +} + +std::unique_ptr MixedMRFModel::clone() const { + return std::make_unique(*this); +} + + +// ============================================================================= +// Missing data imputation +// ============================================================================= + +void MixedMRFModel::set_missing_data(const arma::imat& missing_discrete, + const arma::imat& missing_continuous) { + missing_index_discrete_ = missing_discrete; + missing_index_continuous_ = missing_continuous; + has_missing_ = (missing_index_discrete_.n_rows > 0 || + missing_index_continuous_.n_rows > 0); +} + +void MixedMRFModel::impute_missing() { + if(!has_missing_) return; + + // --- Phase 1: Impute discrete entries --- + const int num_disc_missing = missing_index_discrete_.n_rows; + if(num_disc_missing > 0) { + arma::vec category_probabilities(max_cats_ + 1); + + for(int miss = 0; miss < num_disc_missing; miss++) { + const int person = missing_index_discrete_(miss, 0); + const int variable = missing_index_discrete_(miss, 1); + const int num_cats = num_categories_(variable); + + // Rest score: sum_t x_vt K_xx(t,s) + 2 sum_j y_vj K_xy(s,j) + // K_xx diagonal is zero, so no self-interaction subtraction needed + double rest_v = 0.0; + for(size_t t = 0; t < p_; t++) { + rest_v += discrete_observations_dbl_(person, t) * pairwise_effects_discrete_(t, variable); + } + for(size_t j = 0; j < q_; j++) { + rest_v += 2.0 * continuous_observations_(person, j) * pairwise_effects_cross_(variable, j); + } + + double cumsum = 0.0; + + if(is_ordinal_variable_(variable)) { + // P(x=0) = 1, P(x=c) ∝ exp(c · rest + μ_x(s, c-1)) + cumsum = 1.0; + category_probabilities(0) = cumsum; + for(int c = 1; c <= num_cats; c++) { + double exponent = static_cast(c) * rest_v + + main_effects_discrete_(variable, c - 1); + cumsum += MY_EXP(exponent); + category_probabilities(c) = cumsum; + } + } else { + // Blume-Capel: categories centered at baseline + const int ref = baseline_category_(variable); + double alpha = main_effects_discrete_(variable, 0); + double beta = main_effects_discrete_(variable, 1); + cumsum = 0.0; + for(int cat = 0; cat <= num_cats; cat++) { + const int score = cat - ref; + double exponent = alpha * score + + beta * score * score + + score * rest_v; + cumsum += MY_EXP(exponent); + category_probabilities(cat) = cumsum; + } + } + + // Sample via inverse-transform + double u = runif(rng_) * cumsum; + int sampled = 0; + while(u > category_probabilities(sampled)) { + sampled++; + } + + int new_value = sampled; + if(!is_ordinal_variable_(variable)) { + new_value -= baseline_category_(variable); + } + const int old_value = discrete_observations_(person, variable); + + if(new_value != old_value) { + discrete_observations_(person, variable) = new_value; + discrete_observations_dbl_(person, variable) = + static_cast(new_value); + + if(is_ordinal_variable_(variable)) { + counts_per_category_(old_value, variable)--; + counts_per_category_(new_value, variable)++; + } else { + blume_capel_stats_(0, variable) += (new_value - old_value); + blume_capel_stats_(1, variable) += + (new_value * new_value - old_value * old_value); + } + } + } + } + + // --- Phase 2: Refresh conditional_mean_ (depends on discrete data) --- + if(num_disc_missing > 0 && missing_index_continuous_.n_rows > 0) { + recompute_conditional_mean(); + } + + // --- Phase 3: Impute continuous entries --- + const int num_cont_missing = missing_index_continuous_.n_rows; + if(num_cont_missing > 0) { + for(int miss = 0; miss < num_cont_missing; miss++) { + const int person = missing_index_continuous_(miss, 0); + const int variable = missing_index_continuous_(miss, 1); + + // Conditional: y_vj | y_{v,-j}, x ~ N(mu*, 1/pairwise_effects_continuous_jj) + // mu* = M_vj - (1/pairwise_effects_continuous_jj) * sum_{k!=j} pairwise_effects_continuous_jk * (y_vk - M_vk) + double cond_mean = conditional_mean_(person, variable); + for(size_t k = 0; k < q_; k++) { + if(k != static_cast(variable)) { + cond_mean -= (pairwise_effects_continuous_(variable, k) / pairwise_effects_continuous_(variable, variable)) * + (continuous_observations_(person, k) - + conditional_mean_(person, k)); + } + } + double cond_sd = std::sqrt(1.0 / pairwise_effects_continuous_(variable, variable)); + + continuous_observations_(person, variable) = + rnorm(rng_, cond_mean, cond_sd); + } + } + + // Invalidate gradient cache (observations changed) + invalidate_gradient_cache(); +} + + +// ============================================================================= +// Stubs (to be implemented in later phases) +// ============================================================================= + +void MixedMRFModel::do_one_metropolis_step(int iteration) { + // Step 1: Update all main effects (ordinal thresholds or BC α/β) + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) + update_main_effect(s, c, iteration); + } else { + update_main_effect(s, 0, iteration); // linear α + update_main_effect(s, 1, iteration); // quadratic β + } + } + + // Step 2: Update all continuous means + for(size_t j = 0; j < q_; ++j) + update_continuous_mean(j, iteration); + + // Step 3: Update pairwise_effects_discrete_ (upper triangle, edge-gated) + for(size_t i = 0; i < p_ - 1; ++i) + for(size_t j = i + 1; j < p_; ++j) + if(!edge_selection_active_ || gxx(i, j) == 1) + update_pairwise_discrete(i, j, iteration); + + // Step 4: Update pairwise_effects_continuous_ (off-diag + diagonal, edge-gated) + if(q_ >= 2) { + for(size_t i = 0; i < q_ - 1; ++i) + for(size_t j = i + 1; j < q_; ++j) + if(!edge_selection_active_ || gyy(i, j) == 1) + update_pairwise_effects_continuous_offdiag(i, j, iteration); + } + for(size_t i = 0; i < q_; ++i) + update_pairwise_effects_continuous_diag(i, iteration); + + // Step 5: Update pairwise_effects_cross_ (edge-gated) + for(size_t i = 0; i < p_; ++i) + for(size_t j = 0; j < q_; ++j) + if(!edge_selection_active_ || gxy(i, j) == 1) + update_pairwise_cross(i, j, iteration); + + // Edge-indicator updates are handled by ChainRunner, not here. + // (Matches the OMRF pattern; avoids double-counting indicator proposals.) +} + +void MixedMRFModel::do_pairwise_continuous_metropolis_step(int iteration) { + // Off-diagonal precision (edge-gated) + if(q_ >= 2) { + for(size_t i = 0; i < q_ - 1; ++i) + for(size_t j = i + 1; j < q_; ++j) + if(!edge_selection_active_ || gyy(i, j) == 1) + update_pairwise_effects_continuous_offdiag(i, j, iteration); + } + // Diagonal precision (always updated) + for(size_t i = 0; i < q_; ++i) + update_pairwise_effects_continuous_diag(i, iteration); +} + +void MixedMRFModel::update_edge_indicators() { + if(!edge_selection_active_) return; + + invalidate_gradient_cache(); + + // Discrete-discrete edges (shuffled order) + for(size_t e = 0; e < num_pairwise_xx_; ++e) { + size_t idx = edge_order_xx_(e); + // Decode upper-triangle index to (i, j) + size_t i = 0, j = 1; + size_t count = 0; + for(i = 0; i < p_ - 1; ++i) { + size_t row_len = p_ - 1 - i; + if(count + row_len > idx) { + j = i + 1 + (idx - count); + break; + } + count += row_len; + } + update_edge_indicator_discrete(i, j); + } + + // Continuous-continuous edges (shuffled order) + for(size_t e = 0; e < num_pairwise_yy_; ++e) { + size_t idx = edge_order_yy_(e); + size_t i = 0, j = 1; + size_t count = 0; + for(i = 0; i < q_ - 1; ++i) { + size_t row_len = q_ - 1 - i; + if(count + row_len > idx) { + j = i + 1 + (idx - count); + break; + } + count += row_len; + } + update_edge_indicator_continuous(i, j); + } + + // Cross edges (shuffled order) + for(size_t e = 0; e < num_cross_; ++e) { + size_t idx = edge_order_xy_(e); + size_t i = idx / q_; + size_t j = idx % q_; + update_edge_indicator_cross(i, j); + } +} + +void MixedMRFModel::initialize_graph() { + // Draw initial graph from prior inclusion probabilities. + // Zero out parameters for excluded edges. + for(size_t i = 0; i < p_ - 1; ++i) { + for(size_t j = i + 1; j < p_; ++j) { + if(runif(rng_) >= inclusion_probability_(i, j)) { + set_gxx(i, j, 0); + pairwise_effects_discrete_(i, j) = 0.0; + pairwise_effects_discrete_(j, i) = 0.0; + } + } + } + + for(size_t i = 0; i < q_ - 1; ++i) { + for(size_t j = i + 1; j < q_; ++j) { + if(runif(rng_) >= inclusion_probability_(p_ + i, p_ + j)) { + set_gyy(i, j, 0); + pairwise_effects_continuous_(i, j) = 0.0; + pairwise_effects_continuous_(j, i) = 0.0; + } + } + } + // Recompute precision decomposition after potential zeroing + recompute_pairwise_effects_continuous_decomposition(); + recompute_conditional_mean(); + + for(size_t i = 0; i < p_; ++i) { + for(size_t j = 0; j < q_; ++j) { + if(runif(rng_) >= inclusion_probability_(i, p_ + j)) { + set_gxy(i, j, 0); + pairwise_effects_cross_(i, j) = 0.0; + } + } + } + recompute_conditional_mean(); + if(use_marginal_pl_) recompute_Theta(); +} + +void MixedMRFModel::prepare_iteration() { + // Shuffle edge-update order to avoid order bias. + // Always called, even when edge selection is off, to keep RNG consistent. + edge_order_xx_ = arma_randperm(rng_, num_pairwise_xx_); + edge_order_yy_ = arma_randperm(rng_, num_pairwise_yy_); + edge_order_xy_ = arma_randperm(rng_, num_cross_); +} + +void MixedMRFModel::init_metropolis_adaptation(const WarmupSchedule& schedule) { + total_warmup_ = schedule.total_warmup; +} + +void MixedMRFModel::tune_proposal_sd(int /*iteration*/, const WarmupSchedule& /*schedule*/) { + // Robbins-Monro adaptation is embedded in each MH update function, + // gated by iteration < total_warmup_. No separate tuning pass needed. +} diff --git a/src/models/mixed/mixed_mrf_model.h b/src/models/mixed/mixed_mrf_model.h new file mode 100644 index 00000000..9b7995ec --- /dev/null +++ b/src/models/mixed/mixed_mrf_model.h @@ -0,0 +1,503 @@ +#pragma once + +#include +#include +#include "models/base_model.h" +#include "math/cholesky_helpers.h" +#include "math/cholupdate.h" +#include "rng/rng_utils.h" + +/** + * MixedMRFModel - Mixed Markov Random Field Model + * + * Joint model for p discrete (ordinal or Blume-Capel) variables x and + * q continuous variables y. The joint density is: + * + * log f(x, y) ∝ Σ_s μ_{x,s}(x_s) + x' Kxx x + * - ½ (y - μ_y)' Kyy (y - μ_y) + 2 x' Kxy y + * + * Supports both conditional and marginal pseudo-likelihood, with and + * without edge selection via spike-and-slab priors. + * + * Discrete variables are either ordinal (free category thresholds, category + * 0 as reference) or Blume-Capel (linear α + quadratic β, user-specified + * reference). Blume-Capel observations are centered at their baseline + * category in the constructor, matching the OMRFModel convention. + * + * Inherits from BaseModel for compatibility with the generic MCMC framework + * (ChainRunner, MetropolisSampler, WarmupSchedule). + */ +class MixedMRFModel : public BaseModel { +public: + + // ========================================================================= + // Construction + // ========================================================================= + + /** + * Construct from raw observations. + * + * @param discrete_observations Integer matrix of discrete observations (n × p, 0-based) + * @param continuous_observations Continuous observations (n × q) + * @param num_categories Number of categories per discrete variable (p-vector) + * @param is_ordinal_variable 1 = ordinal, 0 = Blume-Capel (p-vector) + * @param baseline_category Reference category per discrete variable (p-vector) + * @param inclusion_probability Prior inclusion probabilities ((p+q) × (p+q)) + * @param initial_edge_indicators Initial edge inclusion matrix ((p+q) × (p+q)) + * @param edge_selection Enable edge selection (spike-and-slab) + * @param pseudolikelihood "conditional" or "marginal" + * @param main_alpha Beta prior hyperparameter α for main effects + * @param main_beta Beta prior hyperparameter β for main effects + * @param pairwise_scale Scale parameter of Cauchy prior on interactions + * @param seed RNG seed for reproducibility + */ + MixedMRFModel( + const arma::imat& discrete_observations, + const arma::mat& continuous_observations, + const arma::ivec& num_categories, + const arma::uvec& is_ordinal_variable, + const arma::ivec& baseline_category, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + bool edge_selection, + const std::string& pseudolikelihood, + double main_alpha = 1.0, + double main_beta = 1.0, + double pairwise_scale = 2.5, + int seed = 1 + ); + + /** Copy constructor for cloning (required for parallel chains). */ + MixedMRFModel(const MixedMRFModel& other); + + // ========================================================================= + // Capability queries + // ========================================================================= + + /** @return true (MixedMRFModel supports NUTS gradient for the unconstrained block). */ + bool has_gradient() const override { return true; } + /** @return true (supports adaptive Metropolis via Robbins-Monro). */ + bool has_adaptive_metropolis() const override { return true; } + /** @return true when edge selection is enabled. */ + bool has_edge_selection() const override { return edge_selection_; } + /** @return true when missing-data imputation is active. */ + bool has_missing_data() const override { return has_missing_; } + + // ========================================================================= + // Core sampling methods + // ========================================================================= + + /** + * Compute gradient of the log pseudo-posterior (NUTS block only). + * @param parameters NUTS-dimension parameter vector + * @return Gradient vector (same size as parameters) + */ + arma::vec gradient(const arma::vec& parameters) override; + + /** + * Combined log pseudo-posterior and gradient evaluation. + * @param parameters NUTS-dimension parameter vector + * @return Pair of (log-pseudo-posterior, gradient) + */ + std::pair logp_and_gradient( + const arma::vec& parameters) override; + + /** + * Perform one full Metropolis sweep over all parameter groups. + * @param iteration Current iteration (for Robbins-Monro adaptation) + */ + void do_one_metropolis_step(int iteration = -1) override; + + /** + * Update only continuous precision parameters via Metropolis (hybrid NUTS+MH). + * @param iteration Current iteration (for Robbins-Monro adaptation) + */ + void do_pairwise_continuous_metropolis_step(int iteration = -1); + + /** + * Initialize Metropolis adaptation controllers for proposal-SD tuning. + * Called before warmup begins. + */ + void init_metropolis_adaptation(const WarmupSchedule& schedule) override; + + /** + * Tune proposal SDs via Robbins-Monro (Stage 3b). + * Called every iteration; checks schedule internally. + */ + void tune_proposal_sd(int iteration, const WarmupSchedule& schedule) override; + + /** + * Shuffle edge update order at the start of each iteration. + * Advances the RNG state consistently even when edge selection is off. + */ + void prepare_iteration() override; + + // ========================================================================= + // Edge selection + // ========================================================================= + + /** Perform one sweep of Metropolis-Hastings edge add-delete moves. */ + void update_edge_indicators() override; + + /** Initialize a random graph structure for starting edge selection. */ + void initialize_graph() override; + + /** + * Enable or disable edge-selection proposals. + * @param active true to enable edge add-delete moves + */ + void set_edge_selection_active(bool active) override { + edge_selection_active_ = active; + } + + // ========================================================================= + // Parameter vectorization + // ========================================================================= + + /** + * Dimensionality of the active NUTS parameter space (excludes precision). + * When edge selection is active, excludes parameters for inactive edges. + */ + size_t parameter_dimension() const override; + + /** + * Full NUTS-block dimension (all NUTS params, regardless of edge state). + * Excludes continuous precision. Used for mass-matrix sizing and adaptation. + */ + size_t full_parameter_dimension() const override; + + /** + * Storage dimension (all parameters including continuous precision, + * regardless of edge state). Used for fixed-size sample storage. + */ + size_t storage_dimension() const override; + + /** Get active NUTS parameters as a flat vector (excludes precision). */ + arma::vec get_vectorized_parameters() const override; + + /** Get all NUTS parameters (inactive edges zeroed, excludes precision). */ + arma::vec get_full_vectorized_parameters() const override; + + /** Get all parameters including continuous precision for sample storage. */ + arma::vec get_storage_vectorized_parameters() const override; + + /** Set NUTS parameters from a flat vector (does not touch precision). */ + void set_vectorized_parameters(const arma::vec& params) override; + + /** Get vectorized edge indicators (Gxx upper-tri, Gyy upper-tri, Gxy full). */ + arma::ivec get_vectorized_indicator_parameters() override; + + /** Get active subset of inverse mass diagonal (NUTS params only, excludes precision). */ + arma::vec get_active_inv_mass() const override; + + // ========================================================================= + // Infrastructure + // ========================================================================= + + /** Set random seed for reproducibility. */ + void set_seed(int seed) override; + + /** Clone the model for parallel execution. */ + std::unique_ptr clone() const override; + + /** @return Reference to the model's random number generator. */ + SafeRNG& get_rng() override { return rng_; } + + /** @return Current edge-indicator matrix ((p+q) × (p+q)). */ + const arma::imat& get_edge_indicators() const override { + return edge_indicators_; + } + + /** @return Mutable reference to the prior inclusion-probability matrix. */ + arma::mat& get_inclusion_probability() override { + return inclusion_probability_; + } + + /** @return Total number of variables (p + q). */ + int get_num_variables() const override { + return static_cast(p_ + q_); + } + + /** + * Number of unique off-diagonal pairs in the (p+q) × (p+q) indicator + * matrix: p(p-1)/2 + q(q-1)/2 + p*q. + */ + int get_num_pairwise() const override { + return static_cast(num_pairwise_xx_ + num_pairwise_yy_ + num_cross_); + } + + // ========================================================================= + // Missing data + // ========================================================================= + + /** Impute missing entries from full-conditional distributions. */ + void impute_missing() override; + + /** + * Register missing-data locations for discrete and continuous sub-matrices. + * + * @param missing_discrete M_d x 2 matrix of 0-based (row, col) indices into discrete_observations_ + * @param missing_continuous M_c x 2 matrix of 0-based (row, col) indices into continuous_observations_ + */ + void set_missing_data(const arma::imat& missing_discrete, + const arma::imat& missing_continuous); + +private: + + // ========================================================================= + // Counts and dimensions + // ========================================================================= + + size_t n_; ///< Number of observations + size_t p_; ///< Number of discrete variables + size_t q_; ///< Number of continuous variables + size_t num_main_; ///< Total main-effect params (sum C_s for ord + 2 per BC) + size_t num_pairwise_xx_; ///< p(p-1)/2 + size_t num_pairwise_yy_; ///< q(q-1)/2 + size_t num_cross_; ///< p * q + + // ========================================================================= + // Data + // ========================================================================= + + arma::imat discrete_observations_; ///< Discrete observations (n x p), BC columns centered + arma::mat discrete_observations_dbl_; ///< Double version (post-centering) + arma::mat continuous_observations_; ///< Continuous observations (n x q) + arma::ivec num_categories_; ///< Categories per discrete variable (p-vector) + int max_cats_; ///< max(num_categories) + arma::uvec is_ordinal_variable_; ///< 1 = ordinal, 0 = Blume-Capel (p-vector) + arma::ivec baseline_category_; ///< Reference category per discrete variable (p-vector) + + // ========================================================================= + // Missing data + // ========================================================================= + + arma::imat missing_index_discrete_; ///< M_d x 2 (row, col) for missing discrete entries + arma::imat missing_index_continuous_; ///< M_c x 2 (row, col) for missing continuous entries + bool has_missing_ = false; ///< Whether imputation is active + + // ========================================================================= + // Sufficient statistics + // ========================================================================= + + arma::imat counts_per_category_; ///< (max_cats+1) x p category counts (ordinal only) + arma::imat blume_capel_stats_; ///< 2 x p linear/quadratic sums (BC only) + + // ========================================================================= + // Parameters + // ========================================================================= + + arma::mat main_effects_discrete_; ///< p x max_cats main effects (thresholds or alpha/beta) + arma::vec main_effects_continuous_; ///< q-vector continuous means + arma::mat pairwise_effects_discrete_; ///< p x p discrete interactions (symmetric, zero diag) + arma::mat pairwise_effects_continuous_; ///< q x q SPD precision matrix + arma::mat pairwise_effects_cross_; ///< p x q cross-type interactions + + // ========================================================================= + // Edge indicators + // ========================================================================= + + /// Combined (p+q) x (p+q) indicator matrix. + /// Gxx block: rows [0,p), cols [0,p) -- symmetric, zero diag. + /// Gyy block: rows [p,p+q), cols [p,p+q) -- symmetric, zero diag. + /// Gxy block: rows [0,p), cols [p,p+q) -- full p x q rectangle. + arma::imat edge_indicators_; + arma::mat inclusion_probability_; ///< Prior inclusion probabilities + bool edge_selection_; ///< Enable edge selection + bool edge_selection_active_; ///< Currently in edge selection phase + + // ========================================================================= + // Priors + // ========================================================================= + + double main_alpha_; ///< Beta prior alpha for main effects + double main_beta_; ///< Beta prior beta for main effects + double pairwise_scale_; ///< Cauchy scale for interaction priors + + // ========================================================================= + // Proposal SDs (Robbins-Monro adapted) + // ========================================================================= + + arma::mat proposal_sd_main_discrete_; ///< p x max_cats + arma::vec proposal_sd_main_continuous_; ///< q-vector + arma::mat proposal_sd_pairwise_discrete_; ///< p x p + arma::mat proposal_sd_pairwise_continuous_; ///< q x q + arma::mat proposal_sd_pairwise_cross_; ///< p x q + int total_warmup_ = 0; ///< Stored by init_metropolis_adaptation + + // ========================================================================= + // Cached quantities + // ========================================================================= + + arma::mat cholesky_of_precision_; ///< q x q upper Cholesky R (K_yy = R'R) + arma::mat inv_cholesky_of_precision_; ///< q x q R^{-1} (upper triangular) + arma::mat covariance_continuous_; ///< q x q K_yy^{-1} = R^{-1} R^{-T} + double log_det_precision_; ///< log|K_yy| + arma::mat Theta_; ///< p x p marginal PL interaction matrix + arma::mat conditional_mean_; ///< n x q conditional mean mu_y + 2 X K_xy Sigma_yy + + // Rank-1 Cholesky update workspace + std::array kyy_constants_{}; ///< Reparameterization constants + arma::mat precision_proposal_; ///< q x q scratch for proposed precision + arma::vec kyy_v1_ = {0, -1}; ///< Rank-2 decomposition helper 1 + arma::vec kyy_v2_ = {0, 0}; ///< Rank-2 decomposition helper 2 + arma::vec kyy_vf1_; ///< q-vector, zeroed between uses + arma::vec kyy_vf2_; ///< q-vector, zeroed between uses + arma::vec kyy_u1_; ///< q-vector workspace + arma::vec kyy_u2_; ///< q-vector workspace + + // ========================================================================= + // Gradient cache (populated by ensure_gradient_cache) + // ========================================================================= + + arma::mat discrete_observations_dbl_t_; ///< p x n transpose (BLAS gradient) + arma::vec grad_obs_cache_; ///< Cached observed-data gradient component + arma::imat kxx_index_cache_; ///< p x p map from (i,j) to gradient index + arma::imat kxy_index_cache_; ///< p x q map from (i,j) to gradient index + int main_effects_continuous_grad_offset_ = 0; ///< Offset of main_effects_continuous block in gradient vector + bool gradient_cache_valid_ = false; ///< Whether gradient cache is current + + // ========================================================================= + // Configuration + // ========================================================================= + + bool use_marginal_pl_; ///< true = marginal, false = conditional + + // ========================================================================= + // RNG and edge-update order + // ========================================================================= + + SafeRNG rng_; ///< Per-chain random number generator + arma::uvec edge_order_xx_; ///< Shuffled xx-edge pair indices + arma::uvec edge_order_yy_; ///< Shuffled yy-edge pair indices + arma::uvec edge_order_xy_; ///< Shuffled xy-edge pair indices + + // ========================================================================= + // Private helpers + // ========================================================================= + + /** Count total main-effect parameters across all discrete variables. */ + size_t count_num_main_effects() const; + + /** Compute category counts and BC sufficient statistics from discrete_observations_. */ + void compute_sufficient_statistics(); + + /** Recompute conditional_mean_ from main_effects_continuous_, pairwise_effects_cross_, covariance_continuous_. */ + void recompute_conditional_mean(); + + /** Recompute cholesky_of_precision_, inv_cholesky_of_precision_, covariance_continuous_, log_det_precision_ from pairwise_effects_continuous_. */ + void recompute_pairwise_effects_continuous_decomposition(); + + /** Recompute Theta_ from pairwise_effects_discrete_, pairwise_effects_cross_, covariance_continuous_ (marginal PL only). */ + void recompute_Theta(); + + // ========================================================================= + // Gradient helpers (implemented in mixed_mrf_gradient.cpp) + // ========================================================================= + + /** Rebuild gradient index maps after edge-indicator changes. */ + void ensure_gradient_cache(); + + /** Mark gradient cache as stale (call after edge-indicator changes). */ + void invalidate_gradient_cache(); + + /** Unpack NUTS-vector into temporary parameter matrices (no model mutation). */ + void unvectorize_nuts_to_temps( + const arma::vec& params, + arma::mat& temp_main_discrete, + arma::mat& temp_pairwise_discrete, + arma::vec& temp_main_continuous, + arma::mat& temp_pairwise_cross + ) const; + + // ========================================================================= + // Likelihood functions (implemented in mixed_mrf_likelihoods.cpp) + // ========================================================================= + + /** Conditional OMRF pseudolikelihood for discrete variable s, summed over all n. */ + double log_conditional_omrf(int s) const; + + /** Marginal OMRF pseudolikelihood for discrete variable s, using Theta_. */ + double log_marginal_omrf(int s) const; + + /** Conditional GGM log-likelihood: log f(y | x), using cached decomposition. */ + double log_conditional_ggm() const; + + // ========================================================================= + // MH update functions (implemented in mixed_mrf_metropolis.cpp) + // ========================================================================= + + // --- Rank-1 precision proposal helpers (permutation-free) --- + + // Extract reparameterization constants for the (i,j) off-diagonal precision update. + // Populates kyy_constants_[0..5] from cholesky_of_precision_ and covariance_continuous_. + void get_precision_constants(int i, int j); + + // Constrained diagonal value for a proposed off-diagonal precision element. + double precision_constrained_diagonal(double x) const; + + // Log-likelihood ratio for a proposed off-diagonal precision change (rank-2). + // Assumes precision_proposal_ is already filled by the caller. + double log_ggm_ratio_edge(int i, int j) const; + + // Log-likelihood ratio for a proposed diagonal precision change (rank-1). + // Assumes precision_proposal_ is already filled by the caller. + double log_ggm_ratio_diag(int i) const; + + // Rank-1 Cholesky update after accepting an off-diagonal precision change. + void cholesky_update_after_precision_edge(double old_ij, double old_jj, int i, int j); + + // Rank-1 Cholesky update after accepting a diagonal precision change. + void cholesky_update_after_precision_diag(double old_ii, int i); + + // --- Parameter update sweeps --- + + /** Update one main-effect: main_effects_discrete_(s, c). Ordinal threshold or BC α/β. */ + void update_main_effect(int s, int c, int iteration); + + /** Update one continuous mean: main_effects_continuous_(j). */ + void update_continuous_mean(int j, int iteration); + + /** Update one discrete interaction: pairwise_effects_discrete_(i, j). Symmetric. */ + void update_pairwise_discrete(int i, int j, int iteration); + + /** Update one off-diagonal precision element. Cholesky-based. */ + void update_pairwise_effects_continuous_offdiag(int i, int j, int iteration); + + /** Update one diagonal precision element. Log-scale Cholesky. */ + void update_pairwise_effects_continuous_diag(int i, int iteration); + + /** Update one cross interaction: pairwise_effects_cross_(i, j). */ + void update_pairwise_cross(int i, int j, int iteration); + + // --- Edge-indicator update sweeps --- + + /** Metropolis-Hastings add-delete move for one discrete-discrete edge. */ + void update_edge_indicator_discrete(int i, int j); + + /** Metropolis-Hastings add-delete move for one continuous-continuous edge. */ + void update_edge_indicator_continuous(int i, int j); + + /** Metropolis-Hastings add-delete move for one cross-type edge. */ + void update_edge_indicator_cross(int i, int j); + + // ========================================================================= + // Edge-indicator accessor helpers + // ========================================================================= + + int gxx(int i, int j) const { return edge_indicators_(i, j); } + int gyy(int i, int j) const { return edge_indicators_(p_ + i, p_ + j); } + int gxy(int i, int j) const { return edge_indicators_(i, p_ + j); } + + void set_gxx(int i, int j, int val) { + edge_indicators_(i, j) = val; + edge_indicators_(j, i) = val; + } + void set_gyy(int i, int j, int val) { + edge_indicators_(p_ + i, p_ + j) = val; + edge_indicators_(p_ + j, p_ + i) = val; + } + void set_gxy(int i, int j, int val) { + edge_indicators_(i, p_ + j) = val; + } +}; diff --git a/src/models/omrf/omrf_model.cpp b/src/models/omrf/omrf_model.cpp index 6dbe6189..5fc2c886 100644 --- a/src/models/omrf/omrf_model.cpp +++ b/src/models/omrf/omrf_model.cpp @@ -586,7 +586,7 @@ double OMRFModel::log_pseudoposterior_with_state( double log_post = 0.0; auto log_beta_prior = [this](double x) { - return x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + return x * main_alpha_ - std::log1p(MY_EXP(x)) * (main_alpha_ + main_beta_); }; // Main effect contributions (priors and sufficient statistics) diff --git a/src/models/omrf/omrf_model.h b/src/models/omrf/omrf_model.h index 5d86c138..ef893e9e 100644 --- a/src/models/omrf/omrf_model.h +++ b/src/models/omrf/omrf_model.h @@ -254,7 +254,7 @@ class OMRFModel : public BaseModel { /** * Enable or disable edge-selection proposals. - * @param active true to enable edge birth/death moves + * @param active true to enable edge add-delete moves */ void set_edge_selection_active(bool active) override { edge_selection_active_ = active; } /** @return true when edge-selection proposals are currently active. */ @@ -266,68 +266,68 @@ class OMRFModel : public BaseModel { // ========================================================================= // Data - size_t n_; // Number of observations - size_t p_; // Number of variables - arma::imat observations_; // Categorical observations (n × p) - arma::mat observations_double_; // Observations as double (for efficient matrix ops) - arma::mat observations_double_t_; // Transposed observations (for BLAS pairwise gradient) - arma::ivec num_categories_; // Categories per variable - arma::uvec is_ordinal_variable_; // 1 = ordinal, 0 = Blume-Capel - arma::ivec baseline_category_; // Reference category for Blume-Capel + size_t n_; ///< Number of observations + size_t p_; ///< Number of variables + arma::imat observations_; ///< Categorical observations (n x p) + arma::mat observations_double_; ///< Observations as double (for efficient matrix ops) + arma::mat observations_double_t_; ///< Transposed observations (for BLAS pairwise gradient) + arma::ivec num_categories_; ///< Categories per variable + arma::uvec is_ordinal_variable_; ///< 1 = ordinal, 0 = Blume-Capel + arma::ivec baseline_category_; ///< Reference category for Blume-Capel // Sufficient statistics - arma::imat counts_per_category_; // Category counts (max_cats+1 × p) - arma::imat blume_capel_stats_; // [linear_sum, quadratic_sum] for BC vars (2 × p) - arma::imat pairwise_stats_; // X^T X - arma::mat residual_matrix_; // X * pairwise_effects (n × p) + arma::imat counts_per_category_; ///< Category counts (max_cats+1 x p) + arma::imat blume_capel_stats_; ///< [linear_sum, quadratic_sum] for BC vars (2 x p) + arma::imat pairwise_stats_; ///< X^T X + arma::mat residual_matrix_; ///< X * pairwise_effects (n x p) // Parameters - arma::mat main_effects_; // Main effect parameters (p × max_cats) - arma::mat pairwise_effects_; // Pairwise interaction strengths (p × p, symmetric) - arma::imat edge_indicators_; // Edge inclusion indicators (p × p, symmetric binary) + arma::mat main_effects_; ///< Main effect parameters (p x max_cats) + arma::mat pairwise_effects_; ///< Pairwise interaction strengths (p x p, symmetric) + arma::imat edge_indicators_; ///< Edge inclusion indicators (p x p, symmetric binary) // Priors - arma::mat inclusion_probability_; // Prior inclusion probabilities - double main_alpha_; // Beta prior α - double main_beta_; // Beta prior β - double pairwise_scale_; // Cauchy scale for pairwise effects - arma::mat pairwise_scaling_factors_; // Per-pair scaling factors for Cauchy prior + arma::mat inclusion_probability_; ///< Prior inclusion probabilities + double main_alpha_; ///< Beta prior alpha + double main_beta_; ///< Beta prior beta + double pairwise_scale_; ///< Cauchy scale for pairwise effects + arma::mat pairwise_scaling_factors_; ///< Per-pair scaling factors for Cauchy prior // Model configuration - bool edge_selection_; // Enable edge selection - bool edge_selection_active_; // Currently in edge selection phase + bool edge_selection_; ///< Enable edge selection + bool edge_selection_active_; ///< Currently in edge selection phase // Dimension tracking - size_t num_main_; // Total number of main effect parameters - size_t num_pairwise_; // Number of possible pairwise effects + size_t num_main_; ///< Total number of main effect parameters + size_t num_pairwise_; ///< Number of possible pairwise effects // Proposal SDs (adapted by MetropolisAdaptationController during warmup) - arma::mat proposal_sd_main_; - arma::mat proposal_sd_pairwise_; + arma::mat proposal_sd_main_; ///< Proposal SD for main effects + arma::mat proposal_sd_pairwise_; ///< Proposal SD for pairwise effects // Metropolis adaptation controllers (created by init_metropolis_adaptation) - std::unique_ptr metropolis_main_adapter_; - std::unique_ptr metropolis_pairwise_adapter_; + std::unique_ptr metropolis_main_adapter_; ///< Main-effect adapter + std::unique_ptr metropolis_pairwise_adapter_; ///< Pairwise-effect adapter // RNG - SafeRNG rng_; + SafeRNG rng_; ///< Per-chain random number generator // NUTS/HMC settings - double step_size_; - arma::vec inv_mass_; + double step_size_; ///< Current step size for gradient-based samplers + arma::vec inv_mass_; ///< Inverse mass diagonal // Missing data handling - bool has_missing_; - arma::imat missing_index_; + bool has_missing_; ///< Whether the data contains missing values + arma::imat missing_index_; ///< (row, col) indices of missing entries // Cached gradient components - arma::vec grad_obs_cache_; - arma::imat index_matrix_cache_; - bool gradient_cache_valid_; + arma::vec grad_obs_cache_; ///< Cached observed-data gradient + arma::imat index_matrix_cache_; ///< Cached parameter index map + bool gradient_cache_valid_; ///< Whether the gradient cache is current // Interaction indexing (for edge updates) - arma::imat interaction_index_; - arma::uvec shuffled_edge_order_; // Pre-shuffled order (set in prepare_iteration) + arma::imat interaction_index_; ///< Maps edge pair to index + arma::uvec shuffled_edge_order_; ///< Pre-shuffled order (set in prepare_iteration) // ========================================================================= // Private helper methods diff --git a/src/mrf_prediction.cpp b/src/mrf_prediction.cpp index 5b890d83..f16f98c7 100644 --- a/src/mrf_prediction.cpp +++ b/src/mrf_prediction.cpp @@ -8,22 +8,20 @@ using namespace Rcpp; // GGM Conditional Prediction // ============================================================================ -/** - * Compute conditional Gaussian parameters for a GGM. - * - * For a GGM with precision matrix Omega, the conditional distribution of - * X_j given X_{-j} = x_{-j} is: - * - * X_j | X_{-j} ~ N( -omega_{jj}^{-1} sum_{k != j} omega_{jk} x_k, - * omega_{jj}^{-1} ) - * - * @param observations n x p matrix of observed continuous data - * @param predict_vars 0-based indices of variables to predict - * @param precision p x p precision matrix (Omega) - * - * @return List of n x 2 matrices (one per predicted variable), where - * column 0 = conditional mean, column 1 = conditional SD. - */ +// Compute conditional Gaussian parameters for a GGM. +// +// For a GGM with precision matrix Omega, the conditional distribution of +// X_j given X_{-j} = x_{-j} is: +// +// X_j | X_{-j} ~ N( -omega_{jj}^{-1} sum_{k != j} omega_{jk} x_k, +// omega_{jj}^{-1} ) +// +// @param observations n x p matrix of observed continuous data +// @param predict_vars 0-based indices of variables to predict +// @param precision p x p precision matrix (Omega) +// +// @return List of n x 2 matrices (one per predicted variable), where +// column 0 = conditional mean, column 1 = conditional SD. // [[Rcpp::export]] Rcpp::List compute_conditional_ggm( const arma::mat& observations, @@ -159,3 +157,145 @@ Rcpp::List compute_conditional_probs( return result; } + + +// ============================================================================ +// Mixed MRF Conditional Prediction +// ============================================================================ + +// Compute conditional distributions for a mixed MRF. +// +// For discrete variables: P(x_s = c | x_{-s}, y) using the conditional OMRF. +// For continuous variables: E(y_j | y_{-j}, x) and SD(y_j | y_{-j}, x) +// using the conditional GGM. +// +// @param x_observations n x p integer matrix of discrete data. +// @param y_observations n x q numeric matrix of continuous data. +// @param predict_vars 0-based indices into the combined (p+q) variable list. +// Indices 0..p-1 refer to discrete variables, +// p..p+q-1 refer to continuous variables. +// @param Kxx p x p pairwise interactions (diagonal zero). +// @param Kxy p x q cross interactions. +// @param Kyy q x q precision matrix. +// @param mux p x max_cats threshold / Blume-Capel parameters. +// @param muy q-vector of continuous means. +// @param num_categories p-vector: categories per discrete variable. +// @param variable_type p-vector: "ordinal" or "blume-capel". +// @param baseline_category p-vector. +// +// @return List of prediction matrices (one per predicted variable). +// For discrete: n x (num_cats+1) probability matrix. +// For continuous: n x 2 matrix (mean, sd). +// [[Rcpp::export]] +Rcpp::List compute_conditional_mixed( + const arma::imat& x_observations, + const arma::mat& y_observations, + const arma::ivec& predict_vars, + const arma::mat& Kxx, + const arma::mat& Kxy, + const arma::mat& Kyy, + const arma::mat& mux, + const arma::vec& muy, + const arma::ivec& num_categories, + const Rcpp::StringVector& variable_type, + const arma::ivec& baseline_category +) { + int n = x_observations.n_rows; + int p = x_observations.n_cols; + int q = y_observations.n_cols; + int num_predict_vars = predict_vars.n_elem; + + // Convert discrete to double (centered for rest-score computation) + arma::mat x_dbl = arma::conv_to::from(x_observations); + + Rcpp::List result(num_predict_vars); + + for (int pv = 0; pv < num_predict_vars; pv++) { + int var_idx = predict_vars[pv]; + + if (var_idx < p) { + // --- Discrete variable: P(x_s = c | x_{-s}, y) --- + int s = var_idx; + int Cs = num_categories[s]; + + // Rest score from discrete neighbours (centered by baseline) + arma::vec rest_discrete(n, arma::fill::zeros); + for (int k = 0; k < p; k++) { + if (k == s) continue; + arma::vec obs_k = x_dbl.col(k); + double ref_k = static_cast(baseline_category[k]); + rest_discrete += (obs_k - ref_k) * Kxx(k, s); + } + + // Rest score from continuous (factor of 2) + arma::vec rest_continuous(n, arma::fill::zeros); + for (int j = 0; j < q; j++) { + rest_continuous += 2.0 * Kxy(s, j) * y_observations.col(j); + } + + arma::vec rest_scores = rest_discrete + rest_continuous; + + arma::mat probs; + if (std::string(variable_type[s]) == "blume-capel") { + int ref = baseline_category[s]; + double lin_eff = mux(s, 0); + double quad_eff = mux(s, 1); + arma::vec bound; + probs = compute_probs_blume_capel( + rest_scores, lin_eff, quad_eff, ref, Cs, bound + ); + } else { + arma::vec main_param = mux.row(s).head(Cs).t(); + arma::vec bound(n, arma::fill::zeros); + for (int c = 0; c < Cs; c++) { + arma::vec exps = main_param[c] + (c + 1) * rest_scores; + bound = arma::max(bound, exps); + } + probs = compute_probs_ordinal(main_param, rest_scores, bound, Cs); + } + + Rcpp::NumericMatrix prob_mat(n, Cs + 1); + for (int i = 0; i < n; i++) { + for (int c = 0; c <= Cs; c++) { + prob_mat(i, c) = probs(i, c); + } + } + result[pv] = prob_mat; + + } else { + // --- Continuous variable: y_j | y_{-j}, x --- + int j = var_idx - p; + + double omega_jj = Kyy(j, j); + double cond_var = 1.0 / omega_jj; + double cond_sd = std::sqrt(cond_var); + + // Contribution from other continuous variables: + // -sum_{k != j} Kyy[j,k] * (y_k - muy_k) + arma::vec lp_continuous(n, arma::fill::zeros); + for (int k = 0; k < q; k++) { + if (k == j) continue; + lp_continuous -= Kyy(j, k) * (y_observations.col(k) - muy(k)); + } + + // Contribution from discrete variables (factor of 2): + // sum_s 2 * Kxy(s, j) * x_s_centered + arma::vec lp_discrete(n, arma::fill::zeros); + for (int s = 0; s < p; s++) { + double ref_s = static_cast(baseline_category[s]); + lp_discrete += 2.0 * Kxy(s, j) * (x_dbl.col(s) - ref_s); + } + + arma::vec cond_means = muy(j) + cond_var * (lp_continuous + lp_discrete); + + Rcpp::NumericMatrix out(n, 2); + for (int i = 0; i < n; i++) { + out(i, 0) = cond_means[i]; + out(i, 1) = cond_sd; + } + result[pv] = out; + } + } + + return result; +} \ No newline at end of file diff --git a/src/mrf_simulation.cpp b/src/mrf_simulation.cpp index 059b6145..23db2cae 100644 --- a/src/mrf_simulation.cpp +++ b/src/mrf_simulation.cpp @@ -15,32 +15,30 @@ using namespace RcppParallel; // MRF Simulation Core Functions (Thread-Safe) // ============================================================================ -/** - * Function: simulate_mrf - * - * Simulates observations from a Markov Random Field using Gibbs sampling. - * Supports both ordinal and Blume-Capel variable types. - * - * Inputs: - * - num_states: Number of observations to simulate. - * - num_variables: Number of variables in the MRF. - * - num_categories: Number of categories per variable (on top of baseline 0). - * - pairwise: Symmetric pairwise interaction matrix (diagonal ignored). - * - main: Main effect parameters (variables x max_categories). - * For ordinal: threshold parameters for categories 1..K. - * For Blume-Capel: column 0 = linear (alpha), column 1 = quadratic (beta). - * - variable_type: Type of each variable ("ordinal" or "blume-capel"). - * - baseline_category: Baseline category for Blume-Capel variables (0 for ordinal). - * - iter: Number of Gibbs sampling iterations. - * - rng: Thread-safe random number generator. - * - * Returns: - * - Integer matrix of simulated observations (num_states x num_variables). - * - * Notes: - * - Diagonal of pairwise matrix is explicitly ignored (set to zero internally). - * - For ordinal variables, baseline_category should be 0. - */ +// Function: simulate_mrf +// +// Simulates observations from a Markov Random Field using Gibbs sampling. +// Supports both ordinal and Blume-Capel variable types. +// +// Inputs: +// - num_states: Number of observations to simulate. +// - num_variables: Number of variables in the MRF. +// - num_categories: Number of categories per variable (on top of baseline 0). +// - pairwise: Symmetric pairwise interaction matrix (diagonal ignored). +// - main: Main effect parameters (variables x max_categories). +// For ordinal: threshold parameters for categories 1..K. +// For Blume-Capel: column 0 = linear (alpha), column 1 = quadratic (beta). +// - variable_type: Type of each variable ("ordinal" or "blume-capel"). +// - baseline_category: Baseline category for Blume-Capel variables (0 for ordinal). +// - iter: Number of Gibbs sampling iterations. +// - rng: Thread-safe random number generator. +// +// Returns: +// - Integer matrix of simulated observations (num_states x num_variables). +// +// Notes: +// - Diagonal of pairwise matrix is explicitly ignored (set to zero internally). +// - For ordinal variables, baseline_category should be 0. arma::imat simulate_mrf( int num_states, int num_variables, @@ -236,25 +234,23 @@ IntegerMatrix sample_bcomrf_gibbs(int num_states, // GGM Simulation (Direct Multivariate Normal Sampling) // ============================================================================ -/** - * Simulate observations from a Gaussian Graphical Model. - * - * Given a precision matrix Omega, draws num_states observations from - * N(means, Omega^{-1}) using the Cholesky factorization of the covariance. - * - * Algorithm: - * 1. Compute Sigma = Omega^{-1} via arma::inv_sympd. - * 2. Cholesky decompose: L = chol(Sigma, "lower") so Sigma = L L'. - * 3. Draw Z ~ N(0, I) of size (num_states x p). - * 4. Return X = ones * means' + Z * L'. - * - * @param num_states Number of observations to simulate. - * @param precision p x p positive-definite precision matrix (Omega). - * @param means p-vector of variable means (can be all zeros). - * @param rng Thread-safe random number generator. - * - * @return num_states x p matrix of simulated continuous observations. - */ +// Simulate observations from a Gaussian Graphical Model. +// +// Given a precision matrix Omega, draws num_states observations from +// N(means, Omega^{-1}) using the Cholesky factorization of the covariance. +// +// Algorithm: +// 1. Compute Sigma = Omega^{-1} via arma::inv_sympd. +// 2. Cholesky decompose: L = chol(Sigma, "lower") so Sigma = L L'. +// 3. Draw Z ~ N(0, I) of size (num_states x p). +// 4. Return X = ones * means' + Z * L'. +// +// @param num_states Number of observations to simulate. +// @param precision p x p positive-definite precision matrix (Omega). +// @param means p-vector of variable means (can be all zeros). +// @param rng Thread-safe random number generator. +// +// @return num_states x p matrix of simulated continuous observations. arma::mat simulate_ggm( int num_states, const arma::mat& precision, @@ -286,16 +282,14 @@ arma::mat simulate_ggm( // R Interface for GGM Simulation (standalone simulate_ggm) // ============================================================================ -/** - * R-callable wrapper for single GGM simulation. - * - * @param num_states Number of observations to simulate. - * @param precision p x p precision matrix (Omega). - * @param means p-vector of means (default zeros). - * @param seed Random seed for reproducibility. - * - * @return num_states x p numeric matrix. - */ +// R-callable wrapper for single GGM simulation. +// +// @param num_states Number of observations to simulate. +// @param precision p x p precision matrix (Omega). +// @param means p-vector of means (default zeros). +// @param seed Random seed for reproducibility. +// +// @return num_states x p numeric matrix. // [[Rcpp::export]] NumericMatrix sample_ggm_direct(int num_states, NumericMatrix precision, @@ -333,12 +327,9 @@ struct SimulationResult { }; -/** - * Worker class for parallel simulation across posterior draws - */ +// Worker class for parallel simulation across posterior draws class SimulationWorker : public RcppParallel::Worker { public: - // Input data const arma::mat& pairwise_samples; const arma::mat& main_samples; const arma::ivec& draw_indices; @@ -349,14 +340,8 @@ class SimulationWorker : public RcppParallel::Worker { const arma::ivec& baseline_category; const int iter; const arma::ivec& main_param_counts; - - // RNGs const std::vector& draw_rngs; - - // Progress ProgressManager& pm; - - // Output std::vector& results; SimulationWorker( @@ -390,13 +375,15 @@ class SimulationWorker : public RcppParallel::Worker { {} void operator()(std::size_t begin, std::size_t end) { + bool is_main = (begin == 0); for (std::size_t i = begin; i < end; ++i) { + if (pm.shouldExit()) return; + SimulationResult result; result.draw_index = draw_indices[i]; result.error = false; try { - // Get RNG for this draw SafeRNG rng = draw_rngs[i]; // Reconstruct pairwise matrix from flat vector @@ -421,7 +408,6 @@ class SimulationWorker : public RcppParallel::Worker { } } - // Simulate observations via Gibbs sampling result.observations = simulate_mrf( num_states, num_variables, @@ -443,32 +429,28 @@ class SimulationWorker : public RcppParallel::Worker { } results[i] = result; - - // Update progress - treating each draw as a "chain" for progress display - pm.update(0); + if (is_main) pm.update(0); } } }; -/** - * Run parallel simulations across posterior draws - * - * @param pairwise_samples Matrix of pairwise samples (ndraws x n_pairwise) - * @param main_samples Matrix of main/threshold samples (ndraws x n_main) - * @param draw_indices 1-based indices of which draws to use - * @param num_states Number of observations to simulate per draw - * @param num_variables Number of variables - * @param num_categories Number of categories per variable - * @param variable_type Type of each variable ("ordinal" or "blume-capel") - * @param baseline_category Baseline category for each variable - * @param iter Number of Gibbs iterations per simulation - * @param nThreads Number of parallel threads - * @param seed Random seed - * @param progress_type Progress bar type (0=none, 1=total, 2=per-chain) - * - * @return List of simulation results (each is an integer matrix) - */ +// Run parallel simulations across posterior draws +// +// @param pairwise_samples Matrix of pairwise samples (ndraws x n_pairwise) +// @param main_samples Matrix of main/threshold samples (ndraws x n_main) +// @param draw_indices 1-based indices of which draws to use +// @param num_states Number of observations to simulate per draw +// @param num_variables Number of variables +// @param num_categories Number of categories per variable +// @param variable_type Type of each variable ("ordinal" or "blume-capel") +// @param baseline_category Baseline category for each variable +// @param iter Number of Gibbs iterations per simulation +// @param nThreads Number of parallel threads +// @param seed Random seed +// @param progress_type Progress bar type (0=none, 1=total, 2=per-chain) +// +// @return List of simulation results (each is an integer matrix) // [[Rcpp::export]] Rcpp::List run_simulation_parallel( const arma::mat& pairwise_samples, @@ -517,10 +499,8 @@ Rcpp::List run_simulation_parallel( // Prepare results storage std::vector results(ndraws); - // Single-chain progress (we report across all draws as one unit) ProgressManager pm(1, ndraws, 0, 50, progress_type); - // Create worker SimulationWorker worker( pairwise_samples, main_samples, @@ -537,12 +517,11 @@ Rcpp::List run_simulation_parallel( results ); - // Run in parallel { - tbb::global_control control(tbb::global_control::max_allowed_parallelism, nThreads); + tbb::global_control control( + tbb::global_control::max_allowed_parallelism, nThreads); parallelFor(0, ndraws, worker); } - pm.finish(); // Convert results to R list @@ -572,13 +551,11 @@ struct GGMSimulationResult { }; -/** - * Worker class for parallel GGM simulation across posterior draws - */ +// Worker class for parallel GGM simulation across posterior draws class GGMSimulationWorker : public RcppParallel::Worker { public: - const arma::mat& pairwise_samples; // ndraws x p*(p-1)/2 - const arma::mat& main_samples; // ndraws x p (diagonal precisions) + const arma::mat& pairwise_samples; + const arma::mat& main_samples; const arma::ivec& draw_indices; const int num_states; const int num_variables; @@ -610,7 +587,10 @@ class GGMSimulationWorker : public RcppParallel::Worker { {} void operator()(std::size_t begin, std::size_t end) { + bool is_main = (begin == 0); for (std::size_t i = begin; i < end; ++i) { + if (pm.shouldExit()) return; + GGMSimulationResult result; result.draw_index = draw_indices[i]; result.error = false; @@ -628,7 +608,6 @@ class GGMSimulationWorker : public RcppParallel::Worker { idx++; } } - // Diagonal for (int v = 0; v < num_variables; v++) { precision(v, v) = main_samples(draw_indices[i] - 1, v); } @@ -649,27 +628,25 @@ class GGMSimulationWorker : public RcppParallel::Worker { } results[i] = result; - pm.update(0); + if (is_main) pm.update(0); } } }; -/** - * Run parallel GGM simulations across posterior draws. - * - * @param pairwise_samples Matrix of off-diagonal precision samples (ndraws x p*(p-1)/2) - * @param main_samples Matrix of diagonal precision samples (ndraws x p) - * @param draw_indices 1-based indices of which draws to use - * @param num_states Number of observations to simulate per draw - * @param num_variables Number of variables - * @param means p-vector of variable means - * @param nThreads Number of parallel threads - * @param seed Random seed - * @param progress_type Progress bar type (0=none, 1=total, 2=per-chain) - * - * @return List of simulation results (each is a numeric matrix n x p) - */ +// Run parallel GGM simulations across posterior draws. +// +// @param pairwise_samples Matrix of off-diagonal precision samples (ndraws x p*(p-1)/2) +// @param main_samples Matrix of diagonal precision samples (ndraws x p) +// @param draw_indices 1-based indices of which draws to use +// @param num_states Number of observations to simulate per draw +// @param num_variables Number of variables +// @param means p-vector of variable means +// @param nThreads Number of parallel threads +// @param seed Random seed +// @param progress_type Progress bar type (0=none, 1=total, 2=per-chain) +// +// @return List of simulation results (each is a numeric matrix n x p) // [[Rcpp::export]] Rcpp::List run_ggm_simulation_parallel( const arma::mat& pairwise_samples, @@ -706,10 +683,10 @@ Rcpp::List run_ggm_simulation_parallel( ); { - tbb::global_control control(tbb::global_control::max_allowed_parallelism, nThreads); + tbb::global_control control( + tbb::global_control::max_allowed_parallelism, nThreads); parallelFor(0, ndraws, worker); } - pm.finish(); Rcpp::List output(ndraws); @@ -721,5 +698,495 @@ Rcpp::List run_ggm_simulation_parallel( output[i] = Rcpp::wrap(results[i].observations); } + return output; +} + + +// ============================================================================ +// Mixed MRF Simulation (Block Gibbs: Discrete + Continuous) +// ============================================================================ + +// Simulate observations from a mixed MRF using block Gibbs sampling. +// +// Each iteration updates all discrete variables from their full conditional +// given (x_{-s}, y), then updates all continuous variables from +// y | x ~ N(mu_y + 2 * x * Kxy * Kyy^{-1}, Kyy^{-1}). +// +// @param num_states Number of observations to simulate. +// @param Kxx p x p discrete pairwise interactions (diagonal zero). +// @param Kxy p x q cross interactions. +// @param Kyy q x q SPD continuous precision. +// @param mux p x max_cats threshold / Blume-Capel parameters. +// @param muy q-vector of continuous means. +// @param num_categories p-vector: number of categories per discrete variable +// (on top of baseline 0). +// @param variable_type p-vector: "ordinal" or "blume-capel". +// @param baseline_category p-vector: reference category per variable. +// @param iter Number of Gibbs iterations for burn-in. +// @param rng Thread-safe RNG. +// @param x_out Output: n x p integer matrix of discrete observations. +// @param y_out Output: n x q numeric matrix of continuous observations. +void simulate_mixed_mrf( + int num_states, + const arma::mat& Kxx, + const arma::mat& Kxy, + const arma::mat& Kyy, + const arma::mat& mux, + const arma::vec& muy, + const arma::ivec& num_categories, + const std::vector& variable_type, + const arma::ivec& baseline_category, + int iter, + SafeRNG& rng, + arma::imat& x_out, + arma::mat& y_out) { + + int p = Kxx.n_rows; + int q = Kyy.n_rows; + + // Precompute Kyy decomposition + arma::mat Sigma_y = arma::inv_sympd(Kyy); + arma::mat L_Sigma = arma::chol(Sigma_y, "lower"); + + // Precompute Kxy * Kyy^{-1} for conditional mean + arma::mat Kxy_Sigma = Kxy * Sigma_y; // p x q + + // Copy Kxx with zeroed diagonal for safety + arma::mat Kxx_safe = Kxx; + Kxx_safe.diag().zeros(); + + // Generate each observation independently + for (int obs = 0; obs < num_states; obs++) { + // Initialize discrete variables uniformly + arma::ivec x_current(p); + for (int s = 0; s < p; s++) { + int max_cat = num_categories(s); + x_current(s) = static_cast(runif(rng) * (max_cat + 1)); + if (x_current(s) > max_cat) x_current(s) = max_cat; + } + + // Initialize continuous from marginal N(muy, Sigma_y) + arma::vec z = arma_rnorm_vec(rng, q); + arma::vec y_current = muy + L_Sigma * z; + + // Gibbs iterations + for (int it = 0; it < iter; it++) { + + // --- Update discrete variables from f(x_s | x_{-s}, y) --- + for (int s = 0; s < p; s++) { + int Cs = num_categories(s); + + // Rest score from discrete neighbours + double rest_discrete = 0.0; + for (int k = 0; k < p; k++) { + if (k != s) { + int obs_k = x_current(k); + int ref_k = baseline_category(k); + rest_discrete += (obs_k - ref_k) * Kxx_safe(k, s); + } + } + + // Rest score from continuous (factor of 2) + double rest_continuous = 0.0; + for (int j = 0; j < q; j++) { + rest_continuous += 2.0 * Kxy(s, j) * y_current(j); + } + + double rest = rest_discrete + rest_continuous; + + if (variable_type[s] == "blume-capel") { + int ref = baseline_category(s); + double alpha = mux(s, 0); + double beta = mux(s, 1); + + // Compute log-probabilities for categories 0..Cs + arma::vec log_probs(Cs + 1); + double max_lp = -std::numeric_limits::infinity(); + for (int c = 0; c <= Cs; c++) { + int d = c - ref; + log_probs(c) = alpha * d + beta * d * d + d * rest; + if (log_probs(c) > max_lp) max_lp = log_probs(c); + } + + // Stabilize and convert to cumulative probabilities + double cumsum = 0.0; + arma::vec cum_probs(Cs + 1); + for (int c = 0; c <= Cs; c++) { + cumsum += std::exp(log_probs(c) - max_lp); + cum_probs(c) = cumsum; + } + + // Sample + double u = runif(rng) * cumsum; + int sampled = 0; + while (sampled < Cs && u > cum_probs(sampled)) sampled++; + x_current(s) = sampled; + + } else { + // Ordinal: category 0 is reference with log-prob = 0 + arma::vec log_probs(Cs + 1); + log_probs(0) = 0.0; + double max_lp = 0.0; + for (int c = 1; c <= Cs; c++) { + log_probs(c) = mux(s, c - 1) + c * rest; + if (log_probs(c) > max_lp) max_lp = log_probs(c); + } + + double cumsum = 0.0; + arma::vec cum_probs(Cs + 1); + for (int c = 0; c <= Cs; c++) { + cumsum += std::exp(log_probs(c) - max_lp); + cum_probs(c) = cumsum; + } + + double u = runif(rng) * cumsum; + int sampled = 0; + while (sampled < Cs && u > cum_probs(sampled)) sampled++; + x_current(s) = sampled; + } + } + + // --- Update continuous variables from y | x --- + // y | x ~ N(muy + 2 * Kxy_Sigma^T * x_centered, Sigma_y) + // Compute centered discrete observations + arma::vec x_centered(p); + for (int s = 0; s < p; s++) { + x_centered(s) = static_cast(x_current(s) - baseline_category(s)); + } + + // Conditional mean: muy + 2 * Sigma_y * Kxy^T * x_centered + // = muy + 2 * Kxy_Sigma^T * x_centered + arma::vec cond_mean = muy + 2.0 * Kxy_Sigma.t() * x_centered; + + // Sample y ~ N(cond_mean, Sigma_y) + arma::vec z2 = arma_rnorm_vec(rng, q); + y_current = cond_mean + L_Sigma * z2; + } + + // Store final state + x_out.row(obs) = x_current.t(); + y_out.row(obs) = y_current.t(); + } +} + + +// ============================================================================ +// R Interface for Mixed MRF Simulation (standalone) +// ============================================================================ + +// [[Rcpp::export]] +Rcpp::List sample_mixed_mrf_gibbs( + int num_states, + NumericMatrix Kxx_r, + NumericMatrix Kxy_r, + NumericMatrix Kyy_r, + NumericMatrix mux_r, + NumericVector muy_r, + IntegerVector num_categories_r, + Rcpp::StringVector variable_type_r, + IntegerVector baseline_category_r, + int iter, + int seed) { + + SafeRNG rng(seed); + + int p = Kxx_r.nrow(); + int q = Kyy_r.nrow(); + + arma::mat Kxx = Rcpp::as(Kxx_r); + arma::mat Kxy = Rcpp::as(Kxy_r); + arma::mat Kyy = Rcpp::as(Kyy_r); + arma::mat mux = Rcpp::as(mux_r); + arma::vec muy = Rcpp::as(muy_r); + arma::ivec num_categories = Rcpp::as(num_categories_r); + arma::ivec baseline_category = Rcpp::as(baseline_category_r); + + // Convert variable_type + std::vector variable_type(p); + for (int i = 0; i < p; i++) { + variable_type[i] = Rcpp::as(variable_type_r[i]); + if (variable_type[i] != "blume-capel") { + baseline_category[i] = 0; + } + } + + arma::imat x_out(num_states, p); + arma::mat y_out(num_states, q); + + simulate_mixed_mrf( + num_states, Kxx, Kxy, Kyy, mux, muy, + num_categories, variable_type, baseline_category, + iter, rng, x_out, y_out + ); + + Rcpp::checkUserInterrupt(); + + return Rcpp::List::create( + Rcpp::Named("x") = Rcpp::wrap(x_out), + Rcpp::Named("y") = Rcpp::wrap(y_out) + ); +} + + +// ============================================================================ +// Parallel Mixed MRF Simulation for simulate.bgms() with Posterior Draws +// ============================================================================ + +struct MixedSimulationResult { + arma::imat x_observations; + arma::mat y_observations; + int draw_index; + bool error; + std::string error_msg; +}; + + +class MixedSimulationWorker : public RcppParallel::Worker { +public: + // Input: posterior samples as flat vectors (one row per draw) + const arma::mat& mux_samples; // ndraws x n_mux + const arma::mat& kxx_samples; // ndraws x p*(p-1)/2 + const arma::mat& muy_samples; // ndraws x q + const arma::mat& kyy_samples; // ndraws x q*(q+1)/2 + const arma::mat& kxy_samples; // ndraws x p*q + + const arma::ivec& draw_indices; + const int num_states; + const int p, q; + const arma::ivec& num_categories; + const std::vector& variable_type; + const arma::ivec& baseline_category; + const int iter; + const arma::ivec& mux_param_counts; + + const std::vector& draw_rngs; + ProgressManager& pm; + std::vector& results; + + MixedSimulationWorker( + const arma::mat& mux_samples, + const arma::mat& kxx_samples, + const arma::mat& muy_samples, + const arma::mat& kyy_samples, + const arma::mat& kxy_samples, + const arma::ivec& draw_indices, + int num_states, + int p, int q, + const arma::ivec& num_categories, + const std::vector& variable_type, + const arma::ivec& baseline_category, + int iter, + const arma::ivec& mux_param_counts, + const std::vector& draw_rngs, + ProgressManager& pm, + std::vector& results + ) : + mux_samples(mux_samples), + kxx_samples(kxx_samples), + muy_samples(muy_samples), + kyy_samples(kyy_samples), + kxy_samples(kxy_samples), + draw_indices(draw_indices), + num_states(num_states), + p(p), q(q), + num_categories(num_categories), + variable_type(variable_type), + baseline_category(baseline_category), + iter(iter), + mux_param_counts(mux_param_counts), + draw_rngs(draw_rngs), + pm(pm), + results(results) + {} + + void operator()(std::size_t begin, std::size_t end) { + bool is_main = (begin == 0); + for (std::size_t i = begin; i < end; i++) { + if (pm.shouldExit()) return; + + MixedSimulationResult result; + result.draw_index = draw_indices[i]; + result.error = false; + + try { + SafeRNG rng = draw_rngs[i]; + int draw_row = draw_indices[i] - 1; // 1-based to 0-based + + // Reconstruct mux (p x max_cats) + int max_cats = arma::max(mux_param_counts); + arma::mat mux(p, max_cats, arma::fill::zeros); + int idx = 0; + for (int s = 0; s < p; s++) { + for (int c = 0; c < mux_param_counts[s]; c++) { + mux(s, c) = mux_samples(draw_row, idx); + idx++; + } + } + + // Reconstruct Kxx (p x p symmetric, zero diagonal) + arma::mat Kxx(p, p, arma::fill::zeros); + idx = 0; + for (int col = 0; col < p; col++) { + for (int row = col + 1; row < p; row++) { + Kxx(row, col) = kxx_samples(draw_row, idx); + Kxx(col, row) = kxx_samples(draw_row, idx); + idx++; + } + } + + // Reconstruct muy (q-vector) + arma::vec muy(q); + for (int j = 0; j < q; j++) { + muy(j) = muy_samples(draw_row, j); + } + + // Reconstruct Kyy (q x q symmetric, upper triangle including diagonal) + arma::mat Kyy(q, q, arma::fill::zeros); + idx = 0; + for (int col = 0; col < q; col++) { + for (int row = col; row < q; row++) { + if (row == col) { + Kyy(row, col) = kyy_samples(draw_row, idx); + } else { + Kyy(row, col) = kyy_samples(draw_row, idx); + Kyy(col, row) = kyy_samples(draw_row, idx); + } + idx++; + } + } + + // Reconstruct Kxy (p x q, row-major) + arma::mat Kxy(p, q, arma::fill::zeros); + idx = 0; + for (int s = 0; s < p; s++) { + for (int j = 0; j < q; j++) { + Kxy(s, j) = kxy_samples(draw_row, idx); + idx++; + } + } + + // Simulate + result.x_observations.set_size(num_states, p); + result.y_observations.set_size(num_states, q); + + simulate_mixed_mrf( + num_states, Kxx, Kxy, Kyy, mux, muy, + num_categories, variable_type, baseline_category, + iter, rng, + result.x_observations, result.y_observations + ); + + } catch (const std::exception& e) { + result.error = true; + result.error_msg = e.what(); + } catch (...) { + result.error = true; + result.error_msg = "Unknown error"; + } + + results[i] = result; + if (is_main) pm.update(0); + } + } +}; + + +// Run parallel mixed MRF simulations across posterior draws. +// +// The R layer splits the flat parameter vector into the 5 component matrices +// and passes them as separate sample matrices. +// +// @param mux_samples ndraws x n_mux +// @param kxx_samples ndraws x p*(p-1)/2 +// @param muy_samples ndraws x q +// @param kyy_samples ndraws x q*(q+1)/2 +// @param kxy_samples ndraws x p*q +// @param draw_indices 1-based indices of which draws to use +// @param num_states Number of observations per draw +// @param p Number of discrete variables +// @param q Number of continuous variables +// @param num_categories p-vector: categories per discrete variable +// @param variable_type_r p-vector: "ordinal" or "blume-capel" +// @param baseline_category p-vector +// @param iter Gibbs burn-in iterations +// @param nThreads Number of threads +// @param seed Random seed +// @param progress_type Progress bar type +// +// @return List of lists, each containing "x" (integer matrix) and "y" (numeric matrix). +// [[Rcpp::export]] +Rcpp::List run_mixed_simulation_parallel( + const arma::mat& mux_samples, + const arma::mat& kxx_samples, + const arma::mat& muy_samples, + const arma::mat& kyy_samples, + const arma::mat& kxy_samples, + const arma::ivec& draw_indices, + int num_states, + int p, + int q, + const arma::ivec& num_categories, + const Rcpp::StringVector& variable_type_r, + const arma::ivec& baseline_category, + int iter, + int nThreads, + int seed, + int progress_type) { + + int ndraws = draw_indices.n_elem; + + std::vector variable_type(p); + arma::ivec baseline_category_safe = baseline_category; + for (int i = 0; i < p; i++) { + variable_type[i] = Rcpp::as(variable_type_r[i]); + if (variable_type[i] != "blume-capel") { + baseline_category_safe[i] = 0; + } + } + + // Compute mux param counts per discrete variable + arma::ivec mux_param_counts(p); + for (int s = 0; s < p; s++) { + if (variable_type[s] == "blume-capel") { + mux_param_counts[s] = 2; + } else { + mux_param_counts[s] = num_categories[s]; + } + } + + std::vector draw_rngs(ndraws); + for (int d = 0; d < ndraws; d++) { + draw_rngs[d] = SafeRNG(seed + d); + } + + std::vector results(ndraws); + ProgressManager pm(1, ndraws, 0, 50, progress_type); + + MixedSimulationWorker worker( + mux_samples, kxx_samples, muy_samples, kyy_samples, kxy_samples, + draw_indices, num_states, p, q, + num_categories, variable_type, baseline_category_safe, + iter, mux_param_counts, draw_rngs, pm, results + ); + + { + tbb::global_control control( + tbb::global_control::max_allowed_parallelism, nThreads); + parallelFor(0, ndraws, worker); + } + pm.finish(); + + Rcpp::List output(ndraws); + for (int i = 0; i < ndraws; i++) { + if (results[i].error) { + Rcpp::stop("Error in mixed MRF simulation draw %d: %s", + results[i].draw_index, results[i].error_msg.c_str()); + } + output[i] = Rcpp::List::create( + Rcpp::Named("x") = Rcpp::wrap(results[i].x_observations), + Rcpp::Named("y") = Rcpp::wrap(results[i].y_observations) + ); + } + return output; } \ No newline at end of file diff --git a/src/priors/sbm_edge_prior.cpp b/src/priors/sbm_edge_prior.cpp index 81c92ac5..ce84b00c 100644 --- a/src/priors/sbm_edge_prior.cpp +++ b/src/priors/sbm_edge_prior.cpp @@ -59,24 +59,22 @@ arma::mat add_row_col_block_prob_matrix(arma::mat X, -/** - * Function: log_likelihood_mfm_sbm - * - * Computes the log-likelihood contribution for a single node under the - * Mixture of Finite Mixtures Stochastic Block Model (MFM-SBM). Evaluates - * the probability of observed edges between the node and all other nodes - * given their cluster assignments and cluster connection probabilities. - * - * Inputs: - * - cluster_assign: Vector of cluster assignments for all nodes. - * - cluster_probs: Matrix of edge probabilities between clusters. - * - indicator: Upper-triangular matrix of edge indicators (1 = edge present). - * - node: Index of the node whose contribution is computed. - * - no_variables: Total number of nodes in the network. - * - * Returns: - * - Log-likelihood contribution for the specified node. - */ +// Function: log_likelihood_mfm_sbm +// +// Computes the log-likelihood contribution for a single node under the +// Mixture of Finite Mixtures Stochastic Block Model (MFM-SBM). Evaluates +// the probability of observed edges between the node and all other nodes +// given their cluster assignments and cluster connection probabilities. +// +// Inputs: +// - cluster_assign: Vector of cluster assignments for all nodes. +// - cluster_probs: Matrix of edge probabilities between clusters. +// - indicator: Upper-triangular matrix of edge indicators (1 = edge present). +// - node: Index of the node whose contribution is computed. +// - no_variables: Total number of nodes in the network. +// +// Returns: +// - Log-likelihood contribution for the specified node. double log_likelihood_mfm_sbm(arma::uvec cluster_assign, arma::mat cluster_probs, arma::umat indicator, @@ -103,27 +101,25 @@ double log_likelihood_mfm_sbm(arma::uvec cluster_assign, return output; } -/** - * Function: log_marginal_mfm_sbm - * - * Computes the log-marginal likelihood contribution for a single node under - * the MFM-SBM after integrating out cluster connection probabilities. Uses - * Beta-Bernoulli conjugacy with separate hyperparameters for within-cluster - * and between-cluster edges. - * - * Inputs: - * - cluster_assign: Vector of cluster assignments for all nodes. - * - indicator: Upper-triangular matrix of edge indicators (1 = edge present). - * - node: Index of the node whose contribution is computed. - * - no_variables: Total number of nodes in the network. - * - beta_bernoulli_alpha: Alpha hyperparameter for within-cluster edges. - * - beta_bernoulli_beta: Beta hyperparameter for within-cluster edges. - * - beta_bernoulli_alpha_between: Alpha hyperparameter for between-cluster edges. - * - beta_bernoulli_beta_between: Beta hyperparameter for between-cluster edges. - * - * Returns: - * - Log-marginal likelihood contribution for the specified node. - */ +// Function: log_marginal_mfm_sbm +// +// Computes the log-marginal likelihood contribution for a single node under +// the MFM-SBM after integrating out cluster connection probabilities. Uses +// Beta-Bernoulli conjugacy with separate hyperparameters for within-cluster +// and between-cluster edges. +// +// Inputs: +// - cluster_assign: Vector of cluster assignments for all nodes. +// - indicator: Upper-triangular matrix of edge indicators (1 = edge present). +// - node: Index of the node whose contribution is computed. +// - no_variables: Total number of nodes in the network. +// - beta_bernoulli_alpha: Alpha hyperparameter for within-cluster edges. +// - beta_bernoulli_beta: Beta hyperparameter for within-cluster edges. +// - beta_bernoulli_alpha_between: Alpha hyperparameter for between-cluster edges. +// - beta_bernoulli_beta_between: Beta hyperparameter for between-cluster edges. +// +// Returns: +// - Log-marginal likelihood contribution for the specified node. double log_marginal_mfm_sbm(arma::uvec cluster_assign, arma::umat indicator, arma::uword node, diff --git a/src/sample_mixed.cpp b/src/sample_mixed.cpp new file mode 100644 index 00000000..af712cef --- /dev/null +++ b/src/sample_mixed.cpp @@ -0,0 +1,150 @@ +// sample_mixed.cpp - R interface for Mixed MRF model sampling +// +// Uses the unified MCMC runner infrastructure to sample from models with +// both discrete (ordinal / Blume-Capel) and continuous variables. +// Supports MH and hybrid-nuts (NUTS for unconstrained block + MH for Kyy) +// samplers, with optional edge selection. +#include +#include +#include + +#include "models/mixed/mixed_mrf_model.h" +#include "utils/progress_manager.h" +#include "utils/common_helpers.h" +#include "priors/edge_prior.h" +#include "mcmc/execution/chain_result.h" +#include "mcmc/execution/chain_runner.h" +#include "mcmc/execution/sampler_config.h" + +// R-exported function to sample from a Mixed MRF model. +// +// @param inputFromR List with model specification: +// discrete_observations (integer matrix n x p), +// continuous_observations (numeric matrix n x q), +// num_categories (integer vector, length p), +// is_ordinal_variable (integer vector, length p), +// baseline_category (integer vector, length p), +// main_alpha, main_beta, pairwise_scale (doubles), +// pseudolikelihood (string: "conditional" or "marginal") +// @param prior_inclusion_prob Prior inclusion probabilities ((p+q) x (p+q) matrix) +// @param initial_edge_indicators Initial edge indicators ((p+q) x (p+q) integer matrix) +// @param no_iter Number of post-warmup iterations +// @param no_warmup Number of warmup iterations +// @param no_chains Number of parallel chains +// @param edge_selection Whether to do edge selection (spike-and-slab) +// @param seed Random seed +// @param no_threads Number of threads for parallel execution +// @param progress_type Progress bar type +// @param edge_prior Edge prior type +// @param beta_bernoulli_alpha Beta-Bernoulli alpha hyperparameter +// @param beta_bernoulli_beta Beta-Bernoulli beta hyperparameter +// @param beta_bernoulli_alpha_between SBM between-cluster alpha +// @param beta_bernoulli_beta_between SBM between-cluster beta +// @param dirichlet_alpha Dirichlet alpha for SBM +// @param lambda Lambda for SBM +// @param sampler_type Sampler type string ("mh", "hybrid-nuts", etc.) +// @param target_acceptance Target acceptance rate for gradient-based samplers +// @param max_tree_depth Maximum tree depth for NUTS +// @param num_leapfrogs Number of leapfrog steps for HMC +// @param na_impute Whether to impute missing data +// @param missing_index_discrete Matrix of missing discrete indices (n_miss x 2, 0-based) +// @param missing_index_continuous Matrix of missing continuous indices (n_miss x 2, 0-based) +// +// @return List with per-chain results including samples and diagnostics +// [[Rcpp::export]] +Rcpp::List sample_mixed_mrf( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const int no_chains, + const bool edge_selection, + const int seed, + const int no_threads, + const int progress_type, + const std::string& edge_prior = "Bernoulli", + const double beta_bernoulli_alpha = 1.0, + const double beta_bernoulli_beta = 1.0, + const double beta_bernoulli_alpha_between = 1.0, + const double beta_bernoulli_beta_between = 1.0, + const double dirichlet_alpha = 1.0, + const double lambda = 1.0, + const std::string& sampler_type = "mh", + const double target_acceptance = 0.80, + const int max_tree_depth = 10, + const int num_leapfrogs = 100, + const bool na_impute = false, + const Rcpp::Nullable missing_index_discrete_nullable = R_NilValue, + const Rcpp::Nullable missing_index_continuous_nullable = R_NilValue +) { + // Extract model inputs from R list + arma::imat discrete_obs = Rcpp::as(inputFromR["discrete_observations"]); + arma::mat continuous_obs = Rcpp::as(inputFromR["continuous_observations"]); + arma::ivec num_categories = Rcpp::as(inputFromR["num_categories"]); + arma::uvec is_ordinal = Rcpp::as(inputFromR["is_ordinal_variable"]); + arma::ivec baseline_cat = Rcpp::as(inputFromR["baseline_category"]); + double main_alpha = Rcpp::as(inputFromR["main_alpha"]); + double main_beta = Rcpp::as(inputFromR["main_beta"]); + double pairwise_scale = Rcpp::as(inputFromR["pairwise_scale"]); + std::string pseudolikelihood = Rcpp::as(inputFromR["pseudolikelihood"]); + + // Create model + MixedMRFModel model( + discrete_obs, continuous_obs, + num_categories, is_ordinal, baseline_cat, + prior_inclusion_prob, initial_edge_indicators, + edge_selection, pseudolikelihood, + main_alpha, main_beta, pairwise_scale, + seed + ); + + // Set up missing data imputation + if(na_impute) { + arma::imat missing_disc, missing_cont; + if(missing_index_discrete_nullable.isNotNull()) { + missing_disc = Rcpp::as( + Rcpp::IntegerMatrix(missing_index_discrete_nullable.get())); + } + if(missing_index_continuous_nullable.isNotNull()) { + missing_cont = Rcpp::as( + Rcpp::IntegerMatrix(missing_index_continuous_nullable.get())); + } + model.set_missing_data(missing_disc, missing_cont); + } + + // Create edge prior + EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); + auto edge_prior_obj = create_edge_prior( + edge_prior_enum, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda + ); + + // Configure sampler + SamplerConfig config; + config.sampler_type = sampler_type; + config.no_iter = no_iter; + config.no_warmup = no_warmup; + config.edge_selection = edge_selection; + config.seed = seed; + config.na_impute = na_impute; + config.target_acceptance = target_acceptance; + config.max_tree_depth = max_tree_depth; + config.num_leapfrogs = num_leapfrogs; + + // Set up progress manager + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + + // Run MCMC using unified infrastructure + std::vector results = run_mcmc_sampler( + model, *edge_prior_obj, config, no_chains, no_threads, pm); + + // Convert to R list format + Rcpp::List output = convert_results_to_list(results); + + pm.finish(); + + return output; +} diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index 14198eb9..f2b96bcc 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -1,9 +1,7 @@ -/** - * sample_omrf.cpp - R interface for OMRF model sampling - * - * Uses the unified MCMC runner infrastructure to sample from OMRF models. - * Supports MH, NUTS, and HMC samplers with optional edge selection. - */ +// sample_omrf.cpp - R interface for OMRF model sampling +// +// Uses the unified MCMC runner infrastructure to sample from OMRF models. +// Supports MH, NUTS, and HMC samplers with optional edge selection. #include #include #include @@ -16,35 +14,33 @@ #include "mcmc/execution/chain_runner.h" #include "mcmc/execution/sampler_config.h" -/** - * R-exported function to sample from an OMRF model - * - * @param inputFromR List with model specification - * @param prior_inclusion_prob Prior inclusion probabilities (p x p matrix) - * @param initial_edge_indicators Initial edge indicators (p x p integer matrix) - * @param no_iter Number of post-warmup iterations - * @param no_warmup Number of warmup iterations - * @param no_chains Number of parallel chains - * @param edge_selection Whether to do edge selection (spike-and-slab) - * @param sampler_type "mh", "nuts", or "hmc" - * @param seed Random seed - * @param no_threads Number of threads for parallel execution - * @param progress_type Progress bar type - * @param edge_prior Edge prior type: "Bernoulli", "Beta-Bernoulli", "Stochastic-Block" - * @param na_impute Whether to impute missing data - * @param missing_index Matrix of missing data indices (n_missing x 2, 0-based) - * @param beta_bernoulli_alpha Beta-Bernoulli alpha hyperparameter - * @param beta_bernoulli_beta Beta-Bernoulli beta hyperparameter - * @param beta_bernoulli_alpha_between SBM between-cluster alpha - * @param beta_bernoulli_beta_between SBM between-cluster beta - * @param dirichlet_alpha Dirichlet alpha for SBM - * @param lambda Lambda for SBM - * @param target_acceptance Target acceptance rate for NUTS/HMC (default: 0.8) - * @param max_tree_depth Maximum tree depth for NUTS (default: 10) - * @param num_leapfrogs Number of leapfrog steps for HMC (default: 10) - * - * @return List with per-chain results including samples and diagnostics - */ +// R-exported function to sample from an OMRF model +// +// @param inputFromR List with model specification +// @param prior_inclusion_prob Prior inclusion probabilities (p x p matrix) +// @param initial_edge_indicators Initial edge indicators (p x p integer matrix) +// @param no_iter Number of post-warmup iterations +// @param no_warmup Number of warmup iterations +// @param no_chains Number of parallel chains +// @param edge_selection Whether to do edge selection (spike-and-slab) +// @param sampler_type "mh", "nuts", or "hmc" +// @param seed Random seed +// @param no_threads Number of threads for parallel execution +// @param progress_type Progress bar type +// @param edge_prior Edge prior type: "Bernoulli", "Beta-Bernoulli", "Stochastic-Block" +// @param na_impute Whether to impute missing data +// @param missing_index Matrix of missing data indices (n_missing x 2, 0-based) +// @param beta_bernoulli_alpha Beta-Bernoulli alpha hyperparameter +// @param beta_bernoulli_beta Beta-Bernoulli beta hyperparameter +// @param beta_bernoulli_alpha_between SBM between-cluster alpha +// @param beta_bernoulli_beta_between SBM between-cluster beta +// @param dirichlet_alpha Dirichlet alpha for SBM +// @param lambda Lambda for SBM +// @param target_acceptance Target acceptance rate for NUTS/HMC (default: 0.8) +// @param max_tree_depth Maximum tree depth for NUTS (default: 10) +// @param num_leapfrogs Number of leapfrog steps for HMC (default: 10) +// +// @return List with per-chain results including samples and diagnostics // [[Rcpp::export]] Rcpp::List sample_omrf( const Rcpp::List& inputFromR, diff --git a/src/utils/common_helpers.h b/src/utils/common_helpers.h index 963d28ae..1179eeab 100644 --- a/src/utils/common_helpers.h +++ b/src/utils/common_helpers.h @@ -26,8 +26,19 @@ inline int count_num_main_effects(const arma::ivec& num_categories, return n_params; } -enum UpdateMethod { adaptive_metropolis, hamiltonian_mc, nuts }; +/// MCMC update method. +enum UpdateMethod { + adaptive_metropolis, ///< Robbins-Monro adaptive Metropolis-Hastings + hamiltonian_mc, ///< Fixed-trajectory Hamiltonian Monte Carlo + nuts ///< No-U-Turn Sampler +}; +/** + * Convert a string identifier to an UpdateMethod enum value. + * + * @param update_method One of "adaptive-metropolis", "hamiltonian-mc", "nuts". + * @return Corresponding UpdateMethod value. + */ inline UpdateMethod update_method_from_string(const std::string& update_method) { if (update_method == "adaptive-metropolis") return adaptive_metropolis; @@ -41,8 +52,21 @@ inline UpdateMethod update_method_from_string(const std::string& update_method) throw std::invalid_argument("Invalid update_method: " + update_method); } -enum EdgePrior { Stochastic_Block, Beta_Bernoulli, Bernoulli, Not_Applicable }; +/// Edge inclusion prior type. +enum EdgePrior { + Stochastic_Block, ///< MFM Stochastic Block Model + Beta_Bernoulli, ///< Shared Beta-Bernoulli prior + Bernoulli, ///< Fixed Bernoulli prior + Not_Applicable ///< No edge prior (all edges included) +}; +/** + * Convert a string identifier to an EdgePrior enum value. + * + * @param edge_prior One of "Stochastic-Block", "Beta-Bernoulli", + * "Bernoulli", "Not Applicable". + * @return Corresponding EdgePrior value. + */ inline EdgePrior edge_prior_from_string(const std::string& edge_prior) { if (edge_prior == "Stochastic-Block") return Stochastic_Block; diff --git a/src/utils/print_mutex.h b/src/utils/print_mutex.h index 12f9f21f..0f9a4900 100644 --- a/src/utils/print_mutex.h +++ b/src/utils/print_mutex.h @@ -3,19 +3,21 @@ #include +/** + * Return a process-global mutex for thread-safe console output. + * + * Include this header and lock the returned mutex before printing + * from parallel code: + * @code + * { + * tbb::mutex::scoped_lock lock(get_print_mutex()); + * std::cout << "message" << std::endl; + * } + * @endcode + */ inline tbb::mutex& get_print_mutex() { static tbb::mutex m; return m; } -#endif // PRINT_MUTEX_H - -// Add this header to the parallel code you wish to print from -// + the below code to print in parallel code: -// -// { -// tbb::mutex::scoped_lock lock(get_print_mutex()); -// std::cout -// << "print " -// << std::endl; -// } \ No newline at end of file +#endif // PRINT_MUTEX_H \ No newline at end of file diff --git a/src/utils/variable_helpers.cpp b/src/utils/variable_helpers.cpp index 586501cd..b3a745a4 100644 --- a/src/utils/variable_helpers.cpp +++ b/src/utils/variable_helpers.cpp @@ -359,6 +359,13 @@ LogZAndProbs compute_logZ_and_probs_ordinal( const arma::vec eM = ARMA_MY_EXP(main_param); + // The fast block computes eM[c] * pow % eB where the intermediate + // eM[c] * pow = exp(main_param(c) + (c+1)*rest) can overflow even when + // the final product exp(main_param(c) + (c+1)*rest - bound) is finite. + // Reduce the fast-block threshold by max(|main_param|) to prevent this. + const double max_abs_main = arma::max(arma::abs(main_param)); + const double FAST_LIM = std::max(0.0, EXP_BOUND - max_abs_main); + auto do_fast_block = [&](arma::uword i0, arma::uword i1) { auto P = result.probs.rows(i0, i1).cols(1, num_cats); arma::vec r = residual_score.rows(i0, i1); @@ -396,10 +403,10 @@ LogZAndProbs compute_logZ_and_probs_ordinal( const double* bp = bound.memptr(); arma::uword i = 0; while (i < N) { - const bool fast = !(bp[i] < -EXP_BOUND || bp[i] > EXP_BOUND); + const bool fast = !(bp[i] < -FAST_LIM || bp[i] > FAST_LIM); arma::uword j = i + 1; while (j < N) { - const bool fast_j = !(bp[j] < -EXP_BOUND || bp[j] > EXP_BOUND); + const bool fast_j = !(bp[j] < -FAST_LIM || bp[j] > FAST_LIM); if (fast_j != fast) break; j++; } @@ -488,14 +495,19 @@ LogZAndProbs compute_logZ_and_probs_blume_capel( result.log_Z.rows(i0, i1) = bb + ARMA_MY_LOG(denom); }; + // Same intermediate overflow guard as the ordinal function: + // exp_theta[c] * pow can overflow before the implicit cancellation with b. + const double max_abs_theta = arma::max(arma::abs(theta)); + const double THETA_LIM = std::max(0.0, EXP_BOUND - max_abs_theta); + const double* bp = b.memptr(); const double* pp = pow_bound.memptr(); arma::uword i = 0; while (i < N) { - const bool fast_i = (std::abs(bp[i]) <= EXP_BOUND) && (std::abs(pp[i]) <= EXP_BOUND); + const bool fast_i = (std::abs(bp[i]) <= EXP_BOUND) && (std::abs(pp[i]) <= THETA_LIM); arma::uword j = i + 1; while (j < N) { - const bool fast_j = (std::abs(bp[j]) <= EXP_BOUND) && (std::abs(pp[j]) <= EXP_BOUND); + const bool fast_j = (std::abs(bp[j]) <= EXP_BOUND) && (std::abs(pp[j]) <= THETA_LIM); if (fast_j != fast_i) break; j++; } diff --git a/src/utils/variable_helpers.h b/src/utils/variable_helpers.h index b915683e..b20390ba 100644 --- a/src/utils/variable_helpers.h +++ b/src/utils/variable_helpers.h @@ -5,58 +5,62 @@ #include "math/explog_macros.h" -// ----------------------------------------------------------------------------- -// Struct to hold both log-normalizer and probabilities from joint computation. -// Used by logp_and_gradient to avoid duplicate probability/denominator calculations. -// ----------------------------------------------------------------------------- +/** + * Holds both log-normalizer and probabilities from joint computation. + * + * Used by logp_and_gradient to avoid duplicate probability/denominator + * calculations. + */ struct LogZAndProbs { - arma::vec log_Z; // log-normalizer for each person - arma::mat probs; // (num_persons x num_cats+1) probability matrix + /// Log-normalizer for each person. + arma::vec log_Z; + /// Probability matrix (num_persons x num_cats+1). + arma::mat probs; }; -// ----------------------------------------------------------------------------- -// Compute a numerically stable sum of the form: -// -// denom = exp(-bound) + sum_{cat=0}^{K-1} exp(main_effect_param(cat) -// + (cat + 1) * residual_score - bound) -// -// but evaluated efficiently using precomputed exponentials: -// -// exp_r = exp(residual_score) -// exp_m = exp(main_effect_param) -// denom = exp(-bound) * ( 1 + sum_c exp_m[c] * exp_r^(c+1) ) -// -// If non-finite values arise (overflow, underflow, NaN), a safe fallback -// recomputes the naive version using direct exponentials. -// ---------------------------------------------------------------------------- +/** + * Compute a numerically stable sum of the form: + * + * denom = exp(-bound) + sum_{cat=0}^{K-1} exp(main_effect_param(cat) + * + (cat + 1) * residual_score - bound) + * + * but evaluated efficiently using precomputed exponentials: + * + * exp_r = exp(residual_score) + * exp_m = exp(main_effect_param) + * denom = exp(-bound) * ( 1 + sum_c exp_m[c] * exp_r^(c+1) ) + * + * If non-finite values arise (overflow, underflow, NaN), a safe fallback + * recomputes the naive version using direct exponentials. + */ arma::vec compute_denom_ordinal( const arma::vec& residual, const arma::vec& main_eff, const arma::vec& bound ); -// ----------------------------------------------------------------------------- -// Compute denom = Σ_c exp( θ(c) + c*r - b ), with -// θ(c) = lin_eff*(c-ref) + quad_eff*(c-ref)^2 -// b = max_c( θ(c) + c*r ) (vectorized) -// -// Two modes: -// -// FAST (preexp + power-chain): -// denom = Σ_c exp_theta[c] * exp(-b) * exp(r)^c -// Used only when all exponent terms are safe: -// |b| ≤ EXP_BOUND, -// underflow_bound ≥ -EXP_BOUND, -// num_cats*r - b ≤ EXP_BOUND. -// This guarantees the recursive pow-chain stays finite. -// -// SAFE (direct evaluation): -// denom = Σ_c exp(θ(c) + c*r - b) -// Used whenever any FAST-condition fails. Slower but always stable. -// -// FAST gives identical results when safe, otherwise SAFE is used. -// ----------------------------------------------------------------------------- +/** + * Compute denom = Sigma_c exp( theta(c) + c*r - b ), with + * theta(c) = lin_eff*(c-ref) + quad_eff*(c-ref)^2 + * b = max_c( theta(c) + c*r ) (vectorized) + * + * Two modes: + * + * FAST (preexp + power-chain): + * denom = Sigma_c exp_theta[c] * exp(-b) * exp(r)^c + * Used only when all exponent terms are safe: + * |b| <= EXP_BOUND, + * underflow_bound >= -EXP_BOUND, + * num_cats*r - b <= EXP_BOUND. + * This guarantees the recursive pow-chain stays finite. + * + * SAFE (direct evaluation): + * denom = Sigma_c exp(theta(c) + c*r - b) + * Used whenever any FAST-condition fails. Slower but always stable. + * + * FAST gives identical results when safe, otherwise SAFE is used. + */ arma::vec compute_denom_blume_capel( const arma::vec& residual, const double lin_eff, @@ -87,25 +91,25 @@ arma::mat compute_probs_ordinal( int num_cats ); -// ----------------------------------------------------------------------------- -// Blume–Capel probabilities, numerically stable via FAST/SAFE split. -// -// Model: -// θ(c) = lin_eff * (c - ref) + quad_eff * (c - ref)^2, c = 0..num_cats -// exps_i(c) = θ(c) + c * r_i -// b_i = max_c exps_i(c) -// -// Probabilities: -// p_i(c) ∝ exp( exps_i(c) - b_i ) -// -// FAST (preexp + power-chain, same bounds as compute_denom_blume_capel): -// used when |b_i| ≤ EXP_BOUND and pow_bound_i = num_cats * r_i - b_i ≤ EXP_BOUND -// -// SAFE (direct): -// used otherwise: direct exp(θ(c) + (c-ref) * r_i - b_i) -// -// Under these conditions, denom is finite and > 0, so no one-hot fallback. -// ----------------------------------------------------------------------------- +/** + * Blume-Capel probabilities, numerically stable via FAST/SAFE split. + * + * Model: + * theta(c) = lin_eff * (c - ref) + quad_eff * (c - ref)^2, c = 0..num_cats + * exps_i(c) = theta(c) + c * r_i + * b_i = max_c exps_i(c) + * + * Probabilities: + * p_i(c) proportional to exp( exps_i(c) - b_i ) + * + * FAST (preexp + power-chain, same bounds as compute_denom_blume_capel): + * used when |b_i| <= EXP_BOUND and pow_bound_i = num_cats * r_i - b_i <= EXP_BOUND + * + * SAFE (direct): + * used otherwise: direct exp(theta(c) + (c-ref) * r_i - b_i) + * + * Under these conditions, denom is finite and > 0, so no one-hot fallback. + */ arma::mat compute_probs_blume_capel( const arma::vec& residual, const double lin_eff, @@ -115,10 +119,11 @@ arma::mat compute_probs_blume_capel( arma::vec& b ); -// ----------------------------------------------------------------------------- -// Joint computation of log-normalizer and probabilities for ordinal variables. -// Avoids redundant computation by computing both in a single pass. -// ----------------------------------------------------------------------------- +/** + * Joint computation of log-normalizer and probabilities for ordinal variables. + * + * Avoids redundant computation by computing both in a single pass. + */ LogZAndProbs compute_logZ_and_probs_ordinal( const arma::vec& main_param, const arma::vec& residual_score, @@ -126,10 +131,11 @@ LogZAndProbs compute_logZ_and_probs_ordinal( int num_cats ); -// ----------------------------------------------------------------------------- -// Joint computation of log-normalizer and probabilities for Blume-Capel variables. -// Avoids redundant computation by computing both in a single pass. -// ----------------------------------------------------------------------------- +/** + * Joint computation of log-normalizer and probabilities for Blume-Capel variables. + * + * Avoids redundant computation by computing both in a single pass. + */ LogZAndProbs compute_logZ_and_probs_blume_capel( const arma::vec& residual, const double lin_eff, diff --git a/dev/bitwise_compliance/generate_fixtures.R b/tests/compliance/generate_fixtures.R similarity index 88% rename from dev/bitwise_compliance/generate_fixtures.R rename to tests/compliance/generate_fixtures.R index 062f218f..da38590f 100644 --- a/dev/bitwise_compliance/generate_fixtures.R +++ b/tests/compliance/generate_fixtures.R @@ -11,16 +11,16 @@ # for discrete (ordinal / binary / Blume-Capel) models. # # Usage: -# Rscript dev/bitwise_compliance/generate_fixtures.R +# Rscript tests/compliance/generate_fixtures.R # # Output: -# dev/fixtures/compliance/ — one .rds per configuration + manifest.rds +# tests/compliance/fixtures/ — one .rds per configuration + manifest.rds # # ============================================================================== library(callr) -fixture_dir = file.path("dev", "fixtures", "compliance") +fixture_dir = file.path("tests", "compliance", "fixtures") dir.create(fixture_dir, recursive = TRUE, showWarnings = FALSE) # ============================================================================== @@ -46,7 +46,7 @@ installed_version = callr::r( ) cat("Installed version:", installed_version, "\n") -if (installed_version != "0.1.6.3") { +if(installed_version != "0.1.6.3") { cat("WARNING: Expected 0.1.6.3, got", installed_version, "\n") cat("Fixtures will be tagged with the actual version.\n") } @@ -68,9 +68,7 @@ if (installed_version != "0.1.6.3") { # ============================================================================== bgm_configs = list( - # --- Wenchuan ordinal --- - list( id = "bgm_wenchuan_nuts_bernoulli", desc = "bgm: Wenchuan 6v, NUTS, Bernoulli, edge_sel", @@ -306,9 +304,7 @@ bgm_configs = list( # ============================================================================== compare_configs = list( - # --- Wenchuan ordinal (split into two groups) --- - list( id = "cmp_wenchuan_nuts_bernoulli", desc = "bgmCompare: Wenchuan, NUTS, Bernoulli, diff_sel", @@ -480,8 +476,8 @@ prepare_datasets = function() { data(Boredom, package = "bgms", envir = environment()) wenchuan_small = Wenchuan[, 1:6] - adhd_small = ADHD[, 2:7] - boredom_small = Boredom[, 2:7] + adhd_small = ADHD[, 2:7] + boredom_small = Boredom[, 2:7] # Wenchuan with NAs (deterministic injection) wenchuan_na = wenchuan_small @@ -534,41 +530,44 @@ prepare_datasets = function() { extract_bgm_fixture = function(fit, config) { list( - id = config$id, + id = config$id, desc = config$desc, - posterior_summary_main = fit$posterior_summary_main, - posterior_summary_pairwise = fit$posterior_summary_pairwise, + posterior_summary_main = fit$posterior_summary_main, + posterior_summary_pairwise = fit$posterior_summary_pairwise, posterior_summary_indicator = fit$posterior_summary_indicator, - posterior_mean_main = fit$posterior_mean_main, - posterior_mean_pairwise = fit$posterior_mean_pairwise, + posterior_mean_main = fit$posterior_mean_main, + posterior_mean_pairwise = fit$posterior_mean_pairwise, posterior_mean_indicator = fit$posterior_mean_indicator, - raw_main_chain1 = fit$raw_samples$main[[1]], - raw_pairwise_chain1 = fit$raw_samples$pairwise[[1]], - raw_indicator_chain1 = if (!is.null(fit$raw_samples$indicator)) - fit$raw_samples$indicator[[1]] else NULL, + raw_main_chain1 = fit$raw_samples$main[[1]], + raw_pairwise_chain1 = fit$raw_samples$pairwise[[1]], + raw_indicator_chain1 = if(!is.null(fit$raw_samples$indicator)) { + fit$raw_samples$indicator[[1]] + } else { + NULL + }, nuts_diag = fit$nuts_diag, posterior_coclustering_matrix = fit$posterior_coclustering_matrix, - posterior_mean_allocations = fit$posterior_mean_allocations, + posterior_mean_allocations = fit$posterior_mean_allocations, bgms_version = as.character(packageVersion("bgms")) ) } extract_compare_fixture = function(fit, config) { list( - id = config$id, + id = config$id, desc = config$desc, - posterior_summary_main_baseline = fit$posterior_summary_main_baseline, - posterior_summary_pairwise_baseline = fit$posterior_summary_pairwise_baseline, - posterior_summary_main_differences = fit$posterior_summary_main_differences, + posterior_summary_main_baseline = fit$posterior_summary_main_baseline, + posterior_summary_pairwise_baseline = fit$posterior_summary_pairwise_baseline, + posterior_summary_main_differences = fit$posterior_summary_main_differences, posterior_summary_pairwise_differences = fit$posterior_summary_pairwise_differences, - posterior_summary_indicator = fit$posterior_summary_indicator, - posterior_mean_main_baseline = fit$posterior_mean_main_baseline, - posterior_mean_pairwise_baseline = fit$posterior_mean_pairwise_baseline, - posterior_mean_main_differences = fit$posterior_mean_main_differences, + posterior_summary_indicator = fit$posterior_summary_indicator, + posterior_mean_main_baseline = fit$posterior_mean_main_baseline, + posterior_mean_pairwise_baseline = fit$posterior_mean_pairwise_baseline, + posterior_mean_main_differences = fit$posterior_mean_main_differences, posterior_mean_pairwise_differences = fit$posterior_mean_pairwise_differences, - posterior_mean_indicator = fit$posterior_mean_indicator, + posterior_mean_indicator = fit$posterior_mean_indicator, raw_samples = fit$raw_samples, - nuts_diag = fit$nuts_diag, + nuts_diag = fit$nuts_diag, bgms_version = as.character(packageVersion("bgms")) ) } @@ -579,8 +578,8 @@ extract_compare_fixture = function(fit, config) { resolve_args = function(args, datasets) { resolved = args - for (nm in names(resolved)) { - if (is.character(resolved[[nm]]) && resolved[[nm]] %in% names(datasets)) { + for(nm in names(resolved)) { + if(is.character(resolved[[nm]]) && resolved[[nm]] %in% names(datasets)) { resolved[[nm]] = datasets[[resolved[[nm]]]] } } @@ -594,28 +593,31 @@ datasets = prepare_datasets() cat(sprintf("\nGenerating %d bgm fixtures...\n", length(bgm_configs))) bgm_manifest = list() -for (config in bgm_configs) { +for(config in bgm_configs) { cat(sprintf(" [%s] %s ... ", config$id, config$desc)) resolved = resolve_args(config$args, datasets) - result = tryCatch({ - callr::r( - function(args, lib_path, extract_fn) { - .libPaths(c(lib_path, .libPaths())) - library(bgms, lib.loc = lib_path) - set.seed(args$seed) - fit = do.call(bgm, args) - extract_fn(fit, list(id = "tmp", desc = "tmp")) - }, - args = list(args = resolved, lib_path = cran_lib, extract_fn = extract_bgm_fixture), - show = FALSE - ) - }, error = function(e) { - cat(sprintf("ERROR: %s\n", conditionMessage(e))) - NULL - }) - - if (!is.null(result)) { + result = tryCatch( + { + callr::r( + function(args, lib_path, extract_fn) { + .libPaths(c(lib_path, .libPaths())) + library(bgms, lib.loc = lib_path) + set.seed(args$seed) + fit = do.call(bgm, args) + extract_fn(fit, list(id = "tmp", desc = "tmp")) + }, + args = list(args = resolved, lib_path = cran_lib, extract_fn = extract_bgm_fixture), + show = FALSE + ) + }, + error = function(e) { + cat(sprintf("ERROR: %s\n", conditionMessage(e))) + NULL + } + ) + + if(!is.null(result)) { result$id = config$id result$desc = config$desc result$bgms_version = installed_version @@ -634,28 +636,31 @@ for (config in bgm_configs) { cat(sprintf("\nGenerating %d bgmCompare fixtures...\n", length(compare_configs))) compare_manifest = list() -for (config in compare_configs) { +for(config in compare_configs) { cat(sprintf(" [%s] %s ... ", config$id, config$desc)) resolved = resolve_args(config$args, datasets) - result = tryCatch({ - callr::r( - function(args, lib_path, extract_fn) { - .libPaths(c(lib_path, .libPaths())) - library(bgms, lib.loc = lib_path) - set.seed(args$seed) - fit = do.call(bgmCompare, args) - extract_fn(fit, list(id = "tmp", desc = "tmp")) - }, - args = list(args = resolved, lib_path = cran_lib, extract_fn = extract_compare_fixture), - show = FALSE - ) - }, error = function(e) { - cat(sprintf("ERROR: %s\n", conditionMessage(e))) - NULL - }) - - if (!is.null(result)) { + result = tryCatch( + { + callr::r( + function(args, lib_path, extract_fn) { + .libPaths(c(lib_path, .libPaths())) + library(bgms, lib.loc = lib_path) + set.seed(args$seed) + fit = do.call(bgmCompare, args) + extract_fn(fit, list(id = "tmp", desc = "tmp")) + }, + args = list(args = resolved, lib_path = cran_lib, extract_fn = extract_compare_fixture), + show = FALSE + ) + }, + error = function(e) { + cat(sprintf("ERROR: %s\n", conditionMessage(e))) + NULL + } + ) + + if(!is.null(result)) { result$id = config$id result$desc = config$desc result$bgms_version = installed_version diff --git a/dev/bitwise_compliance/test_compliance.R b/tests/compliance/test_compliance.R similarity index 87% rename from dev/bitwise_compliance/test_compliance.R rename to tests/compliance/test_compliance.R index 6a552521..f35a1825 100644 --- a/dev/bitwise_compliance/test_compliance.R +++ b/tests/compliance/test_compliance.R @@ -6,10 +6,10 @@ # fixture set and verifies bitwise-identical output. # # Usage: -# Rscript dev/bitwise_compliance/test_compliance.R +# Rscript tests/compliance/test_compliance.R # # Prerequisites: -# Rscript dev/bitwise_compliance/generate_fixtures.R +# Rscript tests/compliance/generate_fixtures.R # # Exit code: # 0 = all pass, 1 = any fail @@ -70,14 +70,29 @@ # mismatch was a regression introduced and fixed within PR #78; it was # never in a CRAN release. # +# 7. Intermediate-overflow guard in compute_logZ_and_probs_ordinal and +# compute_logZ_and_probs_blume_capel (ACCEPTED — not a bug): +# Commit 04b9562 tightened the fast/slow block threshold from EXP_BOUND +# (709) to FAST_LIM = max(0, EXP_BOUND - max_abs_main) to prevent +# intermediate overflow in exp(main_param(c) + (c+1)*rest) before the +# cancellation with exp(-bound). Both code paths are mathematically +# identical but differ at floating-point level. During HMC leapfrog +# integration, parameters can temporarily reach extreme values where +# max_abs_main is large enough that FAST_LIM < EXP_BOUND, reclassifying +# some observations between the fast (vectorized) and slow (per-element) +# paths. The resulting floating-point perturbation cascades through the +# fixed-step leapfrog integrator. This is needed for mixed MRF models +# where Theta_ss is absorbed into main_param. The affected configs use +# structure-only comparison against CRAN fixtures. +# # ============================================================================== library(bgms) -fixture_dir = file.path("dev", "fixtures", "compliance") +fixture_dir = file.path("tests", "compliance", "fixtures") if(!file.exists(file.path(fixture_dir, "manifest.rds"))) { - stop("No fixtures found. Run dev/bitwise_compliance/generate_fixtures.R first.") + stop("No fixtures found. Run tests/compliance/generate_fixtures.R first.") } manifest = readRDS(file.path(fixture_dir, "manifest.rds")) @@ -93,8 +108,8 @@ data(ADHD) data(Boredom) wenchuan_small = Wenchuan[, 1:6] -adhd_small = ADHD[, 2:7] -boredom_small = Boredom[, 2:7] +adhd_small = ADHD[, 2:7] +boredom_small = Boredom[, 2:7] wenchuan_na = wenchuan_small set.seed(999) @@ -379,36 +394,39 @@ resolve_args = function(args) { extract_bgm_actual = function(fit) { list( - posterior_summary_main = fit$posterior_summary_main, - posterior_summary_pairwise = fit$posterior_summary_pairwise, + posterior_summary_main = fit$posterior_summary_main, + posterior_summary_pairwise = fit$posterior_summary_pairwise, posterior_summary_indicator = fit$posterior_summary_indicator, - posterior_mean_main = fit$posterior_mean_main, - posterior_mean_pairwise = fit$posterior_mean_pairwise, + posterior_mean_main = fit$posterior_mean_main, + posterior_mean_pairwise = fit$posterior_mean_pairwise, posterior_mean_indicator = fit$posterior_mean_indicator, - raw_main_chain1 = fit$raw_samples$main[[1]], - raw_pairwise_chain1 = fit$raw_samples$pairwise[[1]], - raw_indicator_chain1 = if(!is.null(fit$raw_samples$indicator)) - fit$raw_samples$indicator[[1]] else NULL, + raw_main_chain1 = fit$raw_samples$main[[1]], + raw_pairwise_chain1 = fit$raw_samples$pairwise[[1]], + raw_indicator_chain1 = if(!is.null(fit$raw_samples$indicator)) { + fit$raw_samples$indicator[[1]] + } else { + NULL + }, nuts_diag = fit$nuts_diag, posterior_coclustering_matrix = fit$posterior_coclustering_matrix, - posterior_mean_allocations = fit$posterior_mean_allocations + posterior_mean_allocations = fit$posterior_mean_allocations ) } extract_compare_actual = function(fit) { list( - posterior_summary_main_baseline = fit$posterior_summary_main_baseline, - posterior_summary_pairwise_baseline = fit$posterior_summary_pairwise_baseline, - posterior_summary_main_differences = fit$posterior_summary_main_differences, + posterior_summary_main_baseline = fit$posterior_summary_main_baseline, + posterior_summary_pairwise_baseline = fit$posterior_summary_pairwise_baseline, + posterior_summary_main_differences = fit$posterior_summary_main_differences, posterior_summary_pairwise_differences = fit$posterior_summary_pairwise_differences, - posterior_summary_indicator = fit$posterior_summary_indicator, - posterior_mean_main_baseline = fit$posterior_mean_main_baseline, - posterior_mean_pairwise_baseline = fit$posterior_mean_pairwise_baseline, - posterior_mean_main_differences = fit$posterior_mean_main_differences, + posterior_summary_indicator = fit$posterior_summary_indicator, + posterior_mean_main_baseline = fit$posterior_mean_main_baseline, + posterior_mean_pairwise_baseline = fit$posterior_mean_pairwise_baseline, + posterior_mean_main_differences = fit$posterior_mean_main_differences, posterior_mean_pairwise_differences = fit$posterior_mean_pairwise_differences, - posterior_mean_indicator = fit$posterior_mean_indicator, + posterior_mean_indicator = fit$posterior_mean_indicator, raw_samples = fit$raw_samples, - nuts_diag = fit$nuts_diag + nuts_diag = fit$nuts_diag ) } @@ -425,11 +443,13 @@ na_bugfix_ids = c( ) # Configs excluded from bitwise comparison due to confirmed algorithm changes -# (see header note 5). Checked for structural match only. +# (see header notes 5 and 7). Checked for structural match only. structure_only_ids = c( - "bgm_wenchuan_nuts_blumecapel_impute", # Blume-Capel imputation bug fix - "bgm_wenchuan_nuts_sbm", # SBM lazy init changes RNG order (not a bug) - "bgm_adhd_nuts_sbm" # SBM lazy init changes RNG order (not a bug) + "bgm_wenchuan_nuts_blumecapel_impute", # Blume-Capel imputation bug fix (note 5c) + "bgm_wenchuan_nuts_sbm", # SBM lazy init changes RNG order (note 4) + "bgm_adhd_nuts_sbm", # SBM lazy init changes RNG order (note 4) + "bgm_boredom_hmc_bernoulli", # overflow guard reclassifies fast/slow (note 7) + "cmp_wenchuan_hmc_bernoulli" # overflow guard reclassifies fast/slow (note 7) ) compare_fields = function(expected, actual, type, id) { @@ -478,7 +498,7 @@ compare_fields = function(expected, actual, type, id) { # code computes values. Compare only cells that are non-NA in the fixture. is_posterior = grepl("^posterior_summary|^posterior_mean", field) if(allow_na_skip && is_posterior && - (is.data.frame(exp_val) || is.matrix(exp_val))) { + (is.data.frame(exp_val) || is.matrix(exp_val))) { exp_m = as.matrix(exp_val) act_m = as.matrix(act_val) non_na = !is.na(exp_m) @@ -555,13 +575,17 @@ check_structure = function(expected, actual, type) { next } if(!identical(class(exp_val), class(act_val))) { - mismatches = c(mismatches, sprintf(" %s: class mismatch (%s vs %s)", - field, paste(class(exp_val), collapse = "/"), paste(class(act_val), collapse = "/"))) + mismatches = c(mismatches, sprintf( + " %s: class mismatch (%s vs %s)", + field, paste(class(exp_val), collapse = "/"), paste(class(act_val), collapse = "/") + )) next } if(!identical(dim(exp_val), dim(act_val)) && !identical(length(exp_val), length(act_val))) { - mismatches = c(mismatches, sprintf(" %s: dim mismatch (%s vs %s)", - field, paste(dim(exp_val), collapse = "x"), paste(dim(act_val), collapse = "x"))) + mismatches = c(mismatches, sprintf( + " %s: dim mismatch (%s vs %s)", + field, paste(dim(exp_val), collapse = "x"), paste(dim(act_val), collapse = "x") + )) } } mismatches @@ -601,16 +625,19 @@ for(entry in manifest) { # Run current build set.seed(args$seed) - fit = tryCatch({ - if(type == "bgm") { - do.call(bgm, args) - } else { - do.call(bgmCompare, args) + fit = tryCatch( + { + if(type == "bgm") { + do.call(bgm, args) + } else { + do.call(bgmCompare, args) + } + }, + error = function(e) { + cat(sprintf("ERROR: %s\n", conditionMessage(e))) + NULL } - }, error = function(e) { - cat(sprintf("ERROR: %s\n", conditionMessage(e))) - NULL - }) + ) if(is.null(fit)) { error_count = error_count + 1 diff --git a/dev/generate_legacy_fixtures.R b/tests/fixtures/generate_legacy_fixtures.R similarity index 67% rename from dev/generate_legacy_fixtures.R rename to tests/fixtures/generate_legacy_fixtures.R index 63eb85a1..f74acb65 100644 --- a/dev/generate_legacy_fixtures.R +++ b/tests/fixtures/generate_legacy_fixtures.R @@ -19,37 +19,37 @@ library(callr) # Output directory (in tests, not inst) -legacy_dir <- file.path("tests", "testthat", "fixtures", "legacy") +legacy_dir = file.path("tests", "testthat", "fixtures", "legacy") dir.create(legacy_dir, recursive = TRUE, showWarnings = FALSE) # Minimal data for fitting set.seed(42) -test_data <- matrix(sample(0:1, 60, replace = TRUE), ncol = 3) -colnames(test_data) <- c("X1", "X2", "X3") +test_data = matrix(sample(0:1, 60, replace = TRUE), ncol = 3) +colnames(test_data) = c("X1", "X2", "X3") # Version → snapshot date mapping (use date when version was current) # Dates derived from CRAN archive release dates -version_snapshots <- list( - "0.1.3" = "2024-03-15", # Released 2024-02-25 - "0.1.3.1" = "2024-06-01", # Released 2024-05-15 - "0.1.4" = "2024-11-01", # Released 2024-10-21 - "0.1.4.1" = "2024-11-20", # Released 2024-11-12 - "0.1.4.2" = "2024-12-15", # Released 2024-12-05 - "0.1.6.0" = "source", # Install from source (no binary available) - "0.1.6.1" = "source", # Install from source (no binary available) - "0.1.6.2" = "source", # Released 2026-01-20 - "0.1.6.3" = "current" # Current CRAN version (2026-02-14) +version_snapshots = list( + "0.1.3" = "2024-03-15", # Released 2024-02-25 + "0.1.3.1" = "2024-06-01", # Released 2024-05-15 + "0.1.4" = "2024-11-01", # Released 2024-10-21 + "0.1.4.1" = "2024-11-20", # Released 2024-11-12 + "0.1.4.2" = "2024-12-15", # Released 2024-12-05 + "0.1.6.0" = "source", # Install from source (no binary available) + "0.1.6.1" = "source", # Install from source (no binary available) + "0.1.6.2" = "source", # Released 2026-01-20 + "0.1.6.3" = "current" # Current CRAN version (2026-02-14) ) # Helper function to install and run in isolated environment -create_legacy_fit <- function(version, snapshot_date, output_name) { +create_legacy_fit = function(version, snapshot_date, output_name) { cat(sprintf("\n=== Creating fixture: %s (v%s) ===\n", output_name, version)) - - tmp_lib <- tempfile("bgms_lib_") + + tmp_lib = tempfile("bgms_lib_") dir.create(tmp_lib) - + tryCatch({ - if (snapshot_date == "current") { + if(snapshot_date == "current") { # Install from current CRAN cat("Installing bgms", version, "from current CRAN\n") install.packages( @@ -58,9 +58,9 @@ create_legacy_fit <- function(version, snapshot_date, output_name) { lib = tmp_lib, quiet = TRUE ) - } else if (snapshot_date == "source") { + } else if(snapshot_date == "source") { # Install from CRAN archive source - url <- paste0("https://cran.r-project.org/src/contrib/Archive/bgms/bgms_", version, ".tar.gz") + url = paste0("https://cran.r-project.org/src/contrib/Archive/bgms/bgms_", version, ".tar.gz") cat("Installing bgms", version, "from source:", url, "\n") install.packages( url, @@ -71,7 +71,7 @@ create_legacy_fit <- function(version, snapshot_date, output_name) { ) } else { # Install from Posit Package Manager binary snapshot - repos <- paste0("https://packagemanager.posit.co/cran/", snapshot_date) + repos = paste0("https://packagemanager.posit.co/cran/", snapshot_date) cat("Installing bgms", version, "from", repos, "(binary)\n") install.packages( "bgms", @@ -81,36 +81,35 @@ create_legacy_fit <- function(version, snapshot_date, output_name) { quiet = TRUE ) } - + # Verify installation - if (!"bgms" %in% list.files(tmp_lib)) { + if(!"bgms" %in% list.files(tmp_lib)) { stop("bgms was not installed") } - + # Run fit in separate R process with isolated library - result <- callr::r( + result = callr::r( function(data, lib_path) { .libPaths(c(lib_path, .libPaths())) library(bgms, lib.loc = lib_path) - pkg_version <- as.character(packageVersion("bgms")) + pkg_version = as.character(packageVersion("bgms")) cat("bgms version:", pkg_version, "\n") - + # Create fit with edge selection (save=TRUE for raw samples) - fit <- bgm(data, iter = 100, burnin = 50, save = TRUE) - + fit = bgm(data, iter = 100, burnin = 50, save = TRUE) + # Return fit object fit }, args = list(data = test_data, lib_path = tmp_lib), show = TRUE ) - + # Save fixture - output_path <- file.path(legacy_dir, paste0(output_name, ".rds")) + output_path = file.path(legacy_dir, paste0(output_name, ".rds")) saveRDS(result, output_path) cat("Saved:", output_path, "\n") return(TRUE) - }, error = function(e) { cat("ERROR:", conditionMessage(e), "\n") return(FALSE) @@ -123,26 +122,25 @@ create_legacy_fit <- function(version, snapshot_date, output_name) { # Generate fixtures for all versions # ============================================================================== -results <- list() +results = list() + +for(version in names(version_snapshots)) { + snapshot_date = version_snapshots[[version]] + output_name = paste0("fit_v", version) -for (version in names(version_snapshots)) { - snapshot_date <- version_snapshots[[version]] - output_name <- paste0("fit_v", version) - - success <- create_legacy_fit( + success = create_legacy_fit( version = version, snapshot_date = snapshot_date, output_name = output_name ) - - results[[version]] <- success + + results[[version]] = success } # Summary cat("\n=== Summary (bgm) ===\n") -for (version in names(results)) { - - status <- if (results[[version]]) "SUCCESS" else "FAILED" +for(version in names(results)) { + status = if(results[[version]]) "SUCCESS" else "FAILED" cat(sprintf(" %s: %s\n", version, status)) } @@ -157,13 +155,13 @@ for (version in names(results)) { # Create two-group test data for bgmCompare set.seed(42) -group1_data <- matrix(sample(0:1, 60, replace = TRUE), ncol = 3) -group2_data <- matrix(sample(0:1, 60, replace = TRUE), ncol = 3) -colnames(group1_data) <- c("X1", "X2", "X3") -colnames(group2_data) <- c("X1", "X2", "X3") +group1_data = matrix(sample(0:1, 60, replace = TRUE), ncol = 3) +group2_data = matrix(sample(0:1, 60, replace = TRUE), ncol = 3) +colnames(group1_data) = c("X1", "X2", "X3") +colnames(group2_data) = c("X1", "X2", "X3") # Versions that have bgmCompare (introduced in 0.1.4) -bgmcompare_versions <- list( +bgmcompare_versions = list( "0.1.4" = "2024-11-01", "0.1.4.1" = "2024-11-20", "0.1.4.2" = "2024-12-15", @@ -174,14 +172,14 @@ bgmcompare_versions <- list( ) # Helper function to create bgmCompare legacy fit -create_legacy_bgmcompare_fit <- function(version, snapshot_date, output_name) { +create_legacy_bgmcompare_fit = function(version, snapshot_date, output_name) { cat(sprintf("\n=== Creating bgmCompare fixture: %s (v%s) ===\n", output_name, version)) - - tmp_lib <- tempfile("bgms_lib_") + + tmp_lib = tempfile("bgms_lib_") dir.create(tmp_lib) - + tryCatch({ - if (snapshot_date == "current") { + if(snapshot_date == "current") { # Install from current CRAN cat("Installing bgms", version, "from current CRAN\n") install.packages( @@ -190,9 +188,9 @@ create_legacy_bgmcompare_fit <- function(version, snapshot_date, output_name) { lib = tmp_lib, quiet = TRUE ) - } else if (snapshot_date == "source") { + } else if(snapshot_date == "source") { # Install from CRAN archive source - url <- paste0("https://cran.r-project.org/src/contrib/Archive/bgms/bgms_", version, ".tar.gz") + url = paste0("https://cran.r-project.org/src/contrib/Archive/bgms/bgms_", version, ".tar.gz") cat("Installing bgms", version, "from source:", url, "\n") install.packages( url, @@ -203,7 +201,7 @@ create_legacy_bgmcompare_fit <- function(version, snapshot_date, output_name) { ) } else { # Install from Posit Package Manager binary snapshot - repos <- paste0("https://packagemanager.posit.co/cran/", snapshot_date) + repos = paste0("https://packagemanager.posit.co/cran/", snapshot_date) cat("Installing bgms", version, "from", repos, "(binary)\n") install.packages( "bgms", @@ -213,23 +211,23 @@ create_legacy_bgmcompare_fit <- function(version, snapshot_date, output_name) { quiet = TRUE ) } - + # Verify installation - if (!"bgms" %in% list.files(tmp_lib)) { + if(!"bgms" %in% list.files(tmp_lib)) { stop("bgms was not installed") } - + # Run fit in separate R process with isolated library - result <- callr::r( + result = callr::r( function(data_x, data_y, lib_path) { .libPaths(c(lib_path, .libPaths())) library(bgms, lib.loc = lib_path) - pkg_version <- as.character(packageVersion("bgms")) + pkg_version = as.character(packageVersion("bgms")) cat("bgms version:", pkg_version, "\n") - + # Create bgmCompare fit with difference_selection (save=TRUE for raw samples) # Old API uses separate x and y data frames - fit <- bgmCompare( + fit = bgmCompare( x = data_x, y = data_y, iter = 100, @@ -237,20 +235,19 @@ create_legacy_bgmcompare_fit <- function(version, snapshot_date, output_name) { difference_selection = TRUE, save = TRUE ) - + # Return fit object fit }, args = list(data_x = group1_data, data_y = group2_data, lib_path = tmp_lib), show = TRUE ) - + # Save fixture - output_path <- file.path(legacy_dir, paste0(output_name, ".rds")) + output_path = file.path(legacy_dir, paste0(output_name, ".rds")) saveRDS(result, output_path) cat("Saved:", output_path, "\n") return(TRUE) - }, error = function(e) { cat("ERROR:", conditionMessage(e), "\n") return(FALSE) @@ -260,25 +257,25 @@ create_legacy_bgmcompare_fit <- function(version, snapshot_date, output_name) { } # Generate bgmCompare fixtures -bgmcompare_results <- list() +bgmcompare_results = list() -for (version in names(bgmcompare_versions)) { - snapshot_date <- bgmcompare_versions[[version]] - output_name <- paste0("bgmcompare_v", version) - - success <- create_legacy_bgmcompare_fit( +for(version in names(bgmcompare_versions)) { + snapshot_date = bgmcompare_versions[[version]] + output_name = paste0("bgmcompare_v", version) + + success = create_legacy_bgmcompare_fit( version = version, snapshot_date = snapshot_date, output_name = output_name ) - - bgmcompare_results[[version]] <- success + + bgmcompare_results[[version]] = success } # Summary cat("\n=== Summary (bgmCompare) ===\n") -for (version in names(bgmcompare_results)) { - status <- if (bgmcompare_results[[version]]) "SUCCESS" else "FAILED" +for(version in names(bgmcompare_results)) { + status = if(bgmcompare_results[[version]]) "SUCCESS" else "FAILED" cat(sprintf(" %s: %s\n", version, status)) } diff --git a/tests/testthat/helper-fixtures.R b/tests/testthat/helper-fixtures.R index b217a35f..47306c26 100644 --- a/tests/testthat/helper-fixtures.R +++ b/tests/testthat/helper-fixtures.R @@ -60,7 +60,8 @@ options(bgms.verbose = FALSE) .test_cache = new.env(parent = emptyenv()) -#' @description Get cached bgms fit (4 binary variables, edge selection, 2 chains) +#' @description Get cached bgms fit +#' (4 binary variables, edge selection, 2 chains) get_bgms_fit = function() { if(is.null(.test_cache$bgms_fit)) { data("ADHD", package = "bgms") @@ -74,7 +75,8 @@ get_bgms_fit = function() { .test_cache$bgms_fit } -#' @description Get cached bgms fit (4 ordinal variables, edge selection, 2 chains) +#' @description Get cached bgms fit +#' (4 ordinal variables, edge selection, 2 chains) get_bgms_fit_ordinal = function() { if(is.null(.test_cache$bgms_fit_ordinal)) { data("Wenchuan", package = "bgms") @@ -88,7 +90,8 @@ get_bgms_fit_ordinal = function() { .test_cache$bgms_fit_ordinal } -#' @description Get cached bgmCompare fit (4 binary variables, 2 groups, 2 chains) +#' @description Get cached bgmCompare fit +#' (4 binary variables, 2 groups, 2 chains) get_bgmcompare_fit = function() { if(is.null(.test_cache$bgmcompare_fit)) { data("ADHD", package = "bgms") @@ -103,7 +106,8 @@ get_bgmcompare_fit = function() { .test_cache$bgmcompare_fit } -#' @description Get cached bgmCompare fit using x,y interface (4 ordinal variables, 2 chains) +#' @description Get cached bgmCompare fit +#' using x,y interface (4 ordinal variables, 2 chains) get_bgmcompare_fit_xy = function() { if(is.null(.test_cache$bgmcompare_fit_xy)) { data("Wenchuan", package = "bgms") @@ -119,7 +123,8 @@ get_bgmcompare_fit_xy = function() { .test_cache$bgmcompare_fit_xy } -#' @description Get cached bgmCompare fit (4 ordinal variables, 2 groups, 2 chains) +#' @description Get cached bgmCompare fit +#' (4 ordinal variables, 2 groups, 2 chains) get_bgmcompare_fit_ordinal = function() { if(is.null(.test_cache$bgmcompare_fit_ordinal)) { data("Wenchuan", package = "bgms") @@ -151,7 +156,8 @@ get_bgms_fit_blumecapel = function() { .test_cache$bgms_fit_blumecapel } -#' @description Get cached bgms fit with single chain (for R-hat edge case testing) +#' @description Get cached bgms fit with +#' single chain (for R-hat edge case testing) get_bgms_fit_single_chain = function() { if(is.null(.test_cache$bgms_fit_single)) { data("ADHD", package = "bgms") @@ -234,7 +240,9 @@ get_bgmcompare_fit_hmc_blumecapel = function() { .test_cache$bgmcompare_fit_hmc_bc } -#' @description Get cached bgmCompare fit with main_difference_selection = TRUE + Blume-Capel (1 chain) +#' @description Get cached bgmCompare fit with +#' main_difference_selection = TRUE + +#' Blume-Capel (1 chain) #' Crosses Blume-Capel with difference_selection (Bernoulli prior) get_bgmcompare_fit_main_selection = function() { if(is.null(.test_cache$bgmcompare_fit_main_sel)) { @@ -257,7 +265,9 @@ get_bgmcompare_fit_main_selection = function() { .test_cache$bgmcompare_fit_main_sel } -#' @description Get cached bgmCompare fit with Beta-Bernoulli difference prior + ordinal (1 chain) +#' @description Get cached bgmCompare fit with +#' Beta-Bernoulli difference prior + +#' ordinal (1 chain) #' Crosses Beta-Bernoulli prior with ordinal variables get_bgmcompare_fit_beta_bernoulli = function() { if(is.null(.test_cache$bgmcompare_fit_bb)) { @@ -296,7 +306,8 @@ get_bgms_fit_beta_bernoulli = function() { .test_cache$bgms_fit_bb } -#' @description Get cached bgms fit with Stochastic-Block Model edge prior (2 chains) +#' @description Get cached bgms fit with +#' Stochastic-Block Model edge prior (2 chains) get_bgms_fit_sbm = function() { if(is.null(.test_cache$bgms_fit_sbm)) { data("ADHD", package = "bgms") @@ -329,7 +340,8 @@ get_bgms_fit_hmc = function() { .test_cache$bgms_fit_hmc } -#' @description Get cached bgms fit with adaptive-metropolis + Blume-Capel (1 chain) +#' @description Get cached bgms fit with +#' adaptive-metropolis + Blume-Capel (1 chain) get_bgms_fit_am_blumecapel = function() { if(is.null(.test_cache$bgms_fit_am_bc)) { data("Wenchuan", package = "bgms") @@ -400,7 +412,9 @@ get_bgmcompare_fit_blumecapel = function() { .test_cache$bgmcompare_fit_bc } -#' @description Get cached bgmCompare fit with adaptive-metropolis + Blume-Capel (1 chain) +#' @description Get cached bgmCompare fit +#' with adaptive-metropolis + Blume-Capel +#' (1 chain) get_bgmcompare_fit_am_blumecapel = function() { if(is.null(.test_cache$bgmcompare_fit_am_bc)) { data("Boredom", package = "bgms") @@ -441,7 +455,9 @@ get_bgmcompare_fit_impute = function() { .test_cache$bgmcompare_fit_impute } -#' @description Get cached bgmCompare fit with Blume-Capel + missing data imputation (1 chain) +#' @description Get cached bgmCompare fit +#' with Blume-Capel + missing data +#' imputation (1 chain) get_bgmcompare_fit_blumecapel_impute = function() { if(is.null(.test_cache$bgmcompare_fit_bc_impute)) { data("Boredom", package = "bgms") @@ -483,7 +499,9 @@ get_bgmcompare_fit_standardize = function() { .test_cache$bgmcompare_fit_std } -#' @description Get cached bgms fit for GGM with edge selection (4 continuous variables, 1 chain) +#' @description Get cached bgms fit for GGM +#' with edge selection +#' (4 continuous variables, 1 chain) get_bgms_fit_ggm = function() { if(is.null(.test_cache$bgms_fit_ggm)) { set.seed(42) @@ -501,7 +519,9 @@ get_bgms_fit_ggm = function() { .test_cache$bgms_fit_ggm } -#' @description Get cached bgms fit for GGM without edge selection (4 continuous variables, 1 chain) +#' @description Get cached bgms fit for GGM +#' without edge selection +#' (4 continuous variables, 1 chain) get_bgms_fit_ggm_no_es = function() { if(is.null(.test_cache$bgms_fit_ggm_no_es)) { set.seed(42) @@ -519,6 +539,313 @@ get_bgms_fit_ggm_no_es = function() { .test_cache$bgms_fit_ggm_no_es } +get_bgms_fit_mixed_mrf = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + iter = 50, warmup = 100, chains = 1, + seed = 77771, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf +} + +get_bgms_fit_mixed_mrf_no_es = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_no_es)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_no_es = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = FALSE, + iter = 50, warmup = 100, chains = 1, + seed = 77772, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_no_es +} + +get_bgms_fit_mixed_mrf_marginal = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_marginal)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_marginal = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = FALSE, + pseudolikelihood = "marginal", + iter = 50, warmup = 100, chains = 1, + seed = 77773, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_marginal +} + +get_bgms_fit_mixed_mrf_marginal_es = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_marginal_es)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_marginal_es = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + pseudolikelihood = "marginal", + iter = 50, warmup = 100, chains = 1, + seed = 77774, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_marginal_es +} + +get_bgms_fit_mixed_mrf_nuts = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_nuts)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_nuts = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + update_method = "nuts", + iter = 50, warmup = 100, chains = 1, + seed = 77775, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_nuts +} + +get_bgms_fit_mixed_mrf_nuts_no_es = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_nuts_no_es)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_nuts_no_es = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = FALSE, + update_method = "nuts", + iter = 50, warmup = 100, chains = 1, + seed = 77776, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_nuts_no_es +} + +get_bgms_fit_mixed_mrf_beta_bernoulli = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_beta_bernoulli)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_beta_bernoulli = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + edge_prior = "Beta-Bernoulli", + iter = 50, warmup = 100, chains = 1, + seed = 77777, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_beta_bernoulli +} + +get_bgms_fit_mixed_mrf_sbm = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_sbm)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_sbm = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + edge_prior = "Stochastic-Block", + iter = 50, warmup = 100, chains = 1, + seed = 77778, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_sbm +} + +get_bgms_fit_mixed_mrf_bc = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_bc)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:4, n, replace = TRUE), + rnorm(n), + sample(0:4, n, replace = TRUE), + rnorm(n) + ) + colnames(x) = c("bc1", "c1", "bc2", "c2") + .test_cache$bgms_fit_mixed_mrf_bc = bgm( + x = x, + variable_type = c( + "blume-capel", "continuous", + "blume-capel", "continuous" + ), + baseline_category = 2L, + edge_selection = TRUE, + iter = 50, warmup = 100, chains = 1, + seed = 77779, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_bc +} + +get_bgms_fit_mixed_mrf_impute = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_impute)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + # Insert NAs in both discrete and continuous columns + x[1, 1] = NA + x[2, 2] = NA + .test_cache$bgms_fit_mixed_mrf_impute = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + na_action = "impute", + iter = 50, warmup = 100, chains = 1, + seed = 77780, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_impute +} + +get_bgms_fit_mixed_mrf_multichain = function() { + if(is.null(.test_cache$bgms_fit_mixed_mrf_multichain)) { + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + .test_cache$bgms_fit_mixed_mrf_multichain = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + iter = 50, warmup = 100, chains = 2, + seed = 77781, + display_progress = "none" + ) + } + .test_cache$bgms_fit_mixed_mrf_multichain +} + # ------------------------------------------------------------------------------ # 2. Prediction Data Helpers # ------------------------------------------------------------------------------ @@ -561,6 +888,37 @@ get_prediction_data_ggm = function(n = 10) { x } +#' Get prediction data matching the mixed MRF bgms fixture +#' Columns: d1 (ordinal 0-2), c1 (continuous), d2 (ordinal 0-2), +#' c2 (continuous), d3 (ordinal 0-2) +get_prediction_data_mixed = function(n = 10) { + set.seed(199) + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + x +} + +#' Get prediction data matching the mixed MRF Blume-Capel bgms fixture +#' Columns: bc1 (ordinal 0-4), c1 (continuous), bc2 (ordinal 0-4), +#' c2 (continuous) +get_prediction_data_mixed_bc = function(n = 10) { + set.seed(299) + x = cbind( + sample(0:4, n, replace = TRUE), + rnorm(n), + sample(0:4, n, replace = TRUE), + rnorm(n) + ) + colnames(x) = c("bc1", "c1", "bc2", "c2") + x +} + # ------------------------------------------------------------------------------ # 3. Test Data Generators # ------------------------------------------------------------------------------ @@ -624,7 +982,8 @@ upper_vals = function(M) { M[upper.tri(M)] } -#' Check that named summary entries match matrix positions (ordering consistency) +#' Check that named summary entries match +#' matrix positions (ordering consistency) #' #' For each row of summary_df (named "Vi-Vj"), verify that summary_df$mean[k] #' equals matrix_val[Vi, Vj]. Returns a logical vector (TRUE = match). @@ -633,15 +992,20 @@ check_summary_matrix_consistency = function(summary_df, matrix_val) { matches = logical(nrow(summary_df)) for(k in seq_len(nrow(summary_df))) { parts = strsplit(rownames(summary_df)[k], "-")[[1]] - matches[k] = abs(summary_df$mean[k] - matrix_val[parts[1], parts[2]]) < 1e-10 + matches[k] = abs( + summary_df$mean[k] - + matrix_val[parts[1], parts[2]] + ) < 1e-10 } matches } -#' Check that extractor column means match matrix positions (ordering consistency) +#' Check that extractor column means match +#' matrix positions (ordering consistency) #' #' For each named element of extracted_means (named "Vi-Vj"), verify that -#' the value matches matrix_val[Vi, Vj]. Returns a logical vector (TRUE = match). +#' the value matches matrix_val[Vi, Vj]. +#' Returns a logical vector (TRUE = match). check_extractor_matrix_consistency = function(extracted_means, matrix_val) { matches = logical(length(extracted_means)) for(k in seq_along(extracted_means)) { @@ -668,7 +1032,10 @@ expect_extractor_structure = function(obj, type, expected_dim = NULL, # Type check expect_true( inherits(obj, type), - info = sprintf("Expected class %s, got %s", type, paste(class(obj), collapse = ", ")) + info = sprintf( + "Expected class %s, got %s", + type, paste(class(obj), collapse = ", ") + ) ) # Dimension check @@ -740,3 +1107,363 @@ moderate_mcmc_args = function() { display_progress = "none" ) } + + +# ============================================================================== +# 6. Consolidated Fixture Spec Lists (single source of truth) +# ============================================================================== +# These are used by test-methods.R, test-simulate-predict-regression.R, +# and test-extractor-functions.R. Each entry is a named list describing one +# fixture: label, get_fit, var_type, and boolean flags (is_continuous, is_mixed). +# Entries used by simulate/predict also carry get_prediction_data. + +# ------------------------------------------------------------------ +# get_bgms_fixtures +# ------------------------------------------------------------------ +# All bgms fit variants for parameterized testing. +# Entries carry get_prediction_data for simulate/predict loops. +# +# Returns: list of fixture spec lists. +# ------------------------------------------------------------------ +get_bgms_fixtures = function() { + list( + list( + label = "binary", + get_fit = get_bgms_fit, + get_prediction_data = get_prediction_data_binary, + var_type = "binary", + is_continuous = FALSE + ), + list( + label = "ordinal", + get_fit = get_bgms_fit_ordinal, + get_prediction_data = get_prediction_data_ordinal, + var_type = "ordinal", + is_continuous = FALSE + ), + list( + label = "single-chain", + get_fit = get_bgms_fit_single_chain, + get_prediction_data = get_prediction_data_binary, + var_type = "binary", + is_continuous = FALSE + ), + list( + label = "blume-capel", + get_fit = get_bgms_fit_blumecapel, + get_prediction_data = get_prediction_data_ordinal, + var_type = "blume-capel", + is_continuous = FALSE + ), + list( + label = "adaptive-metropolis", + get_fit = get_bgms_fit_adaptive_metropolis, + get_prediction_data = get_prediction_data_binary, + var_type = "binary", + is_continuous = FALSE + ), + list( + label = "hmc", + get_fit = get_bgms_fit_hmc, + get_prediction_data = get_prediction_data_ordinal, + var_type = "ordinal", + is_continuous = FALSE + ), + list( + label = "am-blumecapel", + get_fit = get_bgms_fit_am_blumecapel, + get_prediction_data = get_prediction_data_ordinal, + var_type = "blume-capel", + is_continuous = FALSE + ), + list( + label = "impute", + get_fit = get_bgms_fit_impute, + get_prediction_data = get_prediction_data_ordinal, + var_type = "ordinal", + is_continuous = FALSE + ), + list( + label = "standardize", + get_fit = get_bgms_fit_standardize, + get_prediction_data = get_prediction_data_ordinal, + var_type = "ordinal", + is_continuous = FALSE + ), + list( + label = "beta-bernoulli", + get_fit = get_bgms_fit_beta_bernoulli, + get_prediction_data = get_prediction_data_binary, + var_type = "binary", + is_continuous = FALSE + ), + list( + label = "sbm", + get_fit = get_bgms_fit_sbm, + get_prediction_data = get_prediction_data_binary, + var_type = "binary", + is_continuous = FALSE + ), + list( + label = "ggm", + get_fit = get_bgms_fit_ggm, + get_prediction_data = get_prediction_data_ggm, + var_type = "continuous", + is_continuous = TRUE + ), + list( + label = "ggm-no-es", + get_fit = get_bgms_fit_ggm_no_es, + get_prediction_data = get_prediction_data_ggm, + var_type = "continuous", + is_continuous = TRUE + ), + list( + label = "mixed-mrf", + get_fit = get_bgms_fit_mixed_mrf, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-no-es", + get_fit = get_bgms_fit_mixed_mrf_no_es, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-marginal", + get_fit = get_bgms_fit_mixed_mrf_marginal, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-marginal-es", + get_fit = get_bgms_fit_mixed_mrf_marginal_es, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-nuts", + get_fit = get_bgms_fit_mixed_mrf_nuts, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-nuts-no-es", + get_fit = get_bgms_fit_mixed_mrf_nuts_no_es, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-beta-bernoulli", + get_fit = get_bgms_fit_mixed_mrf_beta_bernoulli, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-sbm", + get_fit = get_bgms_fit_mixed_mrf_sbm, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-bc", + get_fit = get_bgms_fit_mixed_mrf_bc, + get_prediction_data = get_prediction_data_mixed_bc, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-impute", + get_fit = get_bgms_fit_mixed_mrf_impute, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ), + list( + label = "mixed-mrf-multichain", + get_fit = get_bgms_fit_mixed_mrf_multichain, + get_prediction_data = get_prediction_data_mixed, + var_type = "mixed", + is_continuous = FALSE, + is_mixed = TRUE + ) + ) +} + +# ------------------------------------------------------------------ +# get_bgmcompare_fixtures +# ------------------------------------------------------------------ +# All bgmCompare fit variants for parameterized testing. +# +# Returns: list of fixture spec lists. +# ------------------------------------------------------------------ +get_bgmcompare_fixtures = function() { + list( + list( + label = "binary", + get_fit = get_bgmcompare_fit, + get_prediction_data = get_prediction_data_bgmcompare_binary, + var_type = "binary" + ), + list( + label = "ordinal", + get_fit = get_bgmcompare_fit_ordinal, + get_prediction_data = get_prediction_data_bgmcompare_ordinal, + var_type = "ordinal" + ), + list( + label = "adaptive-metropolis", + get_fit = get_bgmcompare_fit_adaptive_metropolis, + get_prediction_data = get_prediction_data_bgmcompare_binary, + var_type = "binary" + ), + list( + label = "hmc", + get_fit = get_bgmcompare_fit_hmc, + get_prediction_data = get_prediction_data_bgmcompare_binary, + var_type = "binary" + ), + list( + label = "hmc-blume-capel", + get_fit = get_bgmcompare_fit_hmc_blumecapel, + get_prediction_data = get_prediction_data_bgmcompare_blumecapel, + var_type = "blume-capel" + ), + list( + label = "blume-capel", + get_fit = get_bgmcompare_fit_blumecapel, + get_prediction_data = get_prediction_data_bgmcompare_blumecapel, + var_type = "blume-capel" + ), + list( + label = "am-blume-capel", + get_fit = get_bgmcompare_fit_am_blumecapel, + get_prediction_data = get_prediction_data_bgmcompare_blumecapel, + var_type = "blume-capel" + ), + list( + label = "impute", + get_fit = get_bgmcompare_fit_impute, + get_prediction_data = get_prediction_data_bgmcompare_ordinal, + var_type = "ordinal" + ), + list( + label = "blume-capel-impute", + get_fit = get_bgmcompare_fit_blumecapel_impute, + get_prediction_data = get_prediction_data_bgmcompare_blumecapel, + var_type = "blume-capel" + ), + list( + label = "beta-bernoulli", + get_fit = get_bgmcompare_fit_beta_bernoulli, + get_prediction_data = get_prediction_data_bgmcompare_ordinal, + var_type = "ordinal" + ), + list( + label = "standardize", + get_fit = get_bgmcompare_fit_standardize, + get_prediction_data = get_prediction_data_bgmcompare_ordinal, + var_type = "ordinal" + ), + list( + label = "xy", + get_fit = get_bgmcompare_fit_xy, + get_prediction_data = get_prediction_data_bgmcompare_ordinal, + var_type = "ordinal" + ), + list( + label = "main-selection", + get_fit = get_bgmcompare_fit_main_selection, + get_prediction_data = get_prediction_data_bgmcompare_blumecapel, + var_type = "blume-capel" + ) + ) +} + +# ------------------------------------------------------------------ +# get_extractor_fixtures +# ------------------------------------------------------------------ +# Representative subset for extractor-function contract tests. +# Covers all model families (OMRF, GGM, mixed) and both object classes. +# +# Returns: list of fixture spec lists. +# ------------------------------------------------------------------ +get_extractor_fixtures = function() { + list( + list( + label = "bgms_binary", + get_fit = get_bgms_fit, + type = "bgms", + var_type = "binary" + ), + list( + label = "bgms_ordinal", + get_fit = get_bgms_fit_ordinal, + type = "bgms", + var_type = "ordinal" + ), + list( + label = "bgms_blumecapel", + get_fit = get_bgms_fit_blumecapel, + type = "bgms", + var_type = "blume-capel" + ), + list( + label = "bgms_ggm", + get_fit = get_bgms_fit_ggm, + type = "bgms", + var_type = "continuous", + is_continuous = TRUE + ), + list( + label = "bgms_mixed", + get_fit = get_bgms_fit_mixed_mrf, + type = "bgms", + var_type = "mixed", + is_mixed = TRUE + ), + list( + label = "bgmCompare_binary", + get_fit = get_bgmcompare_fit, + type = "bgmCompare", + var_type = "binary" + ), + list( + label = "bgmCompare_ordinal", + get_fit = get_bgmcompare_fit_ordinal, + type = "bgmCompare", + var_type = "ordinal" + ), + list( + label = "bgmCompare_blumecapel", + get_fit = get_bgmcompare_fit_blumecapel, + type = "bgmCompare", + var_type = "blume-capel" + ), + list( + label = "bgmCompare_main_sel", + get_fit = get_bgmcompare_fit_main_selection, + type = "bgmCompare", + var_type = "blume-capel" + ) + ) +} diff --git a/tests/testthat/helper-validation.R b/tests/testthat/helper-validation.R new file mode 100644 index 00000000..2d1624d8 --- /dev/null +++ b/tests/testthat/helper-validation.R @@ -0,0 +1,431 @@ +# =========================================================================== +# Shared helpers for the mixed MRF validation test suite +# =========================================================================== +# Provides: +# - make_network() : generate reproducible true parameter sets +# - generate_data() : simulate data from true parameters via bgms or +# mixedGM +# - extract_bgms_blocks(): pull (mux, muy, Kxx, Kxy, Kyy) from bgms fit +# - extract_mgm_blocks() : pull (mux, muy, Kxx, Kxy, Kyy) from mixedGM fit +# - flatten_params() : flatten all blocks to a single named vector +# - recovery_table() : compare estimated vs true as a data.frame +# - recovery_scatter() : scatterplot of estimated vs true +# - summarise_recovery() : one-line correlation + bias + RMSE summary +# =========================================================================== + +# ------------------------------------------------------------------ +# make_network +# ------------------------------------------------------------------ +# Build a reproducible mixed MRF parameter set. +# +# @param p Number of discrete variables. +# @param q Number of continuous variables. +# @param n_cat Integer vector of length p: number of categories per +# discrete variable (bgms convention = max index, +# e.g. binary = 1). +# @param variable_type Character vector of length p: "ordinal" or +# "blume-capel" per discrete variable. Default: all +# ordinal. +# @param baseline_category Integer vector of length p: baseline category +# for Blume-Capel variables. Ignored for ordinal. +# Default: all zeros. +# @param density Approximate fraction of non-zero edges. +# @param seed Random seed for reproducibility. +# +# Returns: named list with mux, muy, Kxx, Kxy, Kyy, n_cat, p, q, +# variable_type, baseline_category. +# ------------------------------------------------------------------ +make_network = function(p, q, n_cat, variable_type = rep("ordinal", p), + baseline_category = rep(0L, p), + density = 0.5, seed = 42) { + set.seed(seed) + + max_cat = max(n_cat) + # BC variables use 2 columns (alpha, beta); ensure mux is wide enough + mux_cols = max(max_cat, 2L) + + # --- Main effects (mux) --- + # Ordinal: threshold parameters, NA-padded. + # Blume-Capel: alpha (linear) and beta (quadratic), rest NA. + mux = matrix(NA_real_, nrow = p, ncol = mux_cols) + for(i in seq_len(p)) { + if(variable_type[i] == "blume-capel") { + mux[i, 1] = round(runif(1, -0.3, 0.3), 2) # alpha + mux[i, 2] = round(runif(1, -0.4, -0.05), 2) # beta (negative = peaked) + } else { + vals = sort(round(seq(-0.5, 0.5, length.out = n_cat[i]), 2)) + if(n_cat[i] == 1) vals = round(runif(1, -0.3, 0.3), 2) + mux[i, seq_len(n_cat[i])] = vals + } + } + + # --- Continuous means --- + muy = round(runif(q, -0.5, 0.5), 2) + + # --- Kxx: ordinal-ordinal interactions (symmetric, zero diagonal) --- + n_edges_xx = p * (p - 1) / 2 + mask_xx = rbinom(n_edges_xx, 1, density) + vals_xx = mask_xx * round(runif(n_edges_xx, 0.15, 0.4) * sample(c(-1, 1), n_edges_xx, replace = TRUE), 2) + Kxx = matrix(0, p, p) + Kxx[upper.tri(Kxx)] = vals_xx + Kxx = Kxx + t(Kxx) + + # --- Kyy: continuous precision (positive definite, sparse off-diag with diagonal dominance) --- + n_edges_yy = q * (q - 1) / 2 + mask_yy = rbinom(n_edges_yy, 1, density) + vals_yy = mask_yy * round(runif(n_edges_yy, 0.05, 0.2) * sample(c(-1, 1), n_edges_yy, replace = TRUE), 2) + Kyy = matrix(0, q, q) + Kyy[upper.tri(Kyy)] = vals_yy + Kyy = Kyy + t(Kyy) + diag(Kyy) = abs(rowSums(Kyy)) + runif(q, 1.2, 1.8) + + # --- Kxy: ordinal-continuous cross interactions --- + n_cross = p * q + mask_xy = rbinom(n_cross, 1, density) + vals_xy = mask_xy * round(runif(n_cross, 0.1, 0.3) * sample(c(-1, 1), n_cross, replace = TRUE), 2) + Kxy = matrix(vals_xy, p, q) + + list( + mux = mux, muy = muy, + Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + n_cat = n_cat, p = p, q = q, + variable_type = variable_type, + baseline_category = as.integer(baseline_category) + ) +} + +# ------------------------------------------------------------------ +# generate_data +# ------------------------------------------------------------------ +# Simulate data from true network parameters. +# +# @param net Output from make_network(). +# @param n Number of observations. +# @param source "bgms" or "mixedGM". +# @param iter Gibbs burn-in iterations. +# @param seed Random seed. +# +# Returns: data.frame (n x (p+q)) with ordinal columns first, then continuous. +# ------------------------------------------------------------------ +generate_data = function(net, n, source = "bgms", iter = 1000L, seed = 1) { + if(source == "bgms") { + # simulate_mrf() does not support mixed types; use the C++ Gibbs + # sampler directly via sample_mixed_mrf_gibbs(). + sim = bgms:::sample_mixed_mrf_gibbs( + num_states = as.integer(n), + Kxx_r = net$Kxx, + Kxy_r = net$Kxy, + Kyy_r = net$Kyy, + mux_r = net$mux, + muy_r = net$muy, + num_categories_r = as.integer(net$n_cat), + variable_type_r = net$variable_type, + baseline_category_r = net$baseline_category, + iter = as.integer(iter), + seed = as.integer(seed) + ) + df = as.data.frame(cbind(sim$x, sim$y)) + names(df) = c(paste0("X", seq_len(net$p)), paste0("Y", seq_len(net$q))) + df + } else if(source == "mixedGM") { + sim = mixedGM::mixed_gibbs_generate( + n = n, + Kxx = net$Kxx, Kxy = net$Kxy, Kyy = net$Kyy, + mux = net$mux, muy = net$muy, + num_categories = net$n_cat + 1L, + n_burnin = iter + ) + df = as.data.frame(cbind(sim$x, sim$y)) + names(df) = c(paste0("X", seq_len(net$p)), paste0("Y", seq_len(net$q))) + df + } else { + stop('generate_data(): source must be "bgms" or "mixedGM", not "', source, '".') + } +} + +# ------------------------------------------------------------------ +# extract_bgms_blocks +# ------------------------------------------------------------------ +# Extract parameter blocks from a bgms fit object. +# +# @param fit A bgms object. +# @param net The true network (for dimension reference). +# +# Returns: list(mux, muy, Kxx, Kxy, Kyy). +# ------------------------------------------------------------------ +extract_bgms_blocks = function(fit, net) { + pm = coef(fit) + mux = pm$main$discrete # p x max_cat + muy_vec = pm$main$continuous[, "mean"] # length q + Kyy_diag = pm$main$continuous[, "precision"] # length q + + pw = pm$pairwise # (p+q) x (p+q) + p = net$p + q = net$q + Kxx = pw[seq_len(p), seq_len(p)] + Kxy = pw[seq_len(p), p + seq_len(q)] + Kyy_off = pw[p + seq_len(q), p + seq_len(q)] + Kyy = Kyy_off + diag(Kyy) = Kyy_diag + + list(mux = mux, muy = muy_vec, Kxx = Kxx, Kxy = Kxy, Kyy = Kyy) +} + +# ------------------------------------------------------------------ +# extract_mgm_blocks +# ------------------------------------------------------------------ +# Extract posterior means from a mixedGM fit object. +# +# @param fit Output of mixedGM::mixed_sampler(). +# @param n_cat Integer vector of length p: number of categories per ordinal +# variable (bgms convention). Used to NA-pad unused threshold +# slots so flatten_params() produces matching lengths. +# +# Returns: list(mux, muy, Kxx, Kxy, Kyy). +# ------------------------------------------------------------------ +extract_mgm_blocks = function(fit, n_cat) { + mux = apply(fit$samples$mux, c(2, 3), mean) + for(i in seq_along(n_cat)) { + if(n_cat[i] < ncol(mux)) { + mux[i, (n_cat[i] + 1):ncol(mux)] = NA_real_ + } + } + list( + mux = mux, + muy = colMeans(fit$samples$muy), + Kxx = apply(fit$samples$Kxx, c(2, 3), mean), + Kxy = apply(fit$samples$Kxy, c(2, 3), mean), + Kyy = apply(fit$samples$Kyy, c(2, 3), mean) + ) +} + +# ------------------------------------------------------------------ +# flatten_params +# ------------------------------------------------------------------ +# Flatten parameter blocks to a single named vector for comparison. +# Excludes NA entries in mux. +# +# @param blocks list(mux, muy, Kxx, Kxy, Kyy). +# @param prefix Optional prefix for names. +# +# Returns: named numeric vector. +# ------------------------------------------------------------------ +flatten_params = function(blocks, prefix = "") { + mux_vals = as.vector(t(blocks$mux)) + mux_keep = !is.na(mux_vals) + mux_named = mux_vals[mux_keep] + names(mux_named) = paste0(prefix, "mux_", which(mux_keep)) + + muy_named = blocks$muy + names(muy_named) = paste0(prefix, "muy_", seq_along(muy_named)) + + # Kxx upper triangle + kxx_ut = blocks$Kxx[upper.tri(blocks$Kxx)] + names(kxx_ut) = paste0(prefix, "Kxx_", seq_along(kxx_ut)) + + # Kxy full + kxy_vals = as.vector(blocks$Kxy) + names(kxy_vals) = paste0(prefix, "Kxy_", seq_along(kxy_vals)) + + # Kyy upper triangle (includes diagonal) + kyy_ut = blocks$Kyy[upper.tri(blocks$Kyy, diag = TRUE)] + names(kyy_ut) = paste0(prefix, "Kyy_", seq_along(kyy_ut)) + + c(mux_named, muy_named, kxx_ut, kxy_vals, kyy_ut) +} + +# ------------------------------------------------------------------ +# recovery_table +# ------------------------------------------------------------------ +# Build a data.frame comparing true vs estimated parameters. +# +# @param true_blocks list(mux, muy, Kxx, Kxy, Kyy). +# @param est_blocks list(mux, muy, Kxx, Kxy, Kyy). +# @param label Character label for the method. +# +# Returns: data.frame with columns: parameter, true, estimate, diff, block, method. +# ------------------------------------------------------------------ +recovery_table = function(true_blocks, est_blocks, label = "estimate") { + true_flat = flatten_params(true_blocks) + est_flat = flatten_params(est_blocks) + + # Determine block membership + block = ifelse(grepl("^mux", names(true_flat)), "mux", + ifelse(grepl("^muy", names(true_flat)), "muy", + ifelse(grepl("^Kxx", names(true_flat)), "Kxx", + ifelse(grepl("^Kxy", names(true_flat)), "Kxy", "Kyy") + ) + ) + ) + + data.frame( + parameter = names(true_flat), + true = true_flat, + estimate = est_flat, + diff = est_flat - true_flat, + block = block, + method = label, + stringsAsFactors = FALSE, + row.names = NULL + ) +} + +# ------------------------------------------------------------------ +# recovery_scatter +# ------------------------------------------------------------------ +# Scatterplot of estimated vs true, coloured by block. +# +# @param tab Output of recovery_table() (possibly rbind of multiple). +# @param main Plot title. +# ------------------------------------------------------------------ +recovery_scatter = function(tab, main = "Parameter recovery") { + block_cols = c( + mux = "#E41A1C", muy = "#377EB8", Kxx = "#4DAF4A", + Kxy = "#984EA3", Kyy = "#FF7F00" + ) + rng = range(c(tab$true, tab$estimate), na.rm = TRUE) + rng = rng + diff(rng) * c(-0.05, 0.05) + + plot(tab$true, tab$estimate, + pch = 19, cex = 0.9, + col = adjustcolor(block_cols[tab$block], 0.7), + xlim = rng, ylim = rng, asp = 1, + xlab = "True value", ylab = "Posterior mean", + main = main + ) + abline(0, 1, lty = 2, col = "grey40") + + r = cor(tab$true, tab$estimate) + rmse = sqrt(mean(tab$diff^2)) + legend("topleft", + legend = c( + names(block_cols), + sprintf("r = %.3f", r), + sprintf("RMSE = %.3f", rmse) + ), + col = c(block_cols, NA, NA), + pch = c(rep(19, 5), NA, NA), + bty = "n", cex = 0.8 + ) +} + +# ------------------------------------------------------------------ +# summarise_recovery +# ------------------------------------------------------------------ +# Print a one-line summary of recovery quality. +# +# @param tab Output of recovery_table(). +# @param label String label. +# +# Returns: invisible data.frame with summary stats. +# ------------------------------------------------------------------ +summarise_recovery = function(tab, label = "") { + r = cor(tab$true, tab$estimate) + bias = mean(tab$diff) + rmse = sqrt(mean(tab$diff^2)) + max_diff = max(abs(tab$diff)) + cat(sprintf( + "[%s] r = %.4f | bias = %.4f | RMSE = %.4f | max|diff| = %.4f\n", + label, r, bias, rmse, max_diff + )) + + # Per-block summary + blocks = unique(tab$block) + block_summary = do.call(rbind, lapply(blocks, function(b) { + sub = tab[tab$block == b, ] + r_val = if(nrow(sub) >= 3 && sd(sub$true) > 0 && sd(sub$estimate) > 0) { + cor(sub$true, sub$estimate) + } else { + NA_real_ + } + data.frame( + block = b, + n = nrow(sub), + r = r_val, + bias = mean(sub$diff), + rmse = sqrt(mean(sub$diff^2)), + max_diff = max(abs(sub$diff)), + stringsAsFactors = FALSE + ) + })) + + cat(" Per-block breakdown:\n") + print(block_summary, row.names = FALSE) + invisible(block_summary) +} + +# ------------------------------------------------------------------ +# agreement_scatter +# ------------------------------------------------------------------ +# Scatterplot of method A vs method B estimates, coloured by block. +# +# @param tab_a recovery_table for method A. +# @param tab_b recovery_table for method B. +# @param label_a, label_b Axis labels. +# @param main Plot title. +# ------------------------------------------------------------------ +agreement_scatter = function(tab_a, tab_b, label_a = "Method A", + label_b = "Method B", main = "Agreement") { + block_cols = c( + mux = "#E41A1C", muy = "#377EB8", Kxx = "#4DAF4A", + Kxy = "#984EA3", Kyy = "#FF7F00" + ) + rng = range(c(tab_a$estimate, tab_b$estimate), na.rm = TRUE) + rng = rng + diff(rng) * c(-0.05, 0.05) + + plot(tab_a$estimate, tab_b$estimate, + pch = 19, cex = 0.9, + col = adjustcolor(block_cols[tab_a$block], 0.7), + xlim = rng, ylim = rng, asp = 1, + xlab = label_a, ylab = label_b, main = main + ) + abline(0, 1, lty = 2, col = "grey40") + + r = cor(tab_a$estimate, tab_b$estimate) + rmse = sqrt(mean((tab_a$estimate - tab_b$estimate)^2)) + legend("topleft", + legend = c( + names(block_cols), + sprintf("r = %.3f", r), + sprintf("RMSE = %.3f", rmse) + ), + col = c(block_cols, NA, NA), + pch = c(rep(19, 5), NA, NA), + bty = "n", cex = 0.8 + ) +} + +# ------------------------------------------------------------------ +# trace_panel +# ------------------------------------------------------------------ +# Plot trace + density panels for selected parameters from raw samples. +# +# @param samples Matrix of MCMC samples (iterations x parameters). +# @param names Column names to plot. +# @param true_vals Optional named vector of true values. +# @param main Overall title. +# ------------------------------------------------------------------ +trace_panel = function(samples, names = NULL, true_vals = NULL, main = "") { + if(is.null(names)) names = colnames(samples)[seq_len(min(6, ncol(samples)))] + n_par = length(names) + par(mfrow = c(n_par, 2), mar = c(3, 3, 2, 1), mgp = c(2, 0.6, 0)) + for(nm in names) { + vals = samples[, nm] + # Trace + plot(vals, + type = "l", col = adjustcolor("grey30", 0.5), + main = paste(nm, "trace"), xlab = "Iteration", ylab = nm, cex.main = 0.9 + ) + if(!is.null(true_vals) && nm %in% names(true_vals)) { + abline(h = true_vals[nm], col = "red", lwd = 2) + } + abline(h = mean(vals), col = "steelblue", lwd = 1.5) + # Density + d = density(vals) + plot(d, main = paste(nm, "density"), xlab = nm, cex.main = 0.9) + if(!is.null(true_vals) && nm %in% names(true_vals)) { + abline(v = true_vals[nm], col = "red", lwd = 2) + } + abline(v = mean(vals), col = "steelblue", lwd = 1.5) + } +} diff --git a/tests/testthat/test-bgm-spec.R b/tests/testthat/test-bgm-spec.R index 702d3b51..b7706f46 100644 --- a/tests/testthat/test-bgm-spec.R +++ b/tests/testthat/test-bgm-spec.R @@ -3,6 +3,12 @@ # Phase B.5 of the R scaffolding refactor. # ============================================================================== +# These are internal (non-exported) functions — bind via ::: for testing. +bgm_spec = bgms:::bgm_spec +validate_bgm_spec = bgms:::validate_bgm_spec +new_bgm_spec = bgms:::new_bgm_spec +sampler_sublist = bgms:::sampler_sublist + # ============================================================================== # Shared helpers / fixtures # ============================================================================== diff --git a/tests/testthat/test-bgm.R b/tests/testthat/test-bgm.R index 67b2cdec..367d123e 100644 --- a/tests/testthat/test-bgm.R +++ b/tests/testthat/test-bgm.R @@ -84,9 +84,19 @@ test_that("bgm GGM output has correct dimensions", { args = extract_arguments(fit) p = args$num_variables - # main: p diagonal precision elements - expect_equal(nrow(fit$posterior_summary_main), p) - expect_equal(nrow(fit$posterior_mean_main), p) + # GGM has no main effects; precision diagonal is in quadratic + expect_null(fit$posterior_summary_main) + expect_null(fit$posterior_mean_main) + expect_equal(nrow(fit$posterior_summary_quadratic), p) + + # pairwise: p*(p-1)/2 off-diagonal elements + n_edges = p * (p - 1) / 2 + expect_equal(nrow(fit$posterior_summary_pairwise), n_edges) + expect_equal(nrow(fit$posterior_mean_pairwise), p) + expect_equal(ncol(fit$posterior_mean_pairwise), p) + + # precision diagonal lives on the pairwise matrix diagonal + expect_true(all(diag(fit$posterior_mean_pairwise) > 0)) # pairwise: p*(p-1)/2 off-diagonal elements n_edges = p * (p - 1) / 2 @@ -116,7 +126,7 @@ test_that("bgm GGM without edge selection omits indicators", { test_that("bgm GGM posterior precision diagonals are positive", { fit = get_bgms_fit_ggm_no_es() - expect_true(all(fit$posterior_summary_main$mean > 0)) + expect_true(all(fit$posterior_summary_quadratic$mean > 0)) }) # ============================================================================== @@ -202,8 +212,13 @@ test_that("bgm GGM output has correct parameter ordering", { # Extractor column means -> matrix positions pw_means = colMeans(extract_pairwise_interactions(fit)) expect_true( - all(check_extractor_matrix_consistency(pw_means, fit$posterior_mean_pairwise)), - info = "GGM extract_pairwise_interactions() names do not match matrix positions" + all(check_extractor_matrix_consistency( + pw_means, fit$posterior_mean_pairwise + )), + info = paste( + "GGM extract_pairwise_interactions()", + "names do not match matrix positions" + ) ) # Truth-based swap-position checks: @@ -266,8 +281,13 @@ test_that("bgm OMRF output has correct parameter ordering", { # Extractor column means -> matrix positions pw_means = colMeans(extract_pairwise_interactions(fit)) expect_true( - all(check_extractor_matrix_consistency(pw_means, fit$posterior_mean_pairwise)), - info = "OMRF extract_pairwise_interactions() names do not match matrix positions" + all(check_extractor_matrix_consistency( + pw_means, fit$posterior_mean_pairwise + )), + info = paste( + "OMRF extract_pairwise_interactions()", + "names do not match matrix positions" + ) ) # Indicator summary names -> matrix positions @@ -316,12 +336,12 @@ test_that("bgm GGM multi-chain produces valid Rhat", { ) # All Rhat values should be below 1.1 for converged chains - rhat_main = fit$posterior_summary_main$Rhat + rhat_quad = fit$posterior_summary_quadratic$Rhat rhat_pair = fit$posterior_summary_pairwise$Rhat expect_true( - all(rhat_main < 1.1), - info = sprintf("Max main Rhat = %.3f (expected < 1.1)", max(rhat_main)) + all(rhat_quad < 1.1), + info = sprintf("Max quadratic Rhat = %.3f (expected < 1.1)", max(rhat_quad)) ) expect_true( all(rhat_pair < 1.1), @@ -362,7 +382,6 @@ test_that("bgm GGM posterior mean approaches MLE for large n", { # Reconstruct posterior mean precision omega_hat = fit$posterior_mean_pairwise - diag(omega_hat) = as.numeric(fit$posterior_mean_main) # Posterior mean should correlate highly with MLE (likelihood dominates) cor_offdiag = cor( @@ -373,7 +392,10 @@ test_that("bgm GGM posterior mean approaches MLE for large n", { expect_true( cor_offdiag > 0.95, - info = sprintf("Off-diagonal cor with MLE = %.3f (expected > 0.95)", cor_offdiag) + info = sprintf( + "Off-diagonal cor with MLE = %.3f (expected > 0.95)", + cor_offdiag + ) ) expect_true( cor_diag > 0.95, @@ -613,17 +635,17 @@ test_that("bgm GGM with p = 15 produces valid output", { # All precision diagonals should be positive expect_true( - all(fit$posterior_summary_main$mean > 0), + all(fit$posterior_summary_quadratic$mean > 0), info = "Some diagonal precision elements are non-positive" ) # All values should be finite - expect_true(all(is.finite(fit$posterior_summary_main$mean))) + expect_true(all(is.finite(fit$posterior_summary_quadratic$mean))) expect_true(all(is.finite(fit$posterior_summary_pairwise$mean))) # Correct dimensions n_edges = p * (p - 1) / 2 - expect_equal(nrow(fit$posterior_summary_main), p) + expect_equal(nrow(fit$posterior_summary_quadratic), p) expect_equal(nrow(fit$posterior_summary_pairwise), n_edges) }) @@ -698,6 +720,382 @@ test_that("bgm GGM edge selection discriminates true edges", { # --- D.7: Conditional regression check ---------------------------------------- +# ============================================================================== +# Mixed MRF End-to-End Tests +# ============================================================================== + +test_that("bgm mixed MRF is reproducible", { + fit1 = get_bgms_fit_mixed_mrf() + + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + + fit2 = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + iter = 50, warmup = 100, chains = 1, + seed = 77771, + display_progress = "none" + ) + + testthat::expect_equal(fit1$raw_samples$main, fit2$raw_samples$main) + testthat::expect_equal(fit1$raw_samples$pairwise, fit2$raw_samples$pairwise) +}) + +test_that("bgm mixed MRF output has correct dimensions", { + fit = get_bgms_fit_mixed_mrf() + args = extract_arguments(fit) + p_total = args$num_variables # 5 + p = 3L # discrete + q = 2L # continuous + + # pairwise: p_total*(p_total-1)/2 edges + n_edges = p_total * (p_total - 1) / 2 + expect_equal(nrow(fit$posterior_summary_pairwise), n_edges) + expect_equal(nrow(fit$posterior_mean_pairwise), p_total) + expect_equal(ncol(fit$posterior_mean_pairwise), p_total) + + # indicators (edge selection = TRUE) + expect_equal(nrow(fit$posterior_summary_indicator), n_edges) + expect_equal(nrow(fit$posterior_mean_indicator), p_total) + expect_equal(ncol(fit$posterior_mean_indicator), p_total) + + # posterior_mean_main: list with discrete and continuous + expect_true(is.list(fit$posterior_mean_main)) + expect_equal(nrow(fit$posterior_mean_main$discrete), p) + expect_equal(nrow(fit$posterior_mean_main$continuous), q) + expect_equal(ncol(fit$posterior_mean_main$continuous), 1) # mean only + + # raw samples + expect_equal(ncol(fit$raw_samples$pairwise[[1]]), n_edges) + expect_equal(nrow(fit$raw_samples$main[[1]]), args$iter) +}) + +test_that("bgm mixed MRF without edge selection omits indicators", { + fit = get_bgms_fit_mixed_mrf_no_es() + + expect_s3_class(fit, "bgms") + expect_null(fit$posterior_summary_indicator) + expect_null(fit$posterior_mean_indicator) +}) + +test_that("bgm mixed MRF pairwise matrix has correct variable names", { + fit = get_bgms_fit_mixed_mrf() + + # Interleaved order: d1, c1, d2, c2, d3 + expected_names = c("d1", "c1", "d2", "c2", "d3") + expect_equal(rownames(fit$posterior_mean_pairwise), expected_names) + expect_equal(colnames(fit$posterior_mean_pairwise), expected_names) + expect_equal(rownames(fit$posterior_mean_indicator), expected_names) + expect_equal(colnames(fit$posterior_mean_indicator), expected_names) +}) + +test_that("bgm mixed MRF pairwise matrix is symmetric", { + fit = get_bgms_fit_mixed_mrf() + expect_equal(fit$posterior_mean_pairwise, t(fit$posterior_mean_pairwise)) + expect_equal(fit$posterior_mean_indicator, t(fit$posterior_mean_indicator)) +}) + +test_that("bgm mixed MRF summary-matrix consistency", { + fit = get_bgms_fit_mixed_mrf() + expect_true( + all(check_summary_matrix_consistency( + fit$posterior_summary_pairwise, + fit$posterior_mean_pairwise + )), + info = "Mixed MRF pairwise summary names do not match matrix positions" + ) + expect_true( + all(check_summary_matrix_consistency( + fit$posterior_summary_indicator, + fit$posterior_mean_indicator + )), + info = "Mixed MRF indicator summary names do not match matrix positions" + ) +}) + +test_that("bgm mixed MRF posterior precision diagonals are positive", { + fit = get_bgms_fit_mixed_mrf_no_es() + args = extract_arguments(fit) + cont_idx = args$continuous_indices + expect_true(all(diag(fit$posterior_mean_pairwise)[cont_idx] > 0)) +}) + +test_that("bgm mixed MRF marginal pseudolikelihood runs", { + fit = get_bgms_fit_mixed_mrf_marginal() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) +}) + +test_that("bgm mixed MRF marginal PL with edge selection runs", { + fit = get_bgms_fit_mixed_mrf_marginal_es() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_equal(ncol(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + # Edge selection produces indicator matrix + expect_false(is.null(fit$posterior_mean_indicator)) + expect_equal(nrow(fit$posterior_mean_indicator), 5) + expect_true(all(fit$posterior_mean_indicator >= 0 & + fit$posterior_mean_indicator <= 1)) +}) + +test_that("bgm mixed MRF hybrid-NUTS runs", { + fit = get_bgms_fit_mixed_mrf_nuts() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_equal(ncol(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + # Edge selection active + expect_false(is.null(fit$posterior_mean_indicator)) +}) + +test_that("bgm mixed MRF hybrid-NUTS output dimensions", { + fit = get_bgms_fit_mixed_mrf_nuts() + args = extract_arguments(fit) + p_total = args$num_variables + n_edges = p_total * (p_total - 1) / 2 + + expect_equal(nrow(fit$posterior_summary_pairwise), n_edges) + expect_equal(nrow(fit$posterior_mean_pairwise), p_total) + expect_equal(ncol(fit$posterior_mean_pairwise), p_total) + expect_equal(nrow(fit$posterior_summary_indicator), n_edges) + expect_equal(ncol(fit$raw_samples$pairwise[[1]]), n_edges) + expect_equal(nrow(fit$raw_samples$main[[1]]), args$iter) +}) + +test_that("bgm mixed MRF hybrid-NUTS is reproducible", { + fit1 = get_bgms_fit_mixed_mrf_nuts() + + set.seed(99) + n = 80 + x = cbind( + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE), + rnorm(n), + sample(0:2, n, replace = TRUE) + ) + colnames(x) = c("d1", "c1", "d2", "c2", "d3") + + fit2 = bgm( + x = x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + edge_selection = TRUE, + update_method = "nuts", + iter = 50, warmup = 100, chains = 1, + seed = 77775, + display_progress = "none" + ) + + expect_equal(fit1$raw_samples$main, fit2$raw_samples$main) + expect_equal(fit1$raw_samples$pairwise, fit2$raw_samples$pairwise) +}) + +test_that("bgm mixed MRF hybrid-NUTS without edge selection runs", { + fit = get_bgms_fit_mixed_mrf_nuts_no_es() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + expect_null(fit$posterior_summary_indicator) + expect_null(fit$posterior_mean_indicator) +}) + +test_that("bgm mixed MRF Beta-Bernoulli prior runs", { + fit = get_bgms_fit_mixed_mrf_beta_bernoulli() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + expect_false(is.null(fit$posterior_mean_indicator)) + expect_true(all(fit$posterior_mean_indicator >= 0 & + fit$posterior_mean_indicator <= 1)) +}) + +test_that("bgm mixed MRF Stochastic-Block prior runs", { + fit = get_bgms_fit_mixed_mrf_sbm() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + expect_false(is.null(fit$posterior_mean_indicator)) + expect_true(all(fit$posterior_mean_indicator >= 0 & + fit$posterior_mean_indicator <= 1)) +}) + +test_that("bgm mixed MRF Blume-Capel + continuous runs", { + fit = get_bgms_fit_mixed_mrf_bc() + expect_s3_class(fit, "bgms") + p_total = 4 # bc1, c1, bc2, c2 + expect_equal(nrow(fit$posterior_mean_pairwise), p_total) + expect_equal(ncol(fit$posterior_mean_pairwise), p_total) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + expect_false(is.null(fit$posterior_mean_indicator)) + + # Blume-Capel main effects: quadratic structure (linear + quadratic terms) + expect_true(is.list(fit$posterior_mean_main)) + expect_true("discrete" %in% names(fit$posterior_mean_main)) + expect_true("continuous" %in% names(fit$posterior_mean_main)) +}) + +test_that("bgm mixed MRF imputation runs", { + fit = get_bgms_fit_mixed_mrf_impute() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + expect_false(is.null(fit$posterior_mean_indicator)) +}) + +test_that("bgm mixed MRF multi-chain R-hat and ESS", { + fit = get_bgms_fit_mixed_mrf_multichain() + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), 5) + expect_true(all(is.finite(fit$posterior_mean_pairwise))) + + # Multi-chain produces multiple raw sample chains + expect_equal(length(fit$raw_samples$pairwise), 2) + expect_equal(length(fit$raw_samples$main), 2) + + # R-hat and ESS should be computable + rhat = extract_rhat(fit) + expect_true(is.numeric(rhat$pairwise)) + expect_true(all(is.na(rhat$pairwise) | rhat$pairwise > 0)) + ess = extract_ess(fit) + expect_true(is.numeric(ess$pairwise)) + expect_true(all(is.na(ess$pairwise) | ess$pairwise > 0)) +}) + +test_that("bgm mixed MRF output has correct parameter ordering", { + skip_on_cran() + + # 5 interleaved variables: d1, c1, d2, c2, d3 (p=3 discrete, q=2 continuous) + # Internal C++ order: d1, d2, d3, c1, c2 + # + # For a 5x5 upper triangle, position 2 differs between orderings: + # Row-major position 2 = (1,4) = d1-c2 + # Col-major position 2 = (2,3) = c1-d2 + # Strategic zeros make any swap detectable. + p = 3L + q = 2L + n = 500L + + # Parameters in internal (dd/cc/dc block) order + Kxx = matrix(c( + 0, -0.4, 0.2, + -0.4, 0, 0.0, + 0.2, 0.0, 0 + ), p, p, byrow = TRUE) + + Kxy = matrix(c( + 0.3, 0.0, # d1-c1 = 0.3, d1-c2 = 0.0 (swap sentinel) + 0.5, 0.3, # d2-c1 = 0.5 (swap sentinel), d2-c2 = 0.3 + -0.3, 0.15 # d3-c1 = -0.3, d3-c2 = 0.15 + ), p, q, byrow = TRUE) + + Kyy = diag(c(1.5, 2.0)) + Kyy[1, 2] = Kyy[2, 1] = 0.0 # c1-c2 = 0 (swap sentinel) + + nc = c(2L, 2L, 2L) + mux = matrix(0, p, max(nc) + 1) + muy = rep(0, q) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = rep("ordinal", p), + baseline_category_r = rep(0L, p), iter = 500L, seed = 42L + ) + + # Reassemble in user (interleaved) order: d1, c1, d2, c2, d3 + x = data.frame( + d1 = result$x[, 1], + c1 = result$y[, 1], + d2 = result$x[, 2], + c2 = result$y[, 2], + d3 = result$x[, 3] + ) + + fit = bgm( + x, + variable_type = c( + "ordinal", "continuous", "ordinal", + "continuous", "ordinal" + ), + iter = 1000, warmup = 500, chains = 1, + edge_selection = FALSE, seed = 42, + display_progress = "none" + ) + + # Extractor column means -> matrix positions + pw_means = colMeans(extract_pairwise_interactions(fit)) + expect_true( + all(check_extractor_matrix_consistency( + pw_means, fit$posterior_mean_pairwise + )), + info = paste( + "Mixed MRF extract_pairwise_interactions()", + "names do not match matrix positions" + ) + ) + + # Truth-based swap checks (user-order variable names): + # d1-c2 (true = 0.0) should be near zero, not ~0.5 (d2-c1's value) + expect_true( + abs(fit$posterior_mean_pairwise["d1", "c2"]) < 0.2, + info = sprintf( + "d1-c2 should be ~0 but is %.3f (possible swap with c1-d2)", + fit$posterior_mean_pairwise["d1", "c2"] + ) + ) + # c1-d2 (true = 0.5) should be clearly positive, not ~0 (d1-c2's value) + expect_true( + fit$posterior_mean_pairwise["c1", "d2"] > 0.15, + info = sprintf( + "c1-d2 should be ~0.5 but is %.3f (possible swap with d1-c2)", + fit$posterior_mean_pairwise["c1", "d2"] + ) + ) + # c1-c2 (true = 0.0) should be near zero, not ~0.3 (d2-c2's value) + expect_true( + abs(fit$posterior_mean_pairwise["c1", "c2"]) < 0.2, + info = sprintf( + "c1-c2 should be ~0 but is %.3f (possible swap with d2-c2)", + fit$posterior_mean_pairwise["c1", "c2"] + ) + ) + # d2-d3 (true = 0.0) should be near zero + expect_true( + abs(fit$posterior_mean_pairwise["d2", "d3"]) < 0.2, + info = sprintf( + "d2-d3 should be ~0 but is %.3f", + fit$posterior_mean_pairwise["d2", "d3"] + ) + ) + # d1-d2 (true = -0.4) should be negative + expect_true( + fit$posterior_mean_pairwise["d1", "d2"] < -0.15, + info = sprintf( + "d1-d2 should be ~-0.4 but is %.3f", + fit$posterior_mean_pairwise["d1", "d2"] + ) + ) +}) + + test_that("bgm GGM implied regression matches OLS for large n", { skip_on_cran() @@ -725,7 +1123,6 @@ test_that("bgm GGM implied regression matches OLS for large n", { # Reconstruct posterior mean precision matrix omega_hat = fit$posterior_mean_pairwise - diag(omega_hat) = as.numeric(fit$posterior_mean_main) # For each variable j, the implied regression coefficients are: # beta_j = -omega_{j,-j} / omega_{jj} @@ -752,3 +1149,89 @@ test_that("bgm GGM implied regression matches OLS for large n", { ) } }) + + +# ============================================================================== +# Estimate-Simulate-Re-estimate Cycle Tests +# ============================================================================== +# Verify self-consistency: fit -> simulate -> re-fit -> compare. +# Posterior mean parameters from the re-fit should correlate with the original. + +test_that("estimate-simulate-re-estimate cycle recovers parameters (OMRF)", { + skip_on_cran() + + data("Wenchuan", package = "bgms") + fit1 = bgm(Wenchuan[1:100, 1:4], + iter = 2000, warmup = 500, + edge_selection = FALSE, chains = 1, display_progress = "none" + ) + sim = simulate(fit1, nsim = 200, method = "posterior-mean") + fit2 = bgm(sim, + iter = 2000, warmup = 500, + edge_selection = FALSE, chains = 1, display_progress = "none" + ) + + cor_pw = cor( + as.numeric(fit1$posterior_mean_pairwise), + as.numeric(fit2$posterior_mean_pairwise) + ) + expect_gt(cor_pw, 0.7) +}) + +test_that("estimate-simulate-re-estimate cycle recovers parameters (GGM)", { + skip_on_cran() + + set.seed(42) + x = matrix(rnorm(400), nrow = 100, ncol = 4) + colnames(x) = paste0("V", 1:4) + + fit1 = bgm(x, + variable_type = "continuous", + edge_selection = FALSE, iter = 2000, warmup = 500, + chains = 1, display_progress = "none" + ) + sim = simulate(fit1, nsim = 200, method = "posterior-mean") + fit2 = bgm(sim, + variable_type = "continuous", + edge_selection = FALSE, iter = 2000, warmup = 500, + chains = 1, display_progress = "none" + ) + + cor_pw = cor( + as.numeric(fit1$posterior_mean_pairwise), + as.numeric(fit2$posterior_mean_pairwise) + ) + expect_gt(cor_pw, 0.7) +}) + +test_that("estimate-simulate-re-estimate cycle recovers parameters (mixed MRF)", { + skip_on_cran() + + set.seed(99) + n = 100 + x = data.frame( + d1 = sample(0:2, n, replace = TRUE), + c1 = rnorm(n), + d2 = sample(0:2, n, replace = TRUE), + c2 = rnorm(n) + ) + vtypes = c("ordinal", "continuous", "ordinal", "continuous") + + fit1 = bgm(x, + variable_type = vtypes, + edge_selection = FALSE, iter = 2000, warmup = 500, + chains = 1, display_progress = "none" + ) + sim = simulate(fit1, nsim = 200, method = "posterior-mean") + fit2 = bgm(sim, + variable_type = vtypes, + edge_selection = FALSE, iter = 2000, warmup = 500, + chains = 1, display_progress = "none" + ) + + cor_pw = cor( + as.numeric(fit1$posterior_mean_pairwise), + as.numeric(fit2$posterior_mean_pairwise) + ) + expect_gt(cor_pw, 0.7) +}) diff --git a/tests/testthat/test-build-arguments.R b/tests/testthat/test-build-arguments.R index fa52e8cb..c31bd463 100644 --- a/tests/testthat/test-build-arguments.R +++ b/tests/testthat/test-build-arguments.R @@ -98,7 +98,8 @@ test_that("GGM build_arguments: all expected field names present", { "beta_bernoulli_alpha_between", "beta_bernoulli_beta_between", "dirichlet_alpha", "lambda", "na_action", "version", "update_method", "target_accept", "num_chains", - "data_columnnames", "no_variables", "is_continuous" + "data_columnnames", "no_variables", "is_continuous", + "model_type" ) expect_true(all(expected %in% names(a)), info = paste("Missing:", paste(setdiff(expected, names(a)), @@ -162,7 +163,8 @@ test_that("OMRF build_arguments: all expected field names present", { "nuts_max_depth", "learn_mass_matrix", "num_chains", "num_categories", "data_columnnames", "baseline_category", - "pairwise_scaling_factors", "no_variables" + "pairwise_scaling_factors", "no_variables", + "model_type" ) expect_true(all(expected %in% names(a)), info = paste("Missing:", paste(setdiff(expected, names(a)), @@ -222,7 +224,8 @@ test_that("Compare build_arguments: all expected field names present", { "num_chains", "num_groups", "data_columnnames", "projection", "num_categories", "is_ordinal_variable", - "group", "pairwise_scaling_factors" + "group", "pairwise_scaling_factors", + "model_type" ) expect_true(all(expected %in% names(a)), info = paste("Missing:", paste(setdiff(expected, names(a)), diff --git a/tests/testthat/test-extractor-functions.R b/tests/testthat/test-extractor-functions.R index 770a9777..19a5b77a 100644 --- a/tests/testthat/test-extractor-functions.R +++ b/tests/testthat/test-extractor-functions.R @@ -12,44 +12,9 @@ # ============================================================================== # ------------------------------------------------------------------------------ -# Fixture Specifications +# Fixture Specifications — defined in helper-fixtures.R +# get_extractor_fixtures() # ------------------------------------------------------------------------------ -# Define all fixtures to test against, with their properties - -get_all_fixtures = function() { - list( - list( - label = "bgms_binary", - get_fit = get_bgms_fit, - type = "bgms", - var_type = "binary" - ), - list( - label = "bgms_ordinal", - get_fit = get_bgms_fit_ordinal, - type = "bgms", - var_type = "ordinal" - ), - list( - label = "bgms_blumecapel", - get_fit = get_bgms_fit_blumecapel, - type = "bgms", - var_type = "blume-capel" - ), - list( - label = "bgmCompare_binary", - get_fit = get_bgmcompare_fit, - type = "bgmCompare", - var_type = "binary" - ), - list( - label = "bgmCompare_ordinal", - get_fit = get_bgmcompare_fit_ordinal, - type = "bgmCompare", - var_type = "ordinal" - ) - ) -} # ------------------------------------------------------------------------------ @@ -57,7 +22,7 @@ get_all_fixtures = function() { # ------------------------------------------------------------------------------ test_that("extract_arguments returns complete argument list for all fit types", { - fixtures = get_all_fixtures() + fixtures = get_extractor_fixtures() for(spec in fixtures) { ctx = sprintf("[%s]", spec$label) @@ -97,7 +62,7 @@ test_that("extract_arguments errors on non-bgms objects", { # ------------------------------------------------------------------------------ test_that("extract_pairwise_interactions returns valid matrix for all fit types", { - fixtures = get_all_fixtures() + fixtures = get_extractor_fixtures() for(spec in fixtures) { ctx = sprintf("[%s]", spec$label) @@ -125,28 +90,45 @@ test_that("extract_pairwise_interactions returns valid matrix for all fit types" # ------------------------------------------------------------------------------ -# extract_category_thresholds() Tests (parameterized) +# extract_main_effects() Tests (parameterized) # ------------------------------------------------------------------------------ -test_that("extract_category_thresholds returns valid output for all fit types", { - fixtures = get_all_fixtures() +test_that("extract_main_effects returns valid output for all fit types", { + fixtures = get_extractor_fixtures() for(spec in fixtures) { ctx = sprintf("[%s]", spec$label) fit = spec$get_fit() args = extract_arguments(fit) - thresholds = extract_category_thresholds(fit) - - # Structure checks - expect_true(is.matrix(thresholds), info = paste(ctx, "should be matrix")) - - # Values finite where not NA - vals = thresholds[!is.na(thresholds)] - expect_true(all(is.finite(vals)), info = paste(ctx, "non-NA values should be finite")) + main = extract_main_effects(fit) + + if(isTRUE(args$is_continuous)) { + # GGM: no main effects; returns NULL silently + main_null = extract_main_effects(fit) + expect_null(main_null, info = paste(ctx, "GGM should return NULL")) + } else if(isTRUE(args$is_mixed)) { + # Mixed MRF returns a list + expect_true(is.list(main), info = paste(ctx, "should be list for mixed")) + expect_true(is.matrix(main$discrete), info = paste(ctx, "$discrete should be matrix")) + expect_true(is.matrix(main$continuous), info = paste(ctx, "$continuous should be matrix")) + } else { + # OMRF / Blume-Capel return matrix + expect_true(is.matrix(main), info = paste(ctx, "should be matrix")) + vals = main[!is.na(main)] + expect_true(all(is.finite(vals)), info = paste(ctx, "non-NA values should be finite")) + } } }) +test_that("extract_category_thresholds emits deprecation warning", { + fit = get_bgms_fit() + expect_warning( + extract_category_thresholds(fit), + "extract_main_effects" + ) +}) + # ------------------------------------------------------------------------------ # extract_indicators() and extract_posterior_inclusion_probabilities() Tests @@ -234,7 +216,7 @@ test_that("extract_indicators errors when edge_selection = FALSE", { # ------------------------------------------------------------------------------ test_that("extract_rhat returns valid diagnostics for all fit types", { - fixtures = get_all_fixtures() + fixtures = get_extractor_fixtures() for(spec in fixtures) { ctx = sprintf("[%s]", spec$label) @@ -260,7 +242,7 @@ test_that("extract_rhat returns valid diagnostics for all fit types", { }) test_that("extract_ess returns valid diagnostics for all fit types", { - fixtures = get_all_fixtures() + fixtures = get_extractor_fixtures() for(spec in fixtures) { ctx = sprintf("[%s]", spec$label) @@ -390,7 +372,7 @@ test_that("extractor outputs are dimensionally consistent", { pairwise = extract_pairwise_interactions(fit) expect_equal(ncol(pairwise), n_edges, info = paste(ctx, "pairwise cols")) - thresholds = extract_category_thresholds(fit) + thresholds = suppressWarnings(extract_category_thresholds(fit)) expect_equal(nrow(thresholds), p, info = paste(ctx, "threshold rows")) } }) @@ -719,18 +701,18 @@ test_that("extract_indicator_priors returns Stochastic-Block parameters", { # These tests verify backward compatibility with fit objects from older bgms versions. # Legacy fixtures are stored in tests/testthat/fixtures/legacy/ (NOT shipped with package). # -# To generate fixtures, run: Rscript dev/generate_legacy_fixtures.R +# To generate fixtures, run: Rscript tests/fixtures/generate_legacy_fixtures.R # # Tests skip on CRAN since fixtures aren't available in installed package. # -# PATTERN: Unified fixture specs for both bgm and bgmCompare, mirroring get_all_fixtures() +# PATTERN: Unified fixture specs for both bgm and bgmCompare, mirroring get_extractor_fixtures() # ============================================================================== # Legacy Format Compatibility Tests # ============================================================================== # # These tests verify backward compatibility with fit objects from older bgms # versions. They require legacy fixture files (*.rds) that are: -# - Generated by dev/generate_legacy_fixtures.R +# - Generated by tests/fixtures/generate_legacy_fixtures.R # - Stored in tests/testthat/fixtures/legacy/ # - NOT shipped to CRAN (excluded via .Rbuildignore) # - Skipped on CRAN via skip_on_cran() in get_legacy_dir() @@ -770,12 +752,12 @@ get_legacy_dir = function() { load_legacy_fixture = function(filename) { legacy_dir = get_legacy_dir() if(is.null(legacy_dir)) { - skip("Legacy fixtures directory not found - run dev/generate_legacy_fixtures.R") + skip("Legacy fixtures directory not found - run tests/fixtures/generate_legacy_fixtures.R") } path = file.path(legacy_dir, paste0(filename, ".rds")) if(!file.exists(path)) { - skip(paste("Legacy fixture not found:", filename, "- run dev/generate_legacy_fixtures.R")) + skip(paste("Legacy fixture not found:", filename, "- run tests/fixtures/generate_legacy_fixtures.R")) } readRDS(path) } @@ -803,7 +785,7 @@ categorize_version = function(version, type = "bgm") { } # Build legacy fixture specs from available files -# Returns list of specs like get_all_fixtures(), with: +# Returns list of specs like get_extractor_fixtures(), with: # label, version, type (bgm/bgmCompare), era, get_fit get_legacy_fixture_specs = function() { legacy_dir = get_legacy_dir() @@ -888,8 +870,14 @@ test_that("pre-0.1.4 bgm formats emit deprecation warnings for pairwise/threshol expect_warning(extract_pairwise_interactions(fit), "deprecated", info = paste(spec$label, "extract_pairwise should warn") ) - expect_warning(extract_category_thresholds(fit), "deprecated", - info = paste(spec$label, "extract_thresholds should warn") + threshold_warnings = capture_warnings(extract_category_thresholds(fit)) + expect_true( + any(grepl("extract_main_effects", threshold_warnings)), + info = paste(spec$label, "extract_thresholds should warn about rename") + ) + expect_true( + any(grepl("deprecated", threshold_warnings)), + info = paste(spec$label, "extract_thresholds should warn about legacy format") ) } }) @@ -910,8 +898,14 @@ test_that("0.1.4-0.1.5 formats emit deprecation warnings", { expect_warning(extract_pairwise_interactions(fit), "deprecated", info = paste(spec$label, "extract_pairwise should warn") ) - expect_warning(extract_category_thresholds(fit), "deprecated", - info = paste(spec$label, "extract_thresholds should warn") + threshold_warnings = capture_warnings(extract_category_thresholds(fit)) + expect_true( + any(grepl("extract_main_effects", threshold_warnings)), + info = paste(spec$label, "extract_thresholds should warn about rename") + ) + expect_true( + any(grepl("deprecated", threshold_warnings)), + info = paste(spec$label, "extract_thresholds should warn about legacy format") ) # bgmCompare also has extract_group_params @@ -934,7 +928,8 @@ test_that("0.1.6+ formats work without deprecation warnings", { expect_no_warning(extract_indicators(fit)) expect_no_warning(extract_posterior_inclusion_probabilities(fit)) expect_no_warning(extract_pairwise_interactions(fit)) - expect_no_warning(extract_category_thresholds(fit)) + expect_warning(extract_category_thresholds(fit), "extract_main_effects") + expect_no_warning(extract_main_effects(fit)) } }) diff --git a/tests/testthat/test-input-validation.R b/tests/testthat/test-input-validation.R index d2994044..f761313c 100644 --- a/tests/testthat/test-input-validation.R +++ b/tests/testthat/test-input-validation.R @@ -84,14 +84,21 @@ test_that("GGM rejects NUTS and HMC update methods", { ) }) -test_that("Mixed continuous and ordinal variable types are rejected", { +test_that("Mixed continuous and ordinal variable types are accepted for bgm", { set.seed(42) - x = matrix(rnorm(200), nrow = 50, ncol = 4) - - expect_error( - bgm(x = x, variable_type = c("continuous", "ordinal", "ordinal", "ordinal")), - "all variables must be of type" + x = data.frame( + ord1 = sample(0:2, 50, replace = TRUE), + ord2 = sample(0:2, 50, replace = TRUE), + cont1 = rnorm(50), + cont2 = rnorm(50) ) + spec = bgm_spec( + x = x, + variable_type = c("ordinal", "ordinal", "continuous", "continuous") + ) + expect_equal(spec$model_type, "mixed_mrf") + expect_equal(spec$data$num_discrete, 2L) + expect_equal(spec$data$num_continuous, 2L) }) @@ -120,6 +127,37 @@ test_that("bgmCompare errors on mismatched group_indicator length", { ) }) +test_that("bgmCompare rejects continuous variable type", { + x = matrix(rnorm(100), nrow = 50, ncol = 2) + group_ind = rep(1:2, each = 25) + + expect_error( + bgmCompare( + x = x, group_indicator = group_ind, + variable_type = "continuous" + ), + regexp = "not of type continuous" + ) +}) + +test_that("bgmCompare rejects mixed ordinal + continuous variable types", { + x = data.frame( + ord1 = sample(0:2, 50, replace = TRUE), + cont1 = rnorm(50) + ) + group_ind = rep(1:2, each = 25) + + # allow_continuous = FALSE fires before the mixed check, so the error + # message is the same as for pure continuous input. + expect_error( + bgmCompare( + x = x, group_indicator = group_ind, + variable_type = c("ordinal", "continuous") + ), + regexp = "not of type continuous" + ) +}) + # ------------------------------------------------------------------------------ # simulate_mrf() Input Validation diff --git a/tests/testthat/test-methods.R b/tests/testthat/test-methods.R index cee76656..df589bc1 100644 --- a/tests/testthat/test-methods.R +++ b/tests/testthat/test-methods.R @@ -16,161 +16,54 @@ # ============================================================================== # ------------------------------------------------------------------------------ -# Fixture Specifications +# Fixture Specifications — defined in helper-fixtures.R +# get_bgms_fixtures(), get_bgmcompare_fixtures() # ------------------------------------------------------------------------------ -get_bgms_fixtures = function() { - list( - list( - label = "binary", - get_fit = get_bgms_fit, - get_prediction_data = get_prediction_data_binary, - var_type = "binary", - is_continuous = FALSE - ), - list( - label = "ordinal", - get_fit = get_bgms_fit_ordinal, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", - is_continuous = FALSE - ), - list( - label = "single-chain", - get_fit = get_bgms_fit_single_chain, - get_prediction_data = get_prediction_data_binary, - var_type = "binary", - is_continuous = FALSE - ), - list( - label = "blume-capel", - get_fit = get_bgms_fit_blumecapel, - get_prediction_data = get_prediction_data_ordinal, - var_type = "blume-capel", - is_continuous = FALSE - ), - list( - label = "adaptive-metropolis", - get_fit = get_bgms_fit_adaptive_metropolis, - get_prediction_data = get_prediction_data_binary, - var_type = "binary", - is_continuous = FALSE - ), - list( - label = "hmc", - get_fit = get_bgms_fit_hmc, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", - is_continuous = FALSE - ), - list( - label = "am-blumecapel", - get_fit = get_bgms_fit_am_blumecapel, - get_prediction_data = get_prediction_data_ordinal, - var_type = "blume-capel", - is_continuous = FALSE - ), - list( - label = "impute", - get_fit = get_bgms_fit_impute, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", - is_continuous = FALSE - ), - list( - label = "standardize", - get_fit = get_bgms_fit_standardize, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", - is_continuous = FALSE - ), - list( - label = "ggm", - get_fit = get_bgms_fit_ggm, - get_prediction_data = get_prediction_data_ggm, - var_type = "continuous", - is_continuous = TRUE - ), - list( - label = "ggm-no-es", - get_fit = get_bgms_fit_ggm_no_es, - get_prediction_data = get_prediction_data_ggm, - var_type = "continuous", - is_continuous = TRUE - ) + +# ============================================================================== +# Fixture Coverage Guards +# ============================================================================== +# Fail fast when a fixture list drifts from the expected coverage. + +test_that("get_bgms_fixtures covers all required labels", { + specs = get_bgms_fixtures() + labels = vapply(specs, `[[`, character(1), "label") + required = c("binary", "ordinal", "ggm", "mixed-mrf", "blume-capel") + for(r in required) { + expect_true(r %in% labels, info = sprintf("missing required label '%s'", r)) + } + expect_equal(length(specs), 24L, + info = "bgms fixture count changed — update this guard if intentional" ) -} - -get_bgmcompare_fixtures = function() { - list( - list( - label = "binary", - get_fit = get_bgmcompare_fit, - get_prediction_data = get_prediction_data_bgmcompare_binary, - var_type = "binary" - ), - list( - label = "ordinal", - get_fit = get_bgmcompare_fit_ordinal, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ), - list( - label = "adaptive-metropolis", - get_fit = get_bgmcompare_fit_adaptive_metropolis, - get_prediction_data = get_prediction_data_bgmcompare_binary, - var_type = "binary" - ), - list( - label = "hmc", - get_fit = get_bgmcompare_fit_hmc, - get_prediction_data = get_prediction_data_bgmcompare_binary, - var_type = "binary" - ), - list( - label = "hmc-blume-capel", - get_fit = get_bgmcompare_fit_hmc_blumecapel, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "blume-capel", - get_fit = get_bgmcompare_fit_blumecapel, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "am-blume-capel", - get_fit = get_bgmcompare_fit_am_blumecapel, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "impute", - get_fit = get_bgmcompare_fit_impute, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ), - list( - label = "blume-capel-impute", - get_fit = get_bgmcompare_fit_blumecapel_impute, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "beta-bernoulli", - get_fit = get_bgmcompare_fit_beta_bernoulli, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ), - list( - label = "standardize", - get_fit = get_bgmcompare_fit_standardize, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ) +}) + +test_that("get_bgmcompare_fixtures covers all required labels", { + specs = get_bgmcompare_fixtures() + labels = vapply(specs, `[[`, character(1), "label") + required = c("binary", "ordinal", "blume-capel") + for(r in required) { + expect_true(r %in% labels, info = sprintf("missing required label '%s'", r)) + } + expect_equal(length(specs), 13L, + info = "bgmcompare fixture count changed — update this guard if intentional" ) -} +}) + +test_that("get_extractor_fixtures covers all model families", { + specs = get_extractor_fixtures() + labels = vapply(specs, `[[`, character(1), "label") + required = c( + "bgms_binary", "bgms_ggm", "bgms_mixed", + "bgmCompare_binary", "bgmCompare_ordinal" + ) + for(r in required) { + expect_true(r %in% labels, info = sprintf("missing required label '%s'", r)) + } + expect_equal(length(specs), 9L, + info = "extractor fixture count changed — update this guard if intentional" + ) +}) # ============================================================================== @@ -210,7 +103,8 @@ test_that("summary.bgms returns correct structure for all fixture types", { summ = summary(fit) expect_s3_class(summ, "summary.bgms") - expect_true("main" %in% names(summ) || "pairwise" %in% names(summ), info = ctx) + has_content = !is.null(summ$main) || !is.null(summ$quadratic) || !is.null(summ$pairwise) + expect_true(has_content, info = ctx) } }) @@ -322,6 +216,7 @@ test_that("coef.bgmCompare returns group-specific effects for all fixture types" test_that("simulate.bgms returns matrix of correct size for ordinal fixtures", { for(spec in get_bgms_fixtures()) { if(isTRUE(spec$is_continuous)) next + if(isTRUE(spec$is_mixed)) next ctx = sprintf("[bgms %s]", spec$label) fit = spec$get_fit() @@ -371,6 +266,37 @@ test_that("simulate.bgms returns matrix of correct size for GGM fixtures", { } }) +test_that("simulate.bgms returns matrix of correct size for mixed-mrf fixtures", { + for(spec in get_bgms_fixtures()) { + if(!isTRUE(spec$is_mixed)) next + + ctx = sprintf("[bgms %s]", spec$label) + fit = spec$get_fit() + args = extract_arguments(fit) + + n_sim = 30 + simulated = simulate(fit, nsim = n_sim, method = "posterior-mean", seed = 123) + + expect_true(is.matrix(simulated), info = ctx) + expect_equal(nrow(simulated), n_sim, info = paste(ctx, "wrong nrow")) + expect_equal(ncol(simulated), args$num_variables, info = paste(ctx, "wrong ncol")) + expect_equal(colnames(simulated), args$data_columnnames, info = ctx) + expect_true(all(is.finite(simulated)), info = paste(ctx, "non-finite values")) + + # Ordinal columns should be non-negative integers + for(j in seq_len(args$num_variables)) { + if(args$variable_type[j] != "continuous") { + expect_true(all(simulated[, j] == round(simulated[, j])), + info = sprintf("%s variable %d should be integer", ctx, j) + ) + expect_true(all(simulated[, j] >= 0), + info = sprintf("%s variable %d has negative values", ctx, j) + ) + } + } + } +}) + test_that("simulate.bgms is reproducible with seed", { fit = get_bgms_fit() @@ -466,6 +392,7 @@ test_that("simulate.bgms GGM posterior-sample returns list of numeric matrices", test_that("predict.bgms returns valid probabilities for ordinal fixtures", { for(spec in get_bgms_fixtures()) { if(isTRUE(spec$is_continuous)) next + if(isTRUE(spec$is_mixed)) next ctx = sprintf("[bgms %s]", spec$label) fit = spec$get_fit() @@ -533,6 +460,51 @@ test_that("predict.bgms returns valid conditional moments for GGM fixtures", { } }) +test_that("predict.bgms returns valid predictions for mixed-mrf fixtures", { + for(spec in get_bgms_fixtures()) { + if(!isTRUE(spec$is_mixed)) next + + ctx = sprintf("[bgms %s]", spec$label) + fit = spec$get_fit() + args = extract_arguments(fit) + + newdata = spec$get_prediction_data(n = 5) + probs = predict(fit, newdata = newdata, type = "probabilities") + + expect_true(is.list(probs), info = ctx) + expect_equal(length(probs), args$num_variables, info = ctx) + + for(j in seq_len(args$num_variables)) { + vname = args$data_columnnames[j] + expect_equal(nrow(probs[[j]]), nrow(newdata), + info = sprintf("%s %s nrow", ctx, vname) + ) + expect_false(anyNA(probs[[j]]), + info = sprintf("%s %s has NAs", ctx, vname) + ) + + if(args$variable_type[j] %in% c("ordinal", "blume-capel")) { + # Discrete variables: probability rows sum to 1 + row_sums = rowSums(probs[[j]]) + expect_true( + all(abs(row_sums - 1) < 1e-6), + info = sprintf("%s %s probs don't sum to 1", ctx, vname) + ) + } else { + # Continuous variables: 2-column (mean, sd) matrix + expect_equal(ncol(probs[[j]]), 2, + info = sprintf("%s %s ncol", ctx, vname) + ) + } + } + + # type = "response" returns a matrix + resp = predict(fit, newdata = newdata, type = "response") + expect_true(is.matrix(resp), info = ctx) + expect_equal(dim(resp), c(nrow(newdata), args$num_variables), info = ctx) + } +}) + test_that("predict.bgms response returns integer categories", { fit = get_bgms_fit() args = extract_arguments(fit) @@ -652,7 +624,6 @@ test_that("predict.bgms GGM conditional mean matches analytic formula", { # Reconstruct the posterior mean precision matrix omega_hat = fit$posterior_mean_pairwise - diag(omega_hat) = as.numeric(fit$posterior_mean_main) p = args$num_variables # Center newdata by its own means (predict does the same internally) diff --git a/tests/testthat/test-mixed-mrf-likelihood.R b/tests/testthat/test-mixed-mrf-likelihood.R new file mode 100644 index 00000000..d6cdff08 --- /dev/null +++ b/tests/testthat/test-mixed-mrf-likelihood.R @@ -0,0 +1,596 @@ +# ============================================================================== +# Conditional Distribution Correctness Tests +# ============================================================================== +# +# Verify that the C++ conditional prediction functions produce values +# matching hand-computed reference values for minimal networks. +# +# Ported from mixedGM::test-likelihood-correctness.R, adapted for +# the bgms API surface: compute_conditional_probs (OMRF), +# compute_conditional_ggm (GGM), compute_conditional_mixed (mixed MRF). +# +# ============================================================================== + + +# ============================================================================== +# Test 1: OMRF binary conditional probabilities +# ============================================================================== +# Minimal example: p=2 binary ordinals, n=3. +# +# P(x_i = c | x_{-i}) proportional to: +# c = 0: 1 (reference) +# c = 1: exp(main[i,1] + 1 * rest_i) +# +# rest_i = sum_{k != i} pairwise[k,i] * x_k + +test_that("OMRF binary conditional probabilities match hand computation", { + n = 3L + p = 2L + + observations = matrix( + c( + 0L, 1L, 1L, + 1L, 0L, 1L + ), + nrow = n, ncol = p + ) + + pairwise = matrix( + c( + 0.0, 0.3, + 0.3, 0.0 + ), + nrow = p, byrow = TRUE + ) + main = matrix(c(-0.5, 0.2), nrow = p, ncol = 1) + num_categories = c(1L, 1L) + variable_type = c("ordinal", "ordinal") + baseline_category = c(0L, 0L) + + # --- Predict variable 0 (1st variable) --- + # rest = pairwise[1,0] * x_2 = 0.3 * c(1, 0, 1) = c(0.3, 0, 0.3) + rest_v0 = c(0.3, 0.0, 0.3) + mu0 = main[1, 1] # -0.5 + + # P(x_0 = 0) = 1 / (1 + exp(mu0 + rest)) + # P(x_0 = 1) = exp(mu0 + rest) / (1 + exp(mu0 + rest)) + logit_v0 = mu0 + rest_v0 + prob_v0_cat1 = exp(logit_v0) / (1 + exp(logit_v0)) + prob_v0_cat0 = 1 - prob_v0_cat1 + + probs_cpp = compute_conditional_probs( + observations = observations, + predict_vars = 0L, + pairwise = pairwise, + main = main, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(probs_cpp[[1]][, 1], prob_v0_cat0, tolerance = 1e-10) + expect_equal(probs_cpp[[1]][, 2], prob_v0_cat1, tolerance = 1e-10) + + # --- Predict variable 1 (2nd variable) --- + # rest = pairwise[0,1] * x_1 = 0.3 * c(0, 1, 1) = c(0, 0.3, 0.3) + rest_v1 = c(0.0, 0.3, 0.3) + mu1 = main[2, 1] # 0.2 + + logit_v1 = mu1 + rest_v1 + prob_v1_cat1 = exp(logit_v1) / (1 + exp(logit_v1)) + prob_v1_cat0 = 1 - prob_v1_cat1 + + probs_cpp2 = compute_conditional_probs( + observations = observations, + predict_vars = 1L, + pairwise = pairwise, + main = main, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(probs_cpp2[[1]][, 1], prob_v1_cat0, tolerance = 1e-10) + expect_equal(probs_cpp2[[1]][, 2], prob_v1_cat1, tolerance = 1e-10) + + # Probabilities sum to 1 + expect_equal(rowSums(probs_cpp[[1]]), rep(1, n), tolerance = 1e-10) + expect_equal(rowSums(probs_cpp2[[1]]), rep(1, n), tolerance = 1e-10) +}) + + +# ============================================================================== +# Test 2: Multi-category ordinal conditional probabilities +# ============================================================================== +# p=1 ordinal with 3 categories (0, 1, 2), n=4. +# +# P(x = c | ...) proportional to: +# c = 0: 1 +# c = 1: exp(main[1,1] + 1 * rest) +# c = 2: exp(main[1,2] + 2 * rest) + +test_that("OMRF multi-category conditional probabilities match hand computation", { + n = 4L + p = 2L + + observations = matrix( + c( + 0L, 1L, 2L, 1L, + 1L, 0L, 2L, 1L + ), + nrow = n, ncol = p + ) + + pairwise = matrix( + c( + 0.0, 0.25, + 0.25, 0.0 + ), + nrow = p, byrow = TRUE + ) + # 2 categories (3 levels: 0, 1, 2) => 2 threshold columns + main = matrix( + c( + -0.5, 0.1, + 0.2, -0.3 + ), + nrow = p, ncol = 2, byrow = TRUE + ) + num_categories = c(2L, 2L) + variable_type = c("ordinal", "ordinal") + baseline_category = c(0L, 0L) + + # Predict variable 0: + # rest = pairwise[1,0] * x_2 = 0.25 * c(1, 0, 2, 1) = c(0.25, 0, 0.5, 0.25) + rest = 0.25 * c(1, 0, 2, 1) + + # For each observation, compute unnormalized probabilities + hand_probs = matrix(NA_real_, nrow = n, ncol = 3) + for(i in seq_len(n)) { + log_unnorm = c( + 0, # c = 0 (reference) + main[1, 1] + 1 * rest[i], # c = 1 + main[1, 2] + 2 * rest[i] # c = 2 + ) + # Stable softmax + max_val = max(log_unnorm) + unnorm = exp(log_unnorm - max_val) + hand_probs[i, ] = unnorm / sum(unnorm) + } + + probs_cpp = compute_conditional_probs( + observations = observations, + predict_vars = 0L, + pairwise = pairwise, + main = main, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(as.matrix(probs_cpp[[1]]), hand_probs, tolerance = 1e-10) + expect_equal(rowSums(probs_cpp[[1]]), rep(1, n), tolerance = 1e-10) +}) + + +# ============================================================================== +# Test 3: GGM conditional mean and sd +# ============================================================================== +# For precision matrix Omega: +# E[X_j | X_{-j}] = -(1/omega_jj) * sum_{k != j} omega_jk * x_k +# SD[X_j | X_{-j}] = sqrt(1/omega_jj) +# +# Note: compute_conditional_ggm operates on centered data. + +test_that("GGM conditional mean and sd match hand computation", { + n = 4L + p = 3L + + # Symmetric positive definite precision matrix + omega = matrix(c( + 2.0, -0.5, 0.3, + -0.5, 1.5, -0.2, + 0.3, -0.2, 1.0 + ), nrow = p, byrow = TRUE) + + # Centered observations + x = matrix(c( + 1.0, -0.5, 0.3, + 0.2, 0.8, -0.4, + -0.7, 0.1, 0.9, + 0.5, -0.3, 0.2 + ), nrow = n, byrow = TRUE) + + # Predict variable 0: + # mean = -(1/2) * ((-0.5) * x[,2] + 0.3 * x[,3]) + hand_mean_v0 = -(1 / omega[1, 1]) * + (omega[1, 2] * x[, 2] + omega[1, 3] * x[, 3]) + hand_sd_v0 = sqrt(1 / omega[1, 1]) + + result = compute_conditional_ggm( + observations = x, + predict_vars = 0L, + precision = omega + ) + + expect_equal(result[[1]][, 1], hand_mean_v0, tolerance = 1e-10) + expect_equal(result[[1]][1, 2], hand_sd_v0, tolerance = 1e-10) + + # Predict variable 1: + hand_mean_v1 = -(1 / omega[2, 2]) * + (omega[2, 1] * x[, 1] + omega[2, 3] * x[, 3]) + hand_sd_v1 = sqrt(1 / omega[2, 2]) + + result2 = compute_conditional_ggm( + observations = x, + predict_vars = 1L, + precision = omega + ) + + expect_equal(result2[[1]][, 1], hand_mean_v1, tolerance = 1e-10) + expect_equal(result2[[1]][1, 2], hand_sd_v1, tolerance = 1e-10) + + # Predict all variables at once + result_all = compute_conditional_ggm( + observations = x, + predict_vars = c(0L, 1L, 2L), + precision = omega + ) + + expect_equal(length(result_all), p) + expect_equal(result_all[[1]][, 1], hand_mean_v0, tolerance = 1e-10) + expect_equal(result_all[[2]][, 1], hand_mean_v1, tolerance = 1e-10) +}) + + +# ============================================================================== +# Test 4: Mixed MRF discrete conditional probabilities +# ============================================================================== +# p=2 binary ordinals, q=1 continuous, n=3. +# +# For discrete variable s, the rest scores in the mixed MRF are: +# rest = sum_{k != s} (x_k - ref_k) * Kxx[k,s] +# + sum_j 2 * Kxy[s,j] * y_j +# +# Then P(x_s = c | rest) follows the same softmax as the pure OMRF. + +test_that("mixed MRF discrete conditional probabilities match hand computation", { + n = 3L + p = 2L + q = 1L + + x_obs = matrix( + c( + 0L, 1L, 1L, + 1L, 0L, 1L + ), + nrow = n, ncol = p + ) + y_obs = matrix(c(0.5, -0.3, 1.2), nrow = n, ncol = q) + + Kxx = matrix(c( + 0.0, 0.3, + 0.3, 0.0 + ), nrow = p, byrow = TRUE) + Kxy = matrix(c(0.2, 0.4), nrow = p, ncol = q) + Kyy = matrix(2.0, nrow = q, ncol = q) + mux = matrix(c(-0.5, 0.2), nrow = p, ncol = 1) + muy = c(0.1) + + num_categories = c(1L, 1L) + variable_type = c("ordinal", "ordinal") + baseline_category = c(0L, 0L) + + # --- Predict discrete variable 0 --- + # rest_discrete = Kxx[1,0] * (x_2 - 0) = 0.3 * c(1, 0, 1) + # rest_continuous = 2 * Kxy[0,0] * y = 2 * 0.2 * c(0.5, -0.3, 1.2) + # rest = c(0.3, 0, 0.3) + c(0.2, -0.12, 0.48) = c(0.5, -0.12, 0.78) + rest_v0 = c(0.5, -0.12, 0.78) + mu0 = mux[1, 1] # -0.5 + + logit_v0 = mu0 + rest_v0 + hand_prob_cat1 = exp(logit_v0) / (1 + exp(logit_v0)) + hand_prob_cat0 = 1 - hand_prob_cat1 + + result = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = 0L, + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(result[[1]][, 1], hand_prob_cat0, tolerance = 1e-10) + expect_equal(result[[1]][, 2], hand_prob_cat1, tolerance = 1e-10) + expect_equal(rowSums(result[[1]]), rep(1, n), tolerance = 1e-10) + + # --- Predict discrete variable 1 --- + # rest_discrete = Kxx[0,1] * (x_1 - 0) = 0.3 * c(0, 1, 1) + # rest_continuous = 2 * Kxy[1,0] * y = 2 * 0.4 * c(0.5, -0.3, 1.2) + # rest = c(0, 0.3, 0.3) + c(0.4, -0.24, 0.96) = c(0.4, 0.06, 1.26) + rest_v1 = c(0.4, 0.06, 1.26) + mu1 = mux[2, 1] # 0.2 + + logit_v1 = mu1 + rest_v1 + hand_prob1_cat1 = exp(logit_v1) / (1 + exp(logit_v1)) + hand_prob1_cat0 = 1 - hand_prob1_cat1 + + result2 = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = 1L, + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(result2[[1]][, 1], hand_prob1_cat0, tolerance = 1e-10) + expect_equal(result2[[1]][, 2], hand_prob1_cat1, tolerance = 1e-10) +}) + + +# ============================================================================== +# Test 5: Mixed MRF continuous conditional mean and sd +# ============================================================================== +# For continuous variable j in the mixed MRF: +# cond_mean = muy_j + (1/Kyy_jj) * ( +# -sum_{k != j} Kyy[j,k] * (y_k - muy_k) +# + sum_s 2 * Kxy[s,j] * (x_s - ref_s) +# ) +# cond_sd = sqrt(1/Kyy_jj) + +test_that("mixed MRF continuous conditional matches hand computation", { + n = 3L + p = 2L + q = 1L + + x_obs = matrix( + c( + 0L, 1L, 1L, + 1L, 0L, 1L + ), + nrow = n, ncol = p + ) + y_obs = matrix(c(0.5, -0.3, 1.2), nrow = n, ncol = q) + + Kxx = matrix(c( + 0.0, 0.3, + 0.3, 0.0 + ), nrow = p, byrow = TRUE) + Kxy = matrix(c(0.2, 0.4), nrow = p, ncol = q) + Kyy = matrix(2.0, nrow = q, ncol = q) + mux = matrix(c(-0.5, 0.2), nrow = p, ncol = 1) + muy = c(0.1) + + num_categories = c(1L, 1L) + variable_type = c("ordinal", "ordinal") + baseline_category = c(0L, 0L) + + # Predict continuous variable (index = p = 2, since 0-based: [0,1] are discrete) + # cond_var = 1/Kyy[0,0] = 1/2 = 0.5 + # lp_continuous = 0 (only 1 continuous variable, skip self) + # lp_discrete = 2 * Kxy[0,0] * (x_1 - 0) + 2 * Kxy[1,0] * (x_2 - 0) + # = 2 * 0.2 * x_1 + 2 * 0.4 * x_2 + # For n=1: 0.4*0 + 0.8*1 = 0.8 + # For n=2: 0.4*1 + 0.8*0 = 0.4 + # For n=3: 0.4*1 + 0.8*1 = 1.2 + lp_discrete = c(0.8, 0.4, 1.2) + + hand_mean = muy + (1 / Kyy[1, 1]) * lp_discrete + # = 0.1 + 0.5 * c(0.8, 0.4, 1.2) = c(0.5, 0.3, 0.7) + hand_sd = sqrt(1 / Kyy[1, 1]) # sqrt(0.5) + + result = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = as.integer(p), # index 2 = first continuous variable + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(result[[1]][, 1], hand_mean, tolerance = 1e-10) + expect_equal(result[[1]][1, 2], hand_sd, tolerance = 1e-10) +}) + + +# ============================================================================== +# Test 6: Mixed MRF with multiple continuous variables +# ============================================================================== +# p=1 binary ordinal, q=2 continuous, n=3. +# Tests cross-variable precision terms in the continuous conditional. + +test_that("mixed MRF continuous conditional works for q > 1", { + n = 3L + p = 1L + q = 2L + + x_obs = matrix(c(0L, 1L, 1L), nrow = n, ncol = p) + y_obs = matrix(c( + 0.5, 0.2, + -0.3, 0.8, + 1.2, -0.5 + ), nrow = n, ncol = q, byrow = TRUE) + + Kxx = matrix(0.0, nrow = p, ncol = p) + Kxy = matrix(c(0.2, 0.1), nrow = p, ncol = q) + Kyy = matrix(c(2.0, 0.3, 0.3, 1.5), nrow = q, byrow = TRUE) + mux = matrix(-0.5, nrow = p, ncol = 1) + muy = c(0.1, -0.2) + + num_categories = c(1L) + variable_type = c("ordinal") + baseline_category = c(0L) + + # --- Predict continuous variable 0 (internal index p = 1) --- + # cond_var_0 = 1/Kyy[0,0] = 1/2 = 0.5 + # lp_continuous = -Kyy[0,1] * (y[,1] - muy[1]) + # = -0.3 * (y[,1] - (-0.2)) = -0.3 * (y[,1] + 0.2) + # For n=1: -0.3 * (0.2 + 0.2) = -0.12 + # For n=2: -0.3 * (0.8 + 0.2) = -0.3 + # For n=3: -0.3 * (-0.5 + 0.2) = 0.09 + lp_cont = c(-0.12, -0.3, 0.09) + + # lp_discrete = 2 * Kxy[0,0] * (x - 0) = 0.4 * c(0, 1, 1) + lp_disc = c(0.0, 0.4, 0.4) + + hand_mean_y0 = muy[1] + 0.5 * (lp_cont + lp_disc) + hand_sd_y0 = sqrt(0.5) + + result = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = as.integer(p), # index 1 = first continuous + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(result[[1]][, 1], hand_mean_y0, tolerance = 1e-10) + expect_equal(result[[1]][1, 2], hand_sd_y0, tolerance = 1e-10) + + # --- Predict continuous variable 1 (internal index p + 1 = 2) --- + # cond_var_1 = 1/Kyy[1,1] = 1/1.5 + # lp_continuous = -Kyy[1,0] * (y[,0] - muy[0]) + # = -0.3 * (y[,0] - 0.1) + # For n=1: -0.3 * (0.5 - 0.1) = -0.12 + # For n=2: -0.3 * (-0.3 - 0.1) = 0.12 + # For n=3: -0.3 * (1.2 - 0.1) = -0.33 + lp_cont2 = c(-0.12, 0.12, -0.33) + + # lp_discrete = 2 * Kxy[0,1] * (x - 0) = 0.2 * c(0, 1, 1) + lp_disc2 = c(0.0, 0.2, 0.2) + + hand_mean_y1 = muy[2] + (1 / Kyy[2, 2]) * (lp_cont2 + lp_disc2) + hand_sd_y1 = sqrt(1 / Kyy[2, 2]) + + result2 = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = as.integer(p + 1L), # index 2 = second continuous + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(result2[[1]][, 1], hand_mean_y1, tolerance = 1e-10) + expect_equal(result2[[1]][1, 2], hand_sd_y1, tolerance = 1e-10) +}) + + +# ============================================================================== +# Test 7: Predicting all variables at once +# ============================================================================== +# Verify that predicting multiple variables yields the same results +# as predicting each one individually. + +test_that("predicting all variables at once matches individual predictions", { + n = 3L + p = 2L + q = 1L + + x_obs = matrix( + c( + 0L, 1L, 1L, + 1L, 0L, 1L + ), + nrow = n, ncol = p + ) + y_obs = matrix(c(0.5, -0.3, 1.2), nrow = n, ncol = q) + + Kxx = matrix(c(0.0, 0.3, 0.3, 0.0), nrow = p, byrow = TRUE) + Kxy = matrix(c(0.2, 0.4), nrow = p, ncol = q) + Kyy = matrix(2.0, nrow = q, ncol = q) + mux = matrix(c(-0.5, 0.2), nrow = p, ncol = 1) + muy = c(0.1) + + num_categories = c(1L, 1L) + variable_type = c("ordinal", "ordinal") + baseline_category = c(0L, 0L) + + # Predict all three variables at once + result_all = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = c(0L, 1L, as.integer(p)), + Kxx = Kxx, + Kxy = Kxy, + Kyy = Kyy, + mux = mux, + muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + # Predict each individually + result_v0 = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = 0L, + Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + mux = mux, muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + result_v1 = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = 1L, + Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + mux = mux, muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + result_vy = compute_conditional_mixed( + x_observations = x_obs, + y_observations = y_obs, + predict_vars = as.integer(p), + Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + mux = mux, muy = muy, + num_categories = num_categories, + variable_type = variable_type, + baseline_category = baseline_category + ) + + expect_equal(length(result_all), 3L) + expect_equal(as.matrix(result_all[[1]]), as.matrix(result_v0[[1]]), + tolerance = 1e-10 + ) + expect_equal(as.matrix(result_all[[2]]), as.matrix(result_v1[[1]]), + tolerance = 1e-10 + ) + expect_equal(as.matrix(result_all[[3]]), as.matrix(result_vy[[1]]), + tolerance = 1e-10 + ) +}) diff --git a/tests/testthat/test-mixed-mrf-simulate-predict.R b/tests/testthat/test-mixed-mrf-simulate-predict.R new file mode 100644 index 00000000..5273f856 --- /dev/null +++ b/tests/testthat/test-mixed-mrf-simulate-predict.R @@ -0,0 +1,715 @@ +# ============================================================================== +# tests/testthat/test-mixed-mrf-simulate-predict.R +# ============================================================================== +# Phase G tests for mixed MRF simulation and prediction. +# +# EXTENDS: test-simulate-predict-regression.R (which handles parameterized +# roundtrip tests via get_bgms_fixtures). This file covers: +# T25: Gibbs generator sanity (sample_mixed_mrf_gibbs) +# Mixed-specific structural tests for simulate.bgms / predict.bgms +# Edge cases: p=1, q=1, binary-only ordinal +# +# PATTERN: Stochastic-robust testing — dimensions, ranges, invariants, +# coarse distributional checks. No exact moment matching. +# ============================================================================== + + +# ============================================================================== +# 1. Gibbs generator sanity (T25) +# ============================================================================== +# Verify that sample_mixed_mrf_gibbs produces structurally correct output +# and coarse statistical properties match known targets. +# ============================================================================== + +test_that("sample_mixed_mrf_gibbs returns correct dimensions", { + p = 2 + q = 2 + n = 100 + nc = c(2L, 2L) + mux = matrix(c(0, 0.5, -0.3, 0, -0.2, 0.1), nrow = p, ncol = 3, byrow = TRUE) + Kxx = matrix(c(0, 0.3, 0.3, 0), p, p) + muy = c(0.5, -0.2) + Kyy = matrix(c(1.5, 0.2, 0.2, 1.8), q, q) + Kxy = matrix(c(0.1, -0.15, 0.2, 0.05), p, q) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = c("ordinal", "ordinal"), + baseline_category_r = c(0L, 0L), iter = 200L, seed = 42L + ) + + expect_true(is.list(result)) + expect_equal(dim(result$x), c(n, p)) + expect_equal(dim(result$y), c(n, q)) +}) + +test_that("sample_mixed_mrf_gibbs: discrete values in valid range", { + p = 3 + q = 1 + n = 500 + nc = c(2L, 3L, 1L) # categories 0-2, 0-3, 0-1 (binary) + mux = matrix(0, p, 4) + Kxx = matrix(0, p, p) + muy = 0 + Kyy = matrix(2) + Kxy = matrix(0, p, q) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = c("ordinal", "ordinal", "ordinal"), + baseline_category_r = c(0L, 0L, 0L), iter = 100L, seed = 1L + ) + + for(s in 1:p) { + expect_true(all(result$x[, s] >= 0), + info = sprintf("var %d has negative values", s) + ) + expect_true(all(result$x[, s] <= nc[s]), + info = sprintf("var %d exceeds max category %d", s, nc[s]) + ) + } +}) + +test_that("sample_mixed_mrf_gibbs: continuous marginal SD matches precision", { + # Independent model: Kxx = 0, Kxy = 0, so y ~ N(muy, Kyy^{-1}) + p = 1 + q = 2 + n = 2000 + nc = c(2L) + mux = matrix(0, p, 3) + Kxx = matrix(0) + muy = c(1.0, -0.5) + precision = c(2.0, 0.5) + Kyy = diag(precision) + Kxy = matrix(0, p, q) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = "ordinal", baseline_category_r = 0L, + iter = 300L, seed = 99L + ) + + for(j in 1:q) { + expected_sd = 1 / sqrt(precision[j]) + empirical_sd = sd(result$y[, j]) + # Loose check: within 30% of expected + expect_true( + abs(empirical_sd - expected_sd) / expected_sd < 0.3, + info = sprintf( + "y%d SD: expected %.3f, got %.3f", + j, expected_sd, empirical_sd + ) + ) + } +}) + +test_that("sample_mixed_mrf_gibbs: seed reproducibility", { + p = 2 + q = 1 + n = 50 + nc = c(2L, 2L) + mux = matrix(c(0, 0.5, -0.3, 0, -0.2, 0.1), nrow = p, ncol = 3, byrow = TRUE) + Kxx = matrix(c(0, 0.3, 0.3, 0), p, p) + muy = 0 + Kyy = matrix(1) + Kxy = matrix(c(0.1, 0.2), p, q) + + args = list( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = c("ordinal", "ordinal"), + baseline_category_r = c(0L, 0L), iter = 100L, seed = 42L + ) + + r1 = do.call(sample_mixed_mrf_gibbs, args) + r2 = do.call(sample_mixed_mrf_gibbs, args) + + expect_equal(r1$x, r2$x) + expect_equal(r1$y, r2$y) +}) + +test_that("sample_mixed_mrf_gibbs: Blume-Capel variable works", { + p = 2 + q = 1 + n = 200 + nc = c(2L, 4L) + # ordinal: mux has num_categories entries; BC: 2 entries (alpha, beta) + max_cols = max(nc[1], 2) + mux = matrix(0, p, max_cols) + mux[2, 1] = 0.5 # alpha + mux[2, 2] = -0.3 # beta (penalizes distance from reference) + + Kxx = matrix(c(0, 0.2, 0.2, 0), p, p) + muy = 0 + Kyy = matrix(1.5) + Kxy = matrix(c(0.1, -0.1), p, q) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = c("ordinal", "blume-capel"), + baseline_category_r = c(0L, 2L), iter = 200L, seed = 7L + ) + + expect_equal(dim(result$x), c(n, p)) + expect_true(all(result$x[, 1] >= 0 & result$x[, 1] <= nc[1])) + expect_true(all(result$x[, 2] >= 0 & result$x[, 2] <= nc[2])) +}) + + +# ============================================================================== +# 2. Parallel simulation (run_mixed_simulation_parallel) +# ============================================================================== + +test_that("run_mixed_simulation_parallel returns correct structure", { + p = 2L + q = 1L + ndraws_total = 3L + nc = c(2L, 2L) + mux_s = matrix(rep(c(0, 0.5, -0.3, 0, -0.2, 0.1), each = ndraws_total), + nrow = ndraws_total + ) + kxx_s = matrix(0.3, nrow = ndraws_total, ncol = 1) + muy_s = matrix(0, nrow = ndraws_total, ncol = 1) + kyy_s = matrix(1.5, nrow = ndraws_total, ncol = 1) + kxy_s = matrix(c(0.1, -0.1), nrow = ndraws_total, ncol = 2, byrow = TRUE) + + n_use = 2L + n_obs = 15L + res = run_mixed_simulation_parallel( + mux_samples = mux_s, kxx_samples = kxx_s, + muy_samples = muy_s, kyy_samples = kyy_s, kxy_samples = kxy_s, + draw_indices = 1:n_use, num_states = n_obs, + p = p, q = q, num_categories = nc, + variable_type_r = c("ordinal", "ordinal"), + baseline_category = c(0L, 0L), + iter = 100L, nThreads = 1L, seed = 1L, progress_type = 0L + ) + + expect_true(is.list(res)) + expect_equal(length(res), n_use) + + for(d in seq_len(n_use)) { + expect_equal(dim(res[[d]]$x), c(n_obs, p)) + expect_equal(dim(res[[d]]$y), c(n_obs, q)) + expect_true(all(res[[d]]$x >= 0 & res[[d]]$x <= 2)) + } +}) + + +# ============================================================================== +# 3. compute_conditional_mixed +# ============================================================================== + +test_that("compute_conditional_mixed: discrete probs sum to 1", { + p = 2 + q = 1 + n = 5 + nc = c(2L, 2L) + mux = matrix(c(0, 0.5, -0.3, 0, -0.2, 0.1), nrow = p, ncol = 3, byrow = TRUE) + Kxx = matrix(c(0, 0.3, 0.3, 0), p, p) + muy = 0 + Kyy = matrix(1.5) + Kxy = matrix(c(0.1, 0.2), p, q) + + x_obs = matrix(c(0L, 1L, 2L, 0L, 1L, 1L, 2L, 0L, 1L, 2L), + nrow = n, ncol = p + ) + y_obs = matrix(rnorm(n), nrow = n, ncol = q) + + # Predict first discrete variable (0-based index 0) + preds = compute_conditional_mixed( + x_observations = x_obs, y_observations = y_obs, + predict_vars = 0L, Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + mux = mux, muy = muy, num_categories = nc, + variable_type = c("ordinal", "ordinal"), + baseline_category = c(0L, 0L) + ) + + expect_equal(length(preds), 1) + expect_equal(nrow(preds[[1]]), n) + expect_equal(ncol(preds[[1]]), nc[1] + 1) # 3 categories + + row_sums = rowSums(preds[[1]]) + expect_true(all(abs(row_sums - 1) < 1e-8)) + expect_true(all(preds[[1]] >= 0)) +}) + +test_that("compute_conditional_mixed: continuous returns mean and sd", { + p = 1 + q = 2 + n = 5 + nc = c(2L) + mux = matrix(c(0, 0.5, -0.3), nrow = 1, ncol = 3) + Kxx = matrix(0) + muy = c(1.0, -0.5) + Kyy = matrix(c(2, 0.3, 0.3, 1.5), q, q) + Kxy = matrix(c(0.1, -0.1), 1, q) + + x_obs = matrix(c(0L, 1L, 2L, 1L, 0L), nrow = n, ncol = 1) + y_obs = matrix(rnorm(n * q), nrow = n, ncol = q) + + # Predict continuous variable at index p=1 (0-based) + preds = compute_conditional_mixed( + x_observations = x_obs, y_observations = y_obs, + predict_vars = as.integer(p), Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + mux = mux, muy = muy, num_categories = nc, + variable_type = "ordinal", baseline_category = 0L + ) + + expect_equal(length(preds), 1) + expect_equal(nrow(preds[[1]]), n) + expect_equal(ncol(preds[[1]]), 2) # mean, sd + expect_true(all(preds[[1]][, 2] > 0)) # sd > 0 +}) + +test_that("compute_conditional_mixed: mixed prediction variables", { + p = 2 + q = 2 + n = 5 + nc = c(2L, 2L) + mux = matrix(0, p, 3) + Kxx = matrix(c(0, 0.3, 0.3, 0), p, p) + muy = c(0.5, -0.2) + Kyy = diag(c(1.5, 1.8)) + Kxy = matrix(0.1, p, q) + + x_obs = matrix(sample(0:2, n * p, replace = TRUE), n, p) + storage.mode(x_obs) = "integer" + y_obs = matrix(rnorm(n * q), n, q) + + # Predict one discrete (0) and one continuous (p+0 = 2) + preds = compute_conditional_mixed( + x_observations = x_obs, y_observations = y_obs, + predict_vars = c(0L, 2L), Kxx = Kxx, Kxy = Kxy, Kyy = Kyy, + mux = mux, muy = muy, num_categories = nc, + variable_type = c("ordinal", "ordinal"), + baseline_category = c(0L, 0L) + ) + + expect_equal(length(preds), 2) + # First is discrete: n x 3 + expect_equal(ncol(preds[[1]]), nc[1] + 1) + expect_true(all(abs(rowSums(preds[[1]]) - 1) < 1e-8)) + # Second is continuous: n x 2 + expect_equal(ncol(preds[[2]]), 2) +}) + + +# ============================================================================== +# 4. simulate.bgms for mixed MRF (posterior-mean path) +# ============================================================================== + +test_that("simulate.bgms works for mixed MRF with posterior-mean", { + fit = get_bgms_fit_mixed_mrf_no_es() + args = extract_arguments(fit) + + n_sim = 30 + result = simulate(fit, nsim = n_sim, method = "posterior-mean", seed = 1) + + expect_true(is.matrix(result)) + expect_equal(nrow(result), n_sim) + expect_equal(ncol(result), args$num_variables) + expect_equal(colnames(result), args$data_columnnames) + + # Discrete columns should be non-negative integers + for(di in args$discrete_indices) { + vals = result[, di] + expect_true(all(vals >= 0), info = sprintf("col %d negative", di)) + expect_true( + all(vals == round(vals)), + info = sprintf("col %d not integer", di) + ) + } +}) + +test_that("simulate.bgms seed reproducibility for mixed MRF", { + fit = get_bgms_fit_mixed_mrf_no_es() + + r1 = simulate(fit, nsim = 10, method = "posterior-mean", seed = 42) + r2 = simulate(fit, nsim = 10, method = "posterior-mean", seed = 42) + + expect_equal(r1, r2) +}) + + +# ============================================================================== +# 5. predict.bgms for mixed MRF (posterior-mean path) +# ============================================================================== + +test_that("predict.bgms works for mixed MRF with posterior-mean", { + fit = get_bgms_fit_mixed_mrf_no_es() + args = extract_arguments(fit) + + newdata = get_prediction_data_mixed(n = 10) + result = predict(fit, newdata = newdata, type = "probabilities") + + expect_true(is.list(result)) + expect_equal(length(result), args$num_variables) + + for(j in seq_len(args$num_variables)) { + vname = args$data_columnnames[j] + expect_equal(nrow(result[[j]]), 10, info = sprintf("var %s nrow", vname)) + expect_false(anyNA(result[[j]]), info = sprintf("var %s has NAs", vname)) + + if(args$variable_type[j] %in% c("ordinal", "blume-capel")) { + row_sums = rowSums(result[[j]]) + expect_true( + all(abs(row_sums - 1) < 1e-6), + info = sprintf("var %s probs don't sum to 1", vname) + ) + } else { + expect_equal(ncol(result[[j]]), 2, info = sprintf("var %s ncol", vname)) + expect_true(all(result[[j]][, 2] > 0), + info = sprintf("var %s sd not positive", vname) + ) + } + } +}) + +test_that("predict.bgms response type works for mixed MRF", { + fit = get_bgms_fit_mixed_mrf_no_es() + args = extract_arguments(fit) + + newdata = get_prediction_data_mixed(n = 10) + result = predict(fit, newdata = newdata, type = "response") + + expect_true(is.matrix(result)) + expect_equal(dim(result), c(10, args$num_variables)) + + # Discrete columns should be integer-valued + for(di in args$discrete_indices) { + expect_true(all(result[, di] == round(result[, di])), + info = sprintf("response col %d not integer", di) + ) + } +}) + + +# ============================================================================== +# 6. Edge cases +# ============================================================================== + +test_that("sample_mixed_mrf_gibbs: single discrete variable (p=1, q=2)", { + result = sample_mixed_mrf_gibbs( + num_states = 50L, + Kxx_r = matrix(0), + Kxy_r = matrix(c(0.1, -0.1), 1, 2), + Kyy_r = diag(c(1.5, 2.0)), + mux_r = matrix(c(0, 0.5), 1, 2), + muy_r = c(0, 0), + num_categories_r = 1L, + variable_type_r = "ordinal", + baseline_category_r = 0L, + iter = 100L, seed = 3L + ) + + expect_equal(dim(result$x), c(50, 1)) + expect_equal(dim(result$y), c(50, 2)) + expect_true(all(result$x %in% 0:1)) +}) + +test_that("sample_mixed_mrf_gibbs: single continuous variable (p=2, q=1)", { + result = sample_mixed_mrf_gibbs( + num_states = 50L, + Kxx_r = matrix(c(0, 0.2, 0.2, 0), 2, 2), + Kxy_r = matrix(c(0.1, -0.1), 2, 1), + Kyy_r = matrix(2.0), + mux_r = matrix(c(0, 0.5, -0.3, 0, -0.2, 0.1), 2, 3, byrow = TRUE), + muy_r = 0.5, + num_categories_r = c(2L, 2L), + variable_type_r = c("ordinal", "ordinal"), + baseline_category_r = c(0L, 0L), + iter = 100L, seed = 5L + ) + + expect_equal(dim(result$x), c(50, 2)) + expect_equal(dim(result$y), c(50, 1)) +}) + +test_that("sample_mixed_mrf_gibbs: minimal mixed MRF (p=1, q=1)", { + p = 1L + q = 1L + n = 200L + nc = 2L + + Kxx = matrix(0, p, p) + Kxy = matrix(0.3, p, q) + Kyy = matrix(1.5, q, q) + mux = matrix(c(0, 0.5, -0.3), p, nc + 1) + muy = 0.2 + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = "ordinal", + baseline_category_r = 0L, iter = 200L, seed = 77L + ) + + expect_equal(dim(result$x), c(n, p)) + expect_equal(dim(result$y), c(n, q)) + expect_true(all(result$x >= 0 & result$x <= nc)) + expect_true(all(is.finite(result$y))) + + # With non-zero Kxy, discrete and continuous should be associated + group_means = tapply(result$y[, 1], result$x[, 1], mean) + expect_true(length(group_means) >= 2, + info = "at least 2 discrete categories observed" + ) +}) + + +# ============================================================================== +# 7. Edge cases (ported from mixedGM::test-edge-cases.R) +# ============================================================================== +# Structural smoke tests for boundary configurations: many categories, +# large p / small q, small p / large q, near-singular Kyy, and +# end-to-end bgm() fitting with edge-case data. +# ============================================================================== + +# ---- many categories --------------------------------------------------------- +test_that("sample_mixed_mrf_gibbs handles many categories (4+)", { + p = 2L + q = 1L + n = 200L + nc = c(4L, 3L) + + Kxx = matrix(c(0, 0.15, 0.15, 0), p, p) + Kxy = matrix(c(0.2, 0.15), p, q) + Kyy = matrix(1.5, q, q) + mux = matrix(0, p, max(nc)) + mux[1, 1:4] = c(-0.5, 0, 0.3, 0.8) + mux[2, 1:3] = c(-0.3, 0.2, 0.6) + muy = 0 + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = c("ordinal", "ordinal"), + baseline_category_r = c(0L, 0L), iter = 200L, seed = 33L + ) + + expect_equal(dim(result$x), c(n, p)) + expect_equal(dim(result$y), c(n, q)) + + # Discrete values in valid ranges + expect_true(all(result$x[, 1] >= 0 & result$x[, 1] <= nc[1])) + expect_true(all(result$x[, 2] >= 0 & result$x[, 2] <= nc[2])) + + # All categories should be observed with enough samples + expect_true(length(unique(result$x[, 1])) >= 3, + info = "variable 1: at least 3 of 5 categories observed" + ) +}) + +# ---- large p, small q ------------------------------------------------------- +test_that("sample_mixed_mrf_gibbs handles large p, small q (p=10, q=1)", { + p = 10L + q = 1L + n = 300L + nc = as.integer(rep(1, p)) + + # Sparse Kxx + Kxx = matrix(0, p, p) + Kxx[1, 2] = Kxx[2, 1] = 0.2 + Kxx[3, 4] = Kxx[4, 3] = 0.15 + + # Sparse Kxy + Kxy = matrix(0, p, q) + Kxy[1, 1] = 0.25 + Kxy[5, 1] = 0.2 + + Kyy = matrix(1.5, q, q) + mux = matrix(0, p, max(nc)) + muy = 0 + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = rep("ordinal", p), + baseline_category_r = rep(0L, p), iter = 200L, seed = 44L + ) + + expect_equal(dim(result$x), c(n, p)) + expect_equal(dim(result$y), c(n, q)) + expect_true(all(is.finite(result$x))) + expect_true(all(is.finite(result$y))) + expect_true(all(result$x >= 0 & result$x <= 1)) +}) + +# ---- small p, large q ------------------------------------------------------- +test_that("sample_mixed_mrf_gibbs handles small p, large q (p=1, q=5)", { + p = 1L + q = 5L + n = 300L + nc = 1L + + Kxx = matrix(0, p, p) + Kxy = matrix(c(0.2, 0.15, 0.1, 0.05, 0.25), p, q) + + # Sparse Kyy + Kyy = diag(1.5, q) + Kyy[1, 2] = Kyy[2, 1] = 0.2 + Kyy[3, 4] = Kyy[4, 3] = 0.15 + + mux = matrix(0, p, 1) + muy = rep(0, q) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = "ordinal", + baseline_category_r = 0L, iter = 300L, seed = 55L + ) + + expect_equal(dim(result$x), c(n, p)) + expect_equal(dim(result$y), c(n, q)) + expect_true(all(is.finite(result$y))) + + # Kyy diagonal sets marginal variances (when Kxy ~ 0): var_j approx 1/Kyy_jj + for(j in 1:q) { + expected_sd = 1 / sqrt(Kyy[j, j]) + empirical_sd = sd(result$y[, j]) + expect_true( + abs(empirical_sd - expected_sd) / expected_sd < 0.4, + info = sprintf("y%d SD: expected %.3f, got %.3f", j, expected_sd, empirical_sd) + ) + } +}) + +# ---- near-singular Kyy ------------------------------------------------------ +test_that("sample_mixed_mrf_gibbs handles correlated Kyy", { + p = 1L + q = 2L + n = 200L + nc = 1L + + Kxx = matrix(0, p, p) + Kxy = matrix(c(0.2, 0.15), p, q) + + # Correlated Kyy with notable off-diagonal + Kyy = matrix(c(1.5, 0.5, 0.5, 1.5), q, q) + + mux = matrix(0, p, 1) + muy = c(0, 0) + + result = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = muy, num_categories_r = nc, + variable_type_r = "ordinal", + baseline_category_r = 0L, iter = 300L, seed = 66L + ) + + expect_equal(dim(result$y), c(n, q)) + expect_true(all(is.finite(result$y))) + + # With positive off-diagonal precision, y1 and y2 should be negatively correlated + r = cor(result$y[, 1], result$y[, 2]) + expect_true(r < 0.1, info = sprintf("expected negative or near-zero cor, got %.3f", r)) +}) + +# ---- bgm() end-to-end with many categories ---------------------------------- +test_that("bgm() fits mixed MRF with many categories", { + skip_on_cran() + p = 2L + q = 1L + n = 150L + nc = c(4L, 3L) + + Kxx = matrix(c(0, 0.15, 0.15, 0), p, p) + Kxy = matrix(c(0.2, 0.15), p, q) + Kyy = matrix(1.5, q, q) + mux = matrix(0, p, max(nc)) + mux[1, 1:4] = c(-0.5, 0, 0.3, 0.8) + mux[2, 1:3] = c(-0.3, 0.2, 0.6) + + generated = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = 0, num_categories_r = nc, + variable_type_r = c("ordinal", "ordinal"), + baseline_category_r = c(0L, 0L), iter = 500L, seed = 77L + ) + + dat = data.frame(generated$x, generated$y) + vt = c("ordinal", "ordinal", "continuous") + + fit = bgm(dat, + variable_type = vt, iter = 500, warmup = 250, + edge_selection = FALSE, verbose = FALSE, display_progress = "none" + ) + + expect_s3_class(fit, "bgms") + pw = fit$posterior_mean_pairwise + expect_equal(dim(pw), c(p + q, p + q)) + expect_true(all(is.finite(pw))) +}) + +# ---- bgm() end-to-end with large p, small q --------------------------------- +test_that("bgm() fits mixed MRF with large p, small q", { + skip_on_cran() + p = 6L + q = 1L + n = 200L + nc = as.integer(rep(1, p)) + + Kxx = matrix(0, p, p) + Kxx[1, 2] = Kxx[2, 1] = 0.2 + Kxy = matrix(0, p, q) + Kxy[1, 1] = 0.15 + Kyy = matrix(1.5, q, q) + mux = matrix(0, p, max(nc)) + + generated = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = 0, num_categories_r = nc, + variable_type_r = rep("ordinal", p), + baseline_category_r = rep(0L, p), iter = 500L, seed = 88L + ) + + dat = data.frame(generated$x, generated$y) + vt = c(rep("ordinal", p), "continuous") + + fit = bgm(dat, + variable_type = vt, iter = 500, warmup = 250, + edge_selection = FALSE, verbose = FALSE, display_progress = "none" + ) + + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), p + q) +}) + +# ---- bgm() end-to-end with small p, large q --------------------------------- +test_that("bgm() fits mixed MRF with small p, large q", { + skip_on_cran() + p = 1L + q = 4L + n = 200L + nc = 1L + + Kxx = matrix(0, p, p) + Kxy = matrix(c(0.2, 0.1, 0.05, 0.15), p, q) + Kyy = diag(1.5, q) + Kyy[1, 2] = Kyy[2, 1] = 0.2 + mux = matrix(0, p, 1) + + generated = sample_mixed_mrf_gibbs( + num_states = n, Kxx_r = Kxx, Kxy_r = Kxy, Kyy_r = Kyy, + mux_r = mux, muy_r = rep(0, q), num_categories_r = nc, + variable_type_r = "ordinal", + baseline_category_r = 0L, iter = 500L, seed = 99L + ) + + dat = data.frame(generated$x, generated$y) + vt = c("ordinal", rep("continuous", q)) + + fit = bgm(dat, + variable_type = vt, iter = 500, warmup = 250, + edge_selection = FALSE, verbose = FALSE, display_progress = "none" + ) + + expect_s3_class(fit, "bgms") + expect_equal(nrow(fit$posterior_mean_pairwise), p + q) +}) diff --git a/tests/testthat/test-scaffolding-fixtures.R b/tests/testthat/test-scaffolding-fixtures.R deleted file mode 100644 index 6ea1b910..00000000 --- a/tests/testthat/test-scaffolding-fixtures.R +++ /dev/null @@ -1,68 +0,0 @@ -# Null-coalescing operator for R < 4.4 compatibility -if(!exists("%||%", baseenv())) { - `%||%` = function(x, y) if(is.null(x)) y else x -} - -# ============================================================================== -# Golden-Snapshot Fixture Verification Tests -# ============================================================================== -# -# Phase A-0 of the scaffolding refactor (dev/scaffolding/plan.md). -# -# The original tests verified that check_model(), check_compare_model(), -# reformat_data(), and compare_reformat_data() reproduced golden-snapshot -# fixtures exactly. Those monolithic functions have now been deleted (Phase C.8) -# and their logic inlined into bgm_spec(). The golden-snapshot guardrails -# served their purpose during the transition. -# -# What remains: structural checks on the fixture set itself. -# ============================================================================== - -fixture_dir = file.path(testthat::test_path("..", ".."), "dev", "fixtures", "scaffolding") - -# Skip all tests if fixtures haven't been generated yet -skip_if_no_fixtures = function() { - # When running via devtools::test(), the working directory is tests/testthat/ - # The fixtures live at the package root under dev/fixtures/scaffolding/ - pkg_root = testthat::test_path("..", "..") - fixture_dir = file.path(pkg_root, "dev", "fixtures", "scaffolding") - manifest_path = file.path(fixture_dir, "manifest.rds") - if(!file.exists(manifest_path)) { - skip("Scaffolding fixtures not found. Run: Rscript dev/generate_scaffolding_fixtures.R") - } -} - -# Helper: load a fixture by id -load_fixture = function(id) { - path = file.path(fixture_dir, paste0(id, ".rds")) - if(!file.exists(path)) { - skip(paste("Fixture not found:", id)) - } - readRDS(path) -} - -# ============================================================================== -# Structural sanity checks on the fixture set -# ============================================================================== - -test_that("fixture manifest has expected number of cases", { - skip_if_no_fixtures() - manifest = readRDS(file.path(fixture_dir, "manifest.rds")) - expect_gte(nrow(manifest), 15) -}) - -test_that("fixture manifest covers both bgm and compare types", { - skip_if_no_fixtures() - manifest = readRDS(file.path(fixture_dir, "manifest.rds")) - expect_true("bgm" %in% manifest$type) - expect_true("compare" %in% manifest$type) -}) - -test_that("all fixture files listed in manifest exist on disk", { - skip_if_no_fixtures() - manifest = readRDS(file.path(fixture_dir, "manifest.rds")) - for(id in manifest$id) { - path = file.path(fixture_dir, paste0(id, ".rds")) - expect_true(file.exists(path), label = paste("File exists:", id)) - } -}) diff --git a/tests/testthat/test-simulate-predict-regression.R b/tests/testthat/test-simulate-predict-regression.R index de9f9932..c716ddff 100644 --- a/tests/testthat/test-simulate-predict-regression.R +++ b/tests/testthat/test-simulate-predict-regression.R @@ -16,129 +16,9 @@ # ------------------------------------------------------------------------------ -# Fixture Specifications (mirrored from test-methods.R) -# These produce parameterized test specs using session-cached fits from -# helper-fixtures.R. +# Fixture Specifications — defined in helper-fixtures.R +# get_bgms_fixtures(), get_bgmcompare_fixtures() # ------------------------------------------------------------------------------ -get_bgms_fixtures = function() { - list( - list( - label = "binary", get_fit = get_bgms_fit, - get_prediction_data = get_prediction_data_binary, - var_type = "binary", is_continuous = FALSE - ), - list( - label = "ordinal", get_fit = get_bgms_fit_ordinal, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", is_continuous = FALSE - ), - list( - label = "single-chain", get_fit = get_bgms_fit_single_chain, - get_prediction_data = get_prediction_data_binary, - var_type = "binary", is_continuous = FALSE - ), - list( - label = "blume-capel", get_fit = get_bgms_fit_blumecapel, - get_prediction_data = get_prediction_data_ordinal, - var_type = "blume-capel", is_continuous = FALSE - ), - list( - label = "adaptive-metropolis", get_fit = get_bgms_fit_adaptive_metropolis, - get_prediction_data = get_prediction_data_binary, - var_type = "binary", is_continuous = FALSE - ), - list( - label = "hmc", get_fit = get_bgms_fit_hmc, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", is_continuous = FALSE - ), - list( - label = "am-blumecapel", get_fit = get_bgms_fit_am_blumecapel, - get_prediction_data = get_prediction_data_ordinal, - var_type = "blume-capel", is_continuous = FALSE - ), - list( - label = "impute", get_fit = get_bgms_fit_impute, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", is_continuous = FALSE - ), - list( - label = "standardize", get_fit = get_bgms_fit_standardize, - get_prediction_data = get_prediction_data_ordinal, - var_type = "ordinal", is_continuous = FALSE - ), - list( - label = "ggm", get_fit = get_bgms_fit_ggm, - get_prediction_data = get_prediction_data_ggm, - var_type = "continuous", is_continuous = TRUE - ), - list( - label = "ggm-no-es", get_fit = get_bgms_fit_ggm_no_es, - get_prediction_data = get_prediction_data_ggm, - var_type = "continuous", is_continuous = TRUE - ) - ) -} - -get_bgmcompare_fixtures = function() { - list( - list( - label = "binary", get_fit = get_bgmcompare_fit, - get_prediction_data = get_prediction_data_bgmcompare_binary, - var_type = "binary" - ), - list( - label = "ordinal", get_fit = get_bgmcompare_fit_ordinal, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ), - list( - label = "adaptive-metropolis", get_fit = get_bgmcompare_fit_adaptive_metropolis, - get_prediction_data = get_prediction_data_bgmcompare_binary, - var_type = "binary" - ), - list( - label = "hmc", get_fit = get_bgmcompare_fit_hmc, - get_prediction_data = get_prediction_data_bgmcompare_binary, - var_type = "binary" - ), - list( - label = "hmc-blume-capel", get_fit = get_bgmcompare_fit_hmc_blumecapel, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "blume-capel", get_fit = get_bgmcompare_fit_blumecapel, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "am-blume-capel", get_fit = get_bgmcompare_fit_am_blumecapel, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "impute", get_fit = get_bgmcompare_fit_impute, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ), - list( - label = "blume-capel-impute", get_fit = get_bgmcompare_fit_blumecapel_impute, - get_prediction_data = get_prediction_data_bgmcompare_blumecapel, - var_type = "blume-capel" - ), - list( - label = "beta-bernoulli", get_fit = get_bgmcompare_fit_beta_bernoulli, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ), - list( - label = "standardize", get_fit = get_bgmcompare_fit_standardize, - get_prediction_data = get_prediction_data_bgmcompare_ordinal, - var_type = "ordinal" - ) - ) -} # ============================================================================== @@ -187,6 +67,27 @@ test_that("bgms $arguments contains all fields needed by simulate/predict", { isTRUE(args$is_continuous), info = sprintf("%s: is_continuous should be TRUE for GGM", ctx) ) + } else if(isTRUE(spec$is_mixed)) { + # Mixed MRF: OMRF fields plus mixed-specific fields + for(field in BGMS_OMRF_FIELDS) { + expect_true( + field %in% names(args), + info = sprintf("%s: missing arguments$%s", ctx, field) + ) + } + for(field in c( + "is_mixed", "discrete_indices", "continuous_indices", + "num_discrete", "num_continuous", "is_ordinal", + "data_columnnames_discrete", "data_columnnames_continuous" + )) { + expect_true( + field %in% names(args), + info = sprintf("%s: missing mixed arguments$%s", ctx, field) + ) + } + expect_true(isTRUE(args$is_mixed), + info = sprintf("%s: is_mixed should be TRUE", ctx) + ) } else { # OMRF fits must also carry num_categories and baseline_category for(field in BGMS_OMRF_FIELDS) { @@ -199,7 +100,10 @@ test_that("bgms $arguments contains all fields needed by simulate/predict", { } }) -test_that("bgmCompare $arguments contains all fields needed by simulate/predict", { +test_that(paste( + "bgmCompare $arguments contains all fields", + "needed by simulate/predict" +), { for(spec in get_bgmcompare_fixtures()) { ctx = sprintf("[bgmCompare %s]", spec$label) fit = spec$get_fit() @@ -233,9 +137,6 @@ test_that("bgms fit objects have posterior_mean fields for simulate/predict", { expect_false(is.null(fit$posterior_mean_pairwise), info = paste(ctx, "missing posterior_mean_pairwise") ) - expect_false(is.null(fit$posterior_mean_main), - info = paste(ctx, "missing posterior_mean_main") - ) expect_true(is.matrix(fit$posterior_mean_pairwise), info = paste(ctx, "posterior_mean_pairwise not a matrix") ) @@ -245,6 +146,30 @@ test_that("bgms fit objects have posterior_mean fields for simulate/predict", { expect_equal(ncol(fit$posterior_mean_pairwise), p, info = paste(ctx, "posterior_mean_pairwise wrong ncol") ) + + if(isTRUE(args$is_continuous)) { + # GGM: no main effects; precision diagonal is on pairwise matrix + expect_null(fit$posterior_mean_main, + info = paste(ctx, "GGM posterior_mean_main should be NULL") + ) + expect_true(all(diag(fit$posterior_mean_pairwise) > 0), + info = paste(ctx, "GGM pairwise diagonal should be positive") + ) + } else if(isTRUE(spec$is_mixed)) { + expect_true(is.list(fit$posterior_mean_main), + info = paste(ctx, "mixed posterior_mean_main should be a list") + ) + expect_false(is.null(fit$posterior_mean_main$discrete), + info = paste(ctx, "missing posterior_mean_main$discrete") + ) + expect_false(is.null(fit$posterior_mean_main$continuous), + info = paste(ctx, "missing posterior_mean_main$continuous") + ) + } else { + expect_false(is.null(fit$posterior_mean_main), + info = paste(ctx, "missing posterior_mean_main") + ) + } } }) @@ -277,7 +202,8 @@ test_that("bgms fit objects have raw_samples for posterior-sample method", { # ============================================================================== # For each golden fixture, construct a bgm_spec from the frozen inputs and # verify that build_arguments() produces the same simulate/predict-critical -# values (num_categories, baseline_category, variable_type / is_ordinal_variable, +# values (num_categories, baseline_category, +# variable_type / is_ordinal_variable, # is_continuous) as the old pipeline's check_model + reformat_data output. # # This is fast: bgm_spec() + build_arguments() does no MCMC. @@ -460,14 +386,22 @@ test_that("simulate → predict roundtrip works for all bgms fixtures", { expect_true(is.matrix(simulated), info = paste(ctx, "simulate")) expect_equal(nrow(simulated), n_sim, info = paste(ctx, "nrow")) expect_equal(ncol(simulated), args$num_variables, info = paste(ctx, "ncol")) - expect_equal(colnames(simulated), args$data_columnnames, info = paste(ctx, "colnames")) + expect_equal( + colnames(simulated), args$data_columnnames, + info = paste(ctx, "colnames") + ) if(isTRUE(args$is_continuous)) { # GGM: predict returns list of mean/sd matrices colnames(simulated) = args$data_columnnames pred = predict(fit, newdata = simulated) - expect_true(is.list(pred), info = paste(ctx, "predict type")) - expect_equal(length(pred), args$num_variables, info = paste(ctx, "predict length")) + expect_true(is.list(pred), + info = paste(ctx, "predict type") + ) + expect_equal( + length(pred), args$num_variables, + info = paste(ctx, "predict length") + ) for(j in seq_along(pred)) { expect_equal(nrow(pred[[j]]), n_sim, info = sprintf("%s predict var %d nrow", ctx, j) @@ -476,11 +410,66 @@ test_that("simulate → predict roundtrip works for all bgms fixtures", { info = sprintf("%s predict var %d ncol", ctx, j) ) } + } else if(isTRUE(spec$is_mixed)) { + # Mixed MRF: predict returns list with + # discrete (probs) and continuous (mean/sd) + probs = predict( + fit, + newdata = simulated, + type = "probabilities" + ) + expect_true(is.list(probs), + info = paste(ctx, "predict type") + ) + expect_equal( + length(probs), args$num_variables, + info = paste(ctx, "predict length") + ) + + for(j in seq_len(args$num_variables)) { + vname = args$data_columnnames[j] + expect_equal(nrow(probs[[j]]), n_sim, + info = sprintf("%s predict %s nrow", ctx, vname) + ) + expect_false(anyNA(probs[[j]]), + info = sprintf("%s predict %s has NAs", ctx, vname) + ) + + if(args$variable_type[j] %in% c("ordinal", "blume-capel")) { + # Discrete: probability rows sum to 1 + row_sums = rowSums(probs[[j]]) + expect_true( + all(abs(row_sums - 1) < 1e-6), + info = sprintf("%s predict %s probs don't sum to 1", ctx, vname) + ) + } else { + # Continuous: 2-column (mean, sd) matrix + expect_equal(ncol(probs[[j]]), 2, + info = sprintf("%s predict %s ncol", ctx, vname) + ) + } + } + + # type = "response" should return a matrix + resp = predict(fit, newdata = simulated, type = "response") + expect_true(is.matrix(resp), info = paste(ctx, "response matrix")) + expect_equal(dim(resp), c(n_sim, args$num_variables), + info = paste(ctx, "response dim") + ) } else { # OMRF: predict returns list of probability matrices - probs = predict(fit, newdata = simulated, type = "probabilities") - expect_true(is.list(probs), info = paste(ctx, "predict type")) - expect_equal(length(probs), args$num_variables, info = paste(ctx, "predict length")) + probs = predict( + fit, + newdata = simulated, + type = "probabilities" + ) + expect_true(is.list(probs), + info = paste(ctx, "predict type") + ) + expect_equal( + length(probs), args$num_variables, + info = paste(ctx, "predict length") + ) for(j in seq_along(probs)) { expect_equal(nrow(probs[[j]]), n_sim, info = sprintf("%s predict var %d nrow", ctx, j) @@ -528,7 +517,10 @@ test_that("simulate → predict roundtrip works for all bgmCompare fixtures", { expect_true(is.matrix(simulated), info = paste(g_ctx, "simulate")) expect_equal(nrow(simulated), n_sim, info = paste(g_ctx, "nrow")) - expect_equal(ncol(simulated), args$num_variables, info = paste(g_ctx, "ncol")) + expect_equal( + ncol(simulated), args$num_variables, + info = paste(g_ctx, "ncol") + ) expect_equal(colnames(simulated), args$data_columnnames, info = paste(g_ctx, "colnames") ) @@ -542,7 +534,11 @@ test_that("simulate → predict roundtrip works for all bgmCompare fixtures", { ) # Predict - probs = predict(fit, newdata = simulated, group = g, type = "probabilities") + probs = predict( + fit, + newdata = simulated, + group = g, type = "probabilities" + ) expect_true(is.list(probs), info = paste(g_ctx, "predict type")) expect_equal(length(probs), args$num_variables, info = paste(g_ctx, "predict length") @@ -650,7 +646,9 @@ test_that("bgms $arguments field types are correct for simulate/predict", { args = extract_arguments(fit) p = args$num_variables - expect_true(is.numeric(args$num_variables) && length(args$num_variables) == 1, + expect_true( + is.numeric(args$num_variables) && + length(args$num_variables) == 1, info = paste(ctx, "num_variables") ) expect_true(args$num_variables >= 1, @@ -660,30 +658,78 @@ test_that("bgms $arguments field types are correct for simulate/predict", { expect_true(is.character(args$variable_type), info = paste(ctx, "variable_type character") ) - expect_true(all(args$variable_type %in% c("ordinal", "blume-capel", "continuous")), + expect_true( + all(args$variable_type %in% + c("ordinal", "blume-capel", "continuous")), info = paste(ctx, "variable_type values") ) - expect_true(is.character(args$data_columnnames) && length(args$data_columnnames) == p, + expect_true( + is.character(args$data_columnnames) && + length(args$data_columnnames) == p, info = paste(ctx, "data_columnnames length") ) - if(!isTRUE(spec$is_continuous)) { + if(!isTRUE(spec$is_continuous) && !isTRUE(spec$is_mixed)) { # OMRF-only fields - expect_true(is.numeric(args$num_categories) && length(args$num_categories) == p, + expect_true( + is.numeric(args$num_categories) && + length(args$num_categories) == p, info = paste(ctx, "num_categories length") ) expect_true(all(args$num_categories >= 1), info = paste(ctx, "num_categories >= 1") ) - expect_true(is.numeric(args$baseline_category) && length(args$baseline_category) == p, + expect_true( + is.numeric(args$baseline_category) && + length(args$baseline_category) == p, info = paste(ctx, "baseline_category length") ) } + + if(isTRUE(spec$is_mixed)) { + pd = args$num_discrete + qc = args$num_continuous + expect_equal( + pd + qc, p, + info = paste( + ctx, "num_discrete + num_continuous == p" + ) + ) + expect_true( + is.numeric(args$num_categories) && + length(args$num_categories) == pd, + info = paste( + ctx, + "mixed num_categories length == num_discrete" + ) + ) + expect_true( + is.numeric(args$baseline_category) && + length(args$baseline_category) == pd, + info = paste( + ctx, + "mixed baseline_category length == num_discrete" + ) + ) + expect_true( + is.numeric(args$discrete_indices) && + length(args$discrete_indices) == pd, + info = paste(ctx, "discrete_indices length") + ) + expect_true( + is.numeric(args$continuous_indices) && + length(args$continuous_indices) == qc, + info = paste(ctx, "continuous_indices length") + ) + } } }) -test_that("bgmCompare $arguments field types are correct for simulate/predict", { +test_that(paste( + "bgmCompare $arguments field types", + "are correct for simulate/predict" +), { for(spec in get_bgmcompare_fixtures()) { ctx = sprintf("[bgmCompare %s]", spec$label) fit = spec$get_fit() @@ -698,15 +744,21 @@ test_that("bgmCompare $arguments field types are correct for simulate/predict", info = paste(ctx, "num_variables") ) - expect_true(is.numeric(args$num_categories) && length(args$num_categories) == p, + expect_true( + is.numeric(args$num_categories) && + length(args$num_categories) == p, info = paste(ctx, "num_categories length") ) - expect_true(is.logical(args$is_ordinal_variable) && length(args$is_ordinal_variable) == p, + expect_true( + is.logical(args$is_ordinal_variable) && + length(args$is_ordinal_variable) == p, info = paste(ctx, "is_ordinal_variable") ) - expect_true(is.character(args$data_columnnames) && length(args$data_columnnames) == p, + expect_true( + is.character(args$data_columnnames) && + length(args$data_columnnames) == p, info = paste(ctx, "data_columnnames") ) diff --git a/tests/testthat/test-tolerance.R b/tests/testthat/test-tolerance.R index 691de16e..b7cb6c62 100644 --- a/tests/testthat/test-tolerance.R +++ b/tests/testthat/test-tolerance.R @@ -44,6 +44,18 @@ test_that("bgms outputs are numerically sane (stochastic-robust)", { dat = na.omit(Wenchuan)[1:40, 1:4] p = ncol(dat) + dat_ggm = matrix(rnorm(40 * 4), nrow = 40, ncol = 4) + colnames(dat_ggm) = paste0("V", 1:4) + p_ggm = ncol(dat_ggm) + + dat_mixed = data.frame( + d1 = sample(0:2, 40, replace = TRUE), + c1 = rnorm(40), + d2 = sample(0:2, 40, replace = TRUE), + c2 = rnorm(40) + ) + p_mixed = ncol(dat_mixed) + upper_vals = function(M) M[upper.tri(M)] specs = list( @@ -192,6 +204,138 @@ test_that("bgms outputs are numerically sane (stochastic-robust)", { ) } ) + ), + list( + label = "ggm_bgm", + fun_label = "bgm", + fun = bgms::bgm, + args = list( + x = dat_ggm, + variable_type = "continuous", + iter = 50, + warmup = 100, + chains = 1, + edge_selection = TRUE, + display_progress = "none" + ), + checks = list( + # indicator sanity (GGM also has indicators with edge_selection=TRUE) + function(res, ctx) { + fld = "posterior_mean_indicator" + M = res[[fld]] + + expect_true(is.matrix(M), info = sprintf("%s %s is not a matrix", ctx, fld)) + expect_equal( + dim(M), c(p_ggm, p_ggm), + info = sprintf("%s %s wrong dim", ctx, fld) + ) + expect_true( + all(is.na(M) | (M >= 0 & M <= 1)), + info = sprintf("%s %s has values outside [0,1]", ctx, fld) + ) + }, + + # pairwise sanity + symmetry + function(res, ctx) { + fld = "posterior_mean_pairwise" + M = res[[fld]] + + expect_true(is.matrix(M), info = sprintf("%s %s is not a matrix", ctx, fld)) + expect_equal( + dim(M), c(p_ggm, p_ggm), + info = sprintf("%s %s wrong dim", ctx, fld) + ) + expect_false(all(is.na(M)), info = sprintf("%s %s is all NA", ctx, fld)) + + asym = max(abs(M - t(M)), na.rm = TRUE) + expect_true( + asym <= 1e-8, + info = sprintf("%s %s asymmetry too large: %g", ctx, fld, asym) + ) + }, + + # main effects: GGM has no main; precision diagonal is on pairwise + function(res, ctx) { + # Precision diagonal should be on the pairwise matrix diagonal + fld = "posterior_mean_pairwise" + vals = diag(res[[fld]]) + + expect_true(!is.null(vals), info = sprintf("%s diag(%s) missing", ctx, fld)) + expect_equal(length(vals), p_ggm, + info = sprintf("%s diag(%s) wrong length", ctx, fld) + ) + expect_true(all(is.finite(vals)), + info = sprintf("%s diag(%s) has non-finite values", ctx, fld) + ) + } + ) + ), + list( + label = "mixed_mrf_bgm", + fun_label = "bgm", + fun = bgms::bgm, + args = list( + x = dat_mixed, + variable_type = c("ordinal", "continuous", "ordinal", "continuous"), + iter = 50, + warmup = 100, + chains = 1, + edge_selection = TRUE, + display_progress = "none" + ), + checks = list( + # indicator sanity + function(res, ctx) { + fld = "posterior_mean_indicator" + M = res[[fld]] + + expect_true(is.matrix(M), info = sprintf("%s %s is not a matrix", ctx, fld)) + expect_equal( + dim(M), c(p_mixed, p_mixed), + info = sprintf("%s %s wrong dim", ctx, fld) + ) + expect_true( + all(is.na(M) | (M >= 0 & M <= 1)), + info = sprintf("%s %s has values outside [0,1]", ctx, fld) + ) + }, + + # pairwise sanity + symmetry + function(res, ctx) { + fld = "posterior_mean_pairwise" + M = res[[fld]] + + expect_true(is.matrix(M), info = sprintf("%s %s is not a matrix", ctx, fld)) + expect_equal( + dim(M), c(p_mixed, p_mixed), + info = sprintf("%s %s wrong dim", ctx, fld) + ) + expect_false(all(is.na(M)), info = sprintf("%s %s is all NA", ctx, fld)) + + asym = max(abs(M - t(M)), na.rm = TRUE) + expect_true( + asym <= 1e-8, + info = sprintf("%s %s asymmetry too large: %g", ctx, fld, asym) + ) + }, + + # coarse aggregate for pairwise + function(res, ctx) { + fld = "posterior_mean_pairwise" + M = res[[fld]] + vals = abs(upper_vals(M)) + stat = mean(vals, na.rm = TRUE) + + expect_true( + is.finite(stat), + info = sprintf("%s %s mean(|upper|) not finite", ctx, fld) + ) + expect_true( + stat <= 2.0, + info = sprintf("%s %s mean(|upper|) too large: %0.3f", ctx, fld, stat) + ) + } + ) ) ) diff --git a/tests/testthat/test-validation-slow.R b/tests/testthat/test-validation-slow.R new file mode 100644 index 00000000..82f3179b --- /dev/null +++ b/tests/testthat/test-validation-slow.R @@ -0,0 +1,178 @@ +# ============================================================================== +# Gated slow validation tests +# ============================================================================== +# +# Environment-gated wrappers around mixed MRF validation assertions. These run +# in nightly CI (where +# BGMS_RUN_SLOW_TESTS=true) and are skipped during local devtools::test() +# and CRAN checks. +# +# Each test fits a small mixed MRF network with enough iterations to +# produce meaningful posterior estimates, then checks a quantitative +# recovery or agreement criterion. +# +# Shared helper functions live in tests/testthat/helper-validation.R. +# ============================================================================== + +# Load shared helpers once per file +helpers_path = file.path("tests", "testthat", "helper-validation.R") +if(!file.exists(helpers_path)) { + # When running from testthat, the working directory is tests/testthat/ + helpers_path = file.path("helper-validation.R") +} +helpers_available = file.exists(helpers_path) + + +# ============================================================================== +# Gate: skip unless BGMS_RUN_SLOW_TESTS=true +# ============================================================================== + +skip_slow = function() { + skip_if_not( + isTRUE(as.logical(Sys.getenv("BGMS_RUN_SLOW_TESTS", "false"))), + "Set BGMS_RUN_SLOW_TESTS=true to run slow validation tests" + ) + skip_if(!helpers_available, "Validation helpers not found") +} + + +# ============================================================================== +# 1. Parameter recovery (adapted from group1) +# ============================================================================== + +test_that("mixed MRF parameter recovery: cor > 0.8 (small network)", { + skip_slow() + source(helpers_path, local = TRUE) + + net = make_network(p = 2, q = 2, n_cat = c(1L, 2L), density = 1.0, seed = 101) + dat = generate_data(net, n = 2000, source = "bgms", seed = 201) + vt = c(rep("ordinal", 2), rep("continuous", 2)) + + fit = bgm(dat, + variable_type = vt, + pseudolikelihood = "conditional", edge_selection = FALSE, + iter = 10000, warmup = 5000, chains = 2, seed = 301 + ) + + true_blocks = list( + mux = net$mux, muy = net$muy, + Kxx = net$Kxx, Kxy = net$Kxy, Kyy = net$Kyy + ) + est_blocks = extract_bgms_blocks(fit, net) + + true_flat = flatten_params(true_blocks) + est_flat = flatten_params(est_blocks) + + r = cor(true_flat, est_flat) + expect_gt(r, 0.8, + label = sprintf("recovery correlation (%.4f)", r) + ) +}) + + +# ============================================================================== +# 2. Metropolis vs NUTS agreement (adapted from group2) +# ============================================================================== + +test_that("MH vs NUTS posterior agreement: cor > 0.95", { + skip_slow() + source(helpers_path, local = TRUE) + + net = make_network(p = 2, q = 2, n_cat = c(1L, 2L), density = 1.0, seed = 101) + dat = generate_data(net, n = 2000, source = "bgms", seed = 201) + vt = c(rep("ordinal", 2), rep("continuous", 2)) + + fit_mh = bgm(dat, + variable_type = vt, + pseudolikelihood = "conditional", + update_method = "adaptive-metropolis", + edge_selection = FALSE, + iter = 15000, warmup = 10000, chains = 2, seed = 401 + ) + + fit_nuts = bgm(dat, + variable_type = vt, + pseudolikelihood = "conditional", + update_method = "nuts", + edge_selection = FALSE, + iter = 5000, warmup = 3000, chains = 2, seed = 402 + ) + + est_mh = flatten_params(extract_bgms_blocks(fit_mh, net)) + est_nuts = flatten_params(extract_bgms_blocks(fit_nuts, net)) + + r = cor(est_mh, est_nuts) + expect_gt(r, 0.95, + label = sprintf("MH vs NUTS agreement (%.4f)", r) + ) +}) + + +# ============================================================================== +# 3. Conditional vs marginal PL agreement (adapted from group3) +# ============================================================================== + +test_that("conditional vs marginal PL agreement: cor > 0.90", { + skip_slow() + source(helpers_path, local = TRUE) + + net = make_network(p = 2, q = 2, n_cat = c(1L, 2L), density = 1.0, seed = 101) + dat = generate_data(net, n = 2000, source = "bgms", seed = 201) + vt = c(rep("ordinal", 2), rep("continuous", 2)) + + fit_cond = bgm(dat, + variable_type = vt, + pseudolikelihood = "conditional", edge_selection = FALSE, + iter = 10000, warmup = 5000, chains = 2, seed = 501 + ) + + fit_marg = bgm(dat, + variable_type = vt, + pseudolikelihood = "marginal", edge_selection = FALSE, + iter = 10000, warmup = 5000, chains = 2, seed = 502 + ) + + est_cond = flatten_params(extract_bgms_blocks(fit_cond, net)) + est_marg = flatten_params(extract_bgms_blocks(fit_marg, net)) + + r = cor(est_cond, est_marg) + expect_gt(r, 0.90, + label = sprintf("cond vs marg PL agreement (%.4f)", r) + ) +}) + + +# ============================================================================== +# 4. Estimate-simulate-re-estimate cycle (mixed MRF) +# ============================================================================== + +test_that("estimate-simulate-re-estimate cycle: cor > 0.7 (mixed MRF)", { + skip_slow() + source(helpers_path, local = TRUE) + + net = make_network(p = 2, q = 2, n_cat = c(1L, 2L), density = 1.0, seed = 101) + dat = generate_data(net, n = 2000, source = "bgms", seed = 201) + vt = c(rep("ordinal", 2), rep("continuous", 2)) + + fit1 = bgm(dat, + variable_type = vt, + edge_selection = FALSE, iter = 5000, warmup = 2000, + chains = 1, seed = 601 + ) + + simulated = simulate(fit1, nsim = 2000, seed = 701) + + fit2 = bgm(simulated, + variable_type = vt, + edge_selection = FALSE, iter = 5000, warmup = 2000, + chains = 1, seed = 801 + ) + + pw1 = as.vector(fit1$posterior_mean_pairwise) + pw2 = as.vector(fit2$posterior_mean_pairwise) + + r = cor(pw1, pw2) + expect_gt(r, 0.7, + label = sprintf("cycle pairwise correlation (%.4f)", r) + ) +})