diff --git a/R/RcppExports.R b/R/RcppExports.R index 1f228e62..69628ca3 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,12 +1,12 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs) { - .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs) +run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) { + .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) } -run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) { - .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) +run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { + .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) } get_explog_switch <- function() { diff --git a/R/bgm.R b/R/bgm.R index 3c99aec7..e9a99a75 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -237,7 +237,7 @@ bgm = function( dirichlet_alpha = 1, lambda = 1, na_action = c("listwise", "impute"), - display_progress = TRUE, + display_progress = c("per-chain", "total", "none"), update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), target_accept, hmc_num_leapfrogs = 100, @@ -324,7 +324,7 @@ bgm = function( ".")) #Check display_progress ------------------------------------------------------ - display_progress = check_logical(display_progress, "display_progress") + progress_type = progress_type_from_display_progress(display_progress) #Format the data input ------------------------------------------------------- data = reformat_data(x = x, @@ -428,7 +428,7 @@ bgm = function( target_accept = target_accept, pairwise_stats = pairwise_stats, hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, num_chains = chains, - nThreads = cores, seed = seed + nThreads = cores, seed = seed, progress_type = progress_type ) # Main output handler in the wrapper function @@ -457,5 +457,9 @@ bgm = function( output$nuts_diag = nuts_diag } + userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) + if (userInterrupt) + warning("Stopped sampling after user interrupt, results are likely uninterpretable.") + return(output) } \ No newline at end of file diff --git a/R/bgmCompare.R b/R/bgmCompare.R index db91ceb4..9f102ae8 100644 --- a/R/bgmCompare.R +++ b/R/bgmCompare.R @@ -122,7 +122,7 @@ bgmCompare = function( iter = 1e3, burnin = 1e3, na_action = c("listwise", "impute"), - display_progress = TRUE, + display_progress = c("per-chain", "total", "none"), update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), target_accept, hmc_num_leapfrogs = 100, @@ -202,7 +202,8 @@ bgmCompare = function( } # Check display_progress - display_progress = check_logical(display_progress, "display_progress") + progress_type = progress_type_from_display_progress(display_progress) + ## Format data data = compare_reformat_data( @@ -340,7 +341,8 @@ bgmCompare = function( inclusion_probability = model$inclusion_probability_difference, num_chains = chains, nThreads = cores, seed = seed, - update_method = update_method, hmc_num_leapfrogs = hmc_num_leapfrogs + update_method = update_method, hmc_num_leapfrogs = hmc_num_leapfrogs, + progress_type = progress_type ) # Main output handler in the wrapper function diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 2bb2a3c9..c41daab6 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -463,4 +463,15 @@ check_compare_model = function( inclusion_probability_difference = inclusion_probability_difference ) ) -} \ No newline at end of file +} + +progress_type_from_display_progress <- function(display_progress = c("per-chain", "total", "none")) { + if (is.logical(display_progress) && length(display_progress) == 1) { + if (is.na(display_progress)) + stop("The display_progress argument must be a single logical value, but not NA.") + display_progress = if (display_progress) "per-chain" else "none" + } else { + display_progress = match.arg(display_progress) + } + return(if (display_progress == "per-chain") 2L else if (display_progress == "total") 1L else 0L) +} diff --git a/src/Makevars.win b/src/Makevars.win index bfa74448..7d9a2d0f 100644 --- a/src/Makevars.win +++ b/src/Makevars.win @@ -1,4 +1,4 @@ -CXX_STD = CXX14 +CXX_STD = CXX20 PKG_CPPFLAGS = \ $(shell "$(R_HOME)\bin\Rscript.exe" -e "cat(RcppParallel::CxxFlags())") diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index e829cc45..9fa0affe 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -12,8 +12,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // run_bgmCompare_parallel -Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& counts_per_category, const std::vector& blume_capel_stats, const std::vector& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs); -RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP) { +Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& counts_per_category, const std::vector& blume_capel_stats, const std::vector& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs, int progress_type); +RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -52,13 +52,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const std::string& >::type update_method(update_methodSEXP); Rcpp::traits::input_parameter< int >::type hmc_num_leapfrogs(hmc_num_leapfrogsSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs)); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type)); return rcpp_result_gen; END_RCPP } // run_bgm_parallel -Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed); -RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) { +Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type); +RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -93,7 +94,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP); Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); Rcpp::traits::input_parameter< uint64_t >::type seed(seedSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed)); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); return rcpp_result_gen; END_RCPP } @@ -179,8 +181,8 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 35}, - {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31}, + {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36}, + {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32}, {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, diff --git a/src/bgmCompare_parallel.cpp b/src/bgmCompare_parallel.cpp index b9b0648c..577449e1 100644 --- a/src/bgmCompare_parallel.cpp +++ b/src/bgmCompare_parallel.cpp @@ -6,6 +6,9 @@ #include #include #include "rng_utils.h" +#include "progress_manager.h" +#include "sampler_output.h" +#include "mcmc_adaptation.h" using namespace Rcpp; using namespace RcppParallel; @@ -130,6 +133,7 @@ struct GibbsCompareChainRunner : public Worker { const std::vector& chain_rngs; const std::string& update_method; const int hmc_num_leapfrogs; + ProgressManager& pm; // output std::vector& results; @@ -167,6 +171,7 @@ struct GibbsCompareChainRunner : public Worker { const std::vector& chain_rngs, const std::string& update_method, const int hmc_num_leapfrogs, + ProgressManager& pm, std::vector& results ) : observations(observations), @@ -202,6 +207,7 @@ struct GibbsCompareChainRunner : public Worker { chain_rngs(chain_rngs), update_method(update_method), hmc_num_leapfrogs(hmc_num_leapfrogs), + pm(pm), results(results) {} @@ -257,7 +263,8 @@ struct GibbsCompareChainRunner : public Worker { inclusion_probability, rng, update_method, - hmc_num_leapfrogs + hmc_num_leapfrogs, + pm ); out.result = result; @@ -376,7 +383,8 @@ Rcpp::List run_bgmCompare_parallel( int nThreads, int seed, const std::string& update_method, - int hmc_num_leapfrogs + int hmc_num_leapfrogs, + int progress_type ) { std::vector results(num_chains); @@ -386,6 +394,10 @@ Rcpp::List run_bgmCompare_parallel( chain_rngs[c] = SafeRNG(seed + c); } + // only used to determine the total no. burnin iterations, a bit hacky + WarmupSchedule warmup_schedule_temp(burnin, difference_selection, (update_method != "adaptive-metropolis")); + int total_burnin = warmup_schedule_temp.total_burnin; + ProgressManager pm(num_chains, iter, total_burnin, 50, progress_type); GibbsCompareChainRunner worker( observations, num_groups, @@ -397,7 +409,7 @@ Rcpp::List run_bgmCompare_parallel( pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, chain_rngs, update_method, hmc_num_leapfrogs, - results + pm, results ); { @@ -426,9 +438,12 @@ Rcpp::List run_bgmCompare_parallel( if (r.has_indicator) { chain_out["indicator_samples"] = r.indicator_samples; } + chain_out["userInterrupt"] = r.userInterrupt; output[i] = chain_out; } } + pm.finish(); + return output; } \ No newline at end of file diff --git a/src/bgmCompare_sampler.cpp b/src/bgmCompare_sampler.cpp index 264898ac..8ddf730b 100644 --- a/src/bgmCompare_sampler.cpp +++ b/src/bgmCompare_sampler.cpp @@ -13,6 +13,7 @@ #include "rng_utils.h" #include "sampler_output.h" #include +#include "progress_manager.h" using namespace Rcpp; @@ -1576,7 +1577,8 @@ SamplerOutput run_gibbs_sampler_bgmCompare( arma::mat inclusion_probability, SafeRNG& rng, const std::string& update_method, - const int hmc_num_leapfrogs + const int hmc_num_leapfrogs, + ProgressManager& pm ) { // --- Setup: dimensions and storage structures const int num_variables = observations.n_cols; @@ -1645,14 +1647,13 @@ SamplerOutput run_gibbs_sampler_bgmCompare( const int print_every = std::max(1, total_iter / 10); // --- Main Gibbs sampling loop + bool userInterrupt = false; for (int iteration = 0; iteration < total_iter; iteration++) { - if (iteration % print_every == 0) { - tbb::mutex::scoped_lock lock(get_print_mutex()); - std::cout - << "[bgm] chain " << chain_id - << " iteration " << iteration - << " / " << total_iter - << std::endl; + + pm.update(chain_id - 1); + if (pm.shouldExit()) { + userInterrupt = true; + break; } // Shuffle update order of edge indices @@ -1753,5 +1754,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( } else { out.indicator_samples = arma::imat(); } + out.userInterrupt = userInterrupt; + return out; } \ No newline at end of file diff --git a/src/bgmCompare_sampler.h b/src/bgmCompare_sampler.h index 0672d3fb..132bd957 100644 --- a/src/bgmCompare_sampler.h +++ b/src/bgmCompare_sampler.h @@ -1,10 +1,11 @@ #pragma once #include -#include "sampler_output.h" #include +struct SamplerOutput; struct SafeRNG; +class ProgressManager; SamplerOutput run_gibbs_sampler_bgmCompare( int chain_id, @@ -40,5 +41,6 @@ SamplerOutput run_gibbs_sampler_bgmCompare( arma::mat inclusion_probability, SafeRNG& rng, const std::string& update_method, - const int hmc_num_leapfrogs + const int hmc_num_leapfrogs, + ProgressManager& pm ); \ No newline at end of file diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index 052a36b7..b3ac1447 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -6,6 +6,8 @@ #include #include #include "rng_utils.h" +#include "progress_manager.h" +#include "mcmc_adaptation.h" using namespace Rcpp; using namespace RcppParallel; @@ -87,6 +89,7 @@ struct GibbsChainRunner : public Worker { // Wrapped RNG engines const std::vector& chain_rngs; + ProgressManager& pm; // output buffer std::vector& results; @@ -121,6 +124,7 @@ struct GibbsChainRunner : public Worker { int nuts_max_depth, bool learn_mass_matrix, const std::vector& chain_rngs, + ProgressManager& pm, std::vector& results ) : observations(observations), @@ -152,6 +156,7 @@ struct GibbsChainRunner : public Worker { nuts_max_depth(nuts_max_depth), learn_mass_matrix(learn_mass_matrix), chain_rngs(chain_rngs), + pm(pm), results(results) {} @@ -194,7 +199,8 @@ struct GibbsChainRunner : public Worker { hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, - rng + rng, + pm ); out.result = result; @@ -291,7 +297,8 @@ Rcpp::List run_bgm_parallel( bool learn_mass_matrix, int num_chains, int nThreads, - uint64_t seed + uint64_t seed, + int progress_type ) { std::vector results(num_chains); @@ -301,6 +308,11 @@ Rcpp::List run_bgm_parallel( chain_rngs[c] = SafeRNG(seed + c); } + // only used to determine the total no. burnin iterations, a bit hacky + WarmupSchedule warmup_schedule_temp(burnin, edge_selection, (update_method != "adaptive-metropolis")); + int total_burnin = warmup_schedule_temp.total_burnin; + ProgressManager pm(num_chains, iter, total_burnin, 50, progress_type); + GibbsChainRunner worker( observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, @@ -309,7 +321,7 @@ Rcpp::List run_bgm_parallel( na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, - chain_rngs, results + chain_rngs, pm, results ); { @@ -329,5 +341,7 @@ Rcpp::List run_bgm_parallel( } } + pm.finish(); + return output; } diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index 562348ee..0fb2f10b 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -10,9 +10,9 @@ #include "mcmc_nuts.h" #include "mcmc_rwm.h" #include "mcmc_utils.h" -#include "print_mutex.h" #include "sbm_edge_prior.h" #include "rng_utils.h" +#include "progress_manager.h" using namespace Rcpp; @@ -1197,7 +1197,8 @@ Rcpp::List run_gibbs_sampler_bgm( const int hmc_num_leapfrogs, const int nuts_max_depth, const bool learn_mass_matrix, - SafeRNG& rng + SafeRNG& rng, + ProgressManager& pm ) { // --- Setup: dimensions and storage structures const int num_variables = observations.n_cols; @@ -1296,17 +1297,15 @@ Rcpp::List run_gibbs_sampler_bgm( ); const int total_iter = warmup_schedule.total_burnin + iter; - const int print_every = std::max(1, total_iter / 10); + bool userInterrupt = false; // --- Main Gibbs sampling loop for (int iteration = 0; iteration < total_iter; iteration++) { - if (iteration % print_every == 0) { - tbb::mutex::scoped_lock lock(get_print_mutex()); - std::cout - << "[bgm] chain " << chain_id - << " iteration " << iteration - << " / " << total_iter - << std::endl; + + pm.update(chain_id - 1); + if (pm.shouldExit()) { + userInterrupt = true; + break; } // Shuffle update order of edge indices @@ -1424,6 +1423,7 @@ Rcpp::List run_gibbs_sampler_bgm( out["allocations"] = allocation_samples; } + out["userInterrupt"] = userInterrupt; out["chain_id"] = chain_id; return out; } \ No newline at end of file diff --git a/src/bgm_sampler.h b/src/bgm_sampler.h index 7bb1403d..42234481 100644 --- a/src/bgm_sampler.h +++ b/src/bgm_sampler.h @@ -1,8 +1,8 @@ #pragma once #include - // forward declaration struct SafeRNG; +class ProgressManager; Rcpp::List run_gibbs_sampler_bgm( int chain_id, @@ -34,5 +34,6 @@ Rcpp::List run_gibbs_sampler_bgm( const int hmc_num_leapfrogs, const int nuts_max_depth, const bool learn_mass_matrix, - SafeRNG& rng + SafeRNG& rng, + ProgressManager& pm ); \ No newline at end of file diff --git a/src/progress_manager.cpp b/src/progress_manager.cpp new file mode 100644 index 00000000..f7a43836 --- /dev/null +++ b/src/progress_manager.cpp @@ -0,0 +1,415 @@ +#include "progress_manager.h" + +ProgressManager::ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_, int progress_type_, bool useUnicode_) + : nChains(nChains_), nIter(nIter_ + nWarmup_), nWarmup(nWarmup_), progress(nChains_), printEvery(printEvery_), + progress_type(progress_type_), useUnicode(useUnicode_) { // +2 for total and time lines + + for (int i = 0; i < nChains; i++) progress[i] = 0; + start = Clock::now(); + lastPrint = Clock::now(); + + // Check if we're in RStudio + Rcpp::Environment base("package:base"); + Rcpp::Function getOption = base["getOption"]; + Rcpp::Function Sysgetenv = base["Sys.getenv"]; + SEXP s = Sysgetenv("RSTUDIO"); + isRStudio = Rcpp::as(s) == "1"; + + no_spaces_for_total = 3 + static_cast(std::log10(nChains)); + if (progress_type == 1) no_spaces_for_total = 1; // no need to align, so one space is fine + total_padding = std::string(no_spaces_for_total, ' '); + + if (isRStudio) { + consoleWidth = getConsoleWidth(); + lineWidth = std::max(10, std::min(consoleWidth - 25, 70)); + } else { + // For terminal, use default console width + consoleWidth = 80; + lineWidth = 70; + } + + // cleverly determine barwidth so no line wrapping occurs + if (lineWidth <= 5) { + // TODO: we don't want to print anything in this case + barWidth = 0; + } else if (lineWidth < 20) { + barWidth = lineWidth - 10; + } else if (lineWidth < 40) { + barWidth = lineWidth - 15; + } else { // > 40 + barWidth = std::min(40, lineWidth - 30); + } + + if (isRStudio) { + barWidth = std::max(10, barWidth - 20); // minimum bar width of 10 for RStudio + } + + // Set up theme + setupTheme(); + update_prefixes(consoleWidth); +} + +void ProgressManager::update(int chainId) { + progress[chainId]++; + + // Only chain 0 actually does the printing/ checking for user interrupts + if (chainId != 0) return; + + if (progress[chainId] % printEvery == 0) { + auto now = Clock::now(); + std::chrono::duration sinceLast = now - lastPrint; + + // Throttle printing to avoid spamming + if (progress_type != 0 && sinceLast.count() >= 0.5) { + print(); + lastPrint = now; + } + } + + // Check for user interrupts and console width changes less frequently to reduce overhead + if (chainId == 0 && progress[chainId] % (printEvery * 5) == 0) { + needsToExit = checkInterrupt(); + // would be a nice feature, but not working atm + // Also check for console width changes occasionally + // checkConsoleWidthChange(); + } +} + +void ProgressManager::finish() { + + if (progress_type == 0) return; // No progress display + + // Mark all chains as complete and print one final time + for (int i = 0; i < nChains; i++) { + progress[i] = nIter; + } + print(); + +} + +bool ProgressManager::shouldExit() const { + return needsToExit; +} + +void ProgressManager::checkConsoleWidthChange() { + if (!isRStudio) return; + + Rcpp::Environment base("package:base"); + Rcpp::Function getOption = base["getOption"]; + int currentWidth = getConsoleWidth(); + + if (prevConsoleWidth == -1) { + // First time, just store the current width + prevConsoleWidth = consoleWidth; + widthChanged = false; + return; + } + + if (currentWidth != consoleWidth && currentWidth > 0) { + // Width has changed + prevConsoleWidth = consoleWidth; + consoleWidth = currentWidth; + widthChanged = true; + } else { + widthChanged = false; + } +} + +int ProgressManager::getConsoleWidth() const { + Rcpp::Environment base("package:base"); + Rcpp::Function getOption = base["getOption"]; + SEXP s = getOption("width", 0); + int width = Rcpp::as(s); + // Remove the +3 adjustment to test actual console width + return width + 3; +} + +std::string ProgressManager::formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal) const { + std::ostringstream builder; + + double exactFilled = fraction * barWidth; + int filled = std::max(0, std::min(int(exactFilled), barWidth)); + + // Build progress bar with theme + std::string progressBar = lhsToken; + + // Add filled tokens + for (int i = 0; i < filled; i++) { + progressBar += filledToken; + } + + // Add partial token if needed + if (filled < barWidth) { + double partialAmount = exactFilled - filled; + if (partialAmount > 0) { + if (partialAmount > 0.5) { + progressBar += partialTokenMore; + } else { + progressBar += partialTokenLess; + } + filled++; // Account for the partial token + } + } + + // Add empty tokens + for (int i = filled; i < barWidth; i++) { + progressBar += emptyToken; + } + + progressBar += rhsToken; + + // store the current length of the progress bar without any additional text + size_t currentWidth = progressBar.length(); + + if (isTotal) { + + std::string warmupOrSampling = isWarmupPhase() ? "(Warmup)" : "(Sampling)"; + builder << total_prefix << total_padding << warmupOrSampling << ": " << progressBar << " " << current << "/" << total + << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; + + } else { + + std::string warmupOrSampling = isWarmupPhase(chainId - 1) ? " (Warmup)" : " (Sampling)"; + builder << chain_prefix << " " << chainId << warmupOrSampling << ": " << progressBar << " " << current << "/" << total + << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; + } + + std::string output = builder.str(); + + if (isRStudio && progress_type == 2) { + + currentWidth = output.length() + 2 + barWidth - currentWidth; + // Pad each line to exactly lineWidth characters (before adding \n) + if (currentWidth < lineWidth) { + output += std::string(lineWidth - currentWidth, ' '); + } else if (currentWidth > lineWidth) { + output = output.substr(0, lineWidth); + } + } + + return output; +} + +// std::string ProgressManager::formatTimeInfo(double elapsed, double eta) { +// std::ostringstream builder; +// builder << "Elapsed: " << elapsed << "s | ETA: " << eta << "s"; +// return builder.str(); +// } + +std::string ProgressManager::formatTimeInfo(double elapsed, double eta) const { + std::ostringstream builder; + builder << "Elapsed: " << formatDuration(elapsed) << " | ETA: " << formatDuration(eta); + return builder.str(); +} + +// Add this helper function to the class +std::string ProgressManager::formatDuration(double seconds) const { + if (seconds < 0) { + return "0s"; + } + + // Convert to different units + if (seconds < 60) { + // Less than 1 minute: show seconds + return std::to_string(static_cast(std::round(seconds))) + "s"; + } + else if (seconds < 3600) { + // Less than 1 hour: show minutes and seconds + int mins = static_cast(seconds / 60); + int secs = static_cast(seconds) % 60; + if (secs == 0) { + return std::to_string(mins) + "m"; + } else { + return std::to_string(mins) + "m " + std::to_string(secs) + "s"; + } + } + else if (seconds < 86400) { + // Less than 1 day: show hours and minutes + int hours = static_cast(seconds / 3600); + int mins = static_cast((seconds - hours * 3600) / 60); + if (mins == 0) { + return std::to_string(hours) + "h"; + } else { + return std::to_string(hours) + "h " + std::to_string(mins) + "m"; + } + } + else { + // 1 day or more: show days and hours + int days = static_cast(seconds / 86400); + int hours = static_cast((seconds - days * 86400) / 3600); + if (hours == 0) { + return std::to_string(days) + "d"; + } else { + return std::to_string(days) + "d " + std::to_string(hours) + "h"; + } + } +} + +void ProgressManager::setupTheme() { + // should be a struct of some kind... + if (useUnicode) { + // Unicode theme + lhsToken = "⦗"; + rhsToken = "⦘"; + filledToken = "\033[38;5;73m━\033[39m"; // Blue filled + partialTokenMore = "\033[38;5;73m━\033[39m"; // Blue partial (> 0.5) + partialTokenLess = "\033[37m╺\033[39m"; // Gray partial (< 0.5) + emptyToken = "\033[37m━\033[39m"; // Gray empty + } else { + // Classic theme + lhsToken = "["; + rhsToken = "]"; + filledToken = "="; + emptyToken = " "; + partialTokenMore = " "; + partialTokenLess = " "; + } +} + +void ProgressManager::print() { + std::lock_guard lock(printMutex); + + auto now = Clock::now(); + double elapsed = std::chrono::duration_cast(now - start).count(); + + int totalWork = nChains * nIter; + int done = std::reduce(progress.begin(), progress.end()); + double fracTotal = double(done) / totalWork; + // should actually be the eta of the slowest chain! + double eta = (fracTotal > 0) ? elapsed / fracTotal - elapsed : 0.0; + + std::ostringstream out; + // int totalChars = 0; + + // if this is not the first print, delete previous content + if (progress_type == 2) { + + if (lastPrintedChars > 0) { + + if (isRStudio) { + out << "\x1b[" << std::to_string((1 + lineWidth) * lastPrintedLines) << "D"; + // out << "\x1b[" << std::to_string(lastPrintedChars) << "D"; + } else { + // Move cursor up to start of our content and clear everything + for (int i = 0; i < lastPrintedLines; i++) { + out << "\x1b[1A\x1b[2K"; // Move up one line and clear entire line + } + } + } + + // Print progress for each chain + for (int i = 0; i < nChains; i++) { + double frac = double(progress[i]) / nIter; + std::string chainProgress = formatProgressBar(i + 1, progress[i], nIter, frac); + out << chainProgress << "\n"; + // totalChars += chainProgress.length() + 1; // +1 for newline + } + + // Print total progress + std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + out << totalProgress << "\n"; + // totalChars += totalProgress.length() + 1; // +1 for newline + + // Print time info + std::string timeInfo = formatTimeInfo(elapsed, eta); + maybePadToLength(timeInfo); + out << timeInfo << "\n"; + // totalChars += timeInfo.length() + 1; // +1 for newline + + // Track total lines printed (chains + total + time) + lastPrintedLines = nChains + 2; // used in a generic terminal + lastPrintedChars = 1;//totalChars; // used by RStudio + + } else if (progress_type == 1) { + + // Print total progress + std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + + // Print time info + totalProgress += " " + formatTimeInfo(elapsed, eta); + + if (done < totalWork) { + out << totalProgress << "\r"; + } else { + out << totalProgress << "\n"; + } + + // we do not set lastPrintedChars or lastPrintedLines here since we always overwrite the same line + + } + + Rcpp::Rcout << out.str(); + +} + +void ProgressManager::update_prefixes(int width) { + if (width < 20) { + chain_prefix = "C"; + total_prefix = "T"; + } else if (width < 30) { + chain_prefix = "Ch"; + total_prefix = "Tot"; + } else { + chain_prefix = "Chain"; + total_prefix = "Total"; + } +} + +void ProgressManager::maybePadToLength(std::string& content) const { + if (!isRStudio) return; + + // Pad each line to exactly lineWidth characters (before adding \n) + if (content.length() < lineWidth) { + content += std::string(lineWidth - content.length(), ' '); + } else if (content.length() > lineWidth) { + content = content.substr(0, lineWidth); + } +} + + +// Example usage/ test with RcppParallel +// #include +// // Worker functor for RcppParallel +// struct ChainWorker : public RcppParallel::Worker { +// int nIter; +// ProgressManager ± +// bool display_progress; + +// ChainWorker(int nIter_, ProgressManager &pm_, bool display_progress_) +// : nIter(nIter_), pm(pm_), display_progress(display_progress_) {} + +// void operator()(std::size_t begin, std::size_t end) { + +// auto chainId = begin; + +// for (int i = 0; i < nIter; i++) { +// // ---- Simulated work ---- +// std::this_thread::sleep_for(std::chrono::milliseconds(20)); + +// // ---- Update state ---- +// pm.update(chainId); +// if (pm.shouldExit()) break; +// // if (Progress::check_abort()) Rcpp::checkUserInterrupt(); +// } +// } +// }; + + +// // [[Rcpp::export]] +// void runMCMC_parallel(int nChains = 4, int nIter = 100, int nWarmup = 100, int progress_type = 2, bool useUnicode = false) { + +// ProgressManager pm(nChains, nIter, nWarmup, 10, progress_type, useUnicode); +// ChainWorker worker(nIter + nWarmup, pm, true); + +// // Run each chain in parallel +// RcppParallel::parallelFor(0, nChains, worker); + +// pm.finish(); + +// if (pm.shouldExit()) { +// Rcpp::Rcout << "\nComputation interrupted by user.\n"; +// } else { +// Rcpp::Rcout << "\nAll chains finished!\n"; +// } +// } + diff --git a/src/progress_manager.h b/src/progress_manager.h new file mode 100644 index 00000000..80ac526a --- /dev/null +++ b/src/progress_manager.h @@ -0,0 +1,123 @@ +#ifndef PROGRESS_MANAGER_H +#define PROGRESS_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using Clock = std::chrono::steady_clock; + +// Interrupt checking functions +// https://github.com/kforner/rcpp_progress/blob/d851ac62fd0314239e852392de7face5fa4bf48e/inst/include/interrupts.hpp#L24-L31 +static void chkIntFn(void *dummy) { + R_CheckUserInterrupt(); +} + +// this will call the above in a top-level context so it won't longjmp-out of your context +inline bool checkInterrupt() { + return (R_ToplevelExec(chkIntFn, NULL) == FALSE); +} + +/** + * @brief Multi-chain progress bar manager for MCMC computations + * + * This class provides a thread-safe progress bar that works in both RStudio + * console and terminal environments. It supports Unicode theming with colored + * progress indicators and proper cursor positioning. + * + * Key features: + * - Multi-chain progress tracking with atomic operations + * - RStudio vs terminal environment detection and adaptation + * - Unicode and classic theming options + * - ANSI color support with proper visual length calculations + * - Thread-safe printing with mutex protection + * - Console width adaptation and change detection + * - User interrupt checking + */ +class ProgressManager { + +public: + + ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_ = 10, int progress_type = 2, bool useUnicode_ = true); + void update(int chainId); + void finish(); + bool shouldExit() const; + +private: + + void checkConsoleWidthChange(); + int getConsoleWidth() const; + std::string formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal = false) const; + std::string formatTimeInfo(double elapsed, double eta) const; + std::string formatDuration(double seconds) const; + void setupTheme(); + + bool isWarmupPhase() const { + for (auto c : progress) + if (c < nWarmup) + return true; + return false; + } + bool isWarmupPhase(const int chain_id) const { + return progress[chain_id] < nWarmup; + } + + void print(); + + void update_prefixes(int width); + + void maybePadToLength(std::string& content) const; + + // Configuration parameters + int nChains; ///< Number of parallel chains + int nIter; ///< TOTAL Iterations per chain + int nWarmup; ///< Warmup iterations per chain + int printEvery; ///< Print frequency + int no_spaces_for_total; ///< Spacing for total line alignment + int lastPrintedLines = 0; ///< Lines printed in last update + int lastPrintedChars = 0; ///< Characters printed in last update (RStudio) + int consoleWidth = 80; ///< Current console width + int lineWidth = 80; ///< Target line width for content + int prevConsoleWidth = -1; ///< Previous console width for change detection + + // Environment and state flags + bool isRStudio = false; ///< Whether running in RStudio console + bool needsToExit = false; ///< User interrupt flag + bool widthChanged = false; ///< Console width changed flag + + // Visual configuration + int barWidth = 40; ///< Progress bar width in characters + int progress_type = 2; ///< Progress bar style type (0 = "none", 1 = "total", 2 = "per-chain") + bool useUnicode = true; ///< Use Unicode vs ASCII theme + + // Theme tokens + std::string lhsToken; ///< Left bracket/delimiter + std::string rhsToken; ///< Right bracket/delimiter + std::string filledToken; ///< Filled progress character + std::string emptyToken; ///< Empty progress character + std::string partialTokenMore; ///< Partial progress (>50%) + std::string partialTokenLess; ///< Partial progress (<50%) + std::string chain_prefix; ///< Chain label prefix + std::string total_prefix; ///< Total label prefix + std::string total_padding; ///< Padding for total line alignment + + // progress tracking + std::vector progress; ///< Per-chain progress counters + + // Timing + Clock::time_point start; ///< Start time + std::chrono::time_point lastPrint; ///< Last print time + + // Thread synchronization + std::mutex printMutex; ///< Mutex for thread-safe printing +}; + +#endif // PROGRESS_MANAGER_H \ No newline at end of file diff --git a/src/sampler_output.h b/src/sampler_output.h index f36d8e1d..8f899d00 100644 --- a/src/sampler_output.h +++ b/src/sampler_output.h @@ -30,6 +30,7 @@ struct SamplerOutput { arma::vec energy_samples; int chain_id; bool has_indicator; + bool userInterrupt; }; #endif