From 671d3e107825e78af26092f4dc217f613937f44f Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 22 Sep 2025 09:43:40 +0200 Subject: [PATCH 1/6] parallel progress bar and interrupt --- ppt2.R | 4 + src/RcppExports.cpp | 14 ++ src/bgm_parallel.cpp | 16 +- src/bgm_sampler.cpp | 20 +- src/bgm_sampler.h | 5 +- src/progress_manager.cpp | 398 +++++++++++++++++++++++++++++++++++++++ src/progress_manager.h | 119 ++++++++++++ 7 files changed, 559 insertions(+), 17 deletions(-) create mode 100644 ppt2.R create mode 100644 src/progress_manager.cpp create mode 100644 src/progress_manager.h diff --git a/ppt2.R b/ppt2.R new file mode 100644 index 00000000..60ec82d8 --- /dev/null +++ b/ppt2.R @@ -0,0 +1,4 @@ +Rcpp::sourceCpp("src/progress_manager.cpp") +runMCMC_parallel(3, 500) +runMCMC_parallel(3, 500, useUnicode = FALSE) +runMCMC_parallel(3, 500, useUnicode = FALSE, display_progress = TRUE) \ No newline at end of file diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index e829cc45..97f0f246 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -177,6 +177,19 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// runMCMC_parallel +void runMCMC_parallel(int nChains, int nIter, bool display_progress, bool useUnicode); +RcppExport SEXP _bgms_runMCMC_parallel(SEXP nChainsSEXP, SEXP nIterSEXP, SEXP display_progressSEXP, SEXP useUnicodeSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< int >::type nChains(nChainsSEXP); + Rcpp::traits::input_parameter< int >::type nIter(nIterSEXP); + Rcpp::traits::input_parameter< bool >::type display_progress(display_progressSEXP); + Rcpp::traits::input_parameter< bool >::type useUnicode(useUnicodeSEXP); + runMCMC_parallel(nChains, nIter, display_progress, useUnicode); + return R_NilValue; +END_RCPP +} static const R_CallMethodDef CallEntries[] = { {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 35}, @@ -187,6 +200,7 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, + {"_bgms_runMCMC_parallel", (DL_FUNC) &_bgms_runMCMC_parallel, 4}, {NULL, NULL, 0} }; diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index 052a36b7..cf423ed3 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -6,6 +6,7 @@ #include #include #include "rng_utils.h" +#include "progress_manager.h" using namespace Rcpp; using namespace RcppParallel; @@ -87,6 +88,7 @@ struct GibbsChainRunner : public Worker { // Wrapped RNG engines const std::vector& chain_rngs; + ProgressManager& pm; // output buffer std::vector& results; @@ -121,7 +123,8 @@ struct GibbsChainRunner : public Worker { int nuts_max_depth, bool learn_mass_matrix, const std::vector& chain_rngs, - std::vector& results + std::vector& results, + ProgressManager& pm ) : observations(observations), num_categories(num_categories), @@ -152,7 +155,8 @@ struct GibbsChainRunner : public Worker { nuts_max_depth(nuts_max_depth), learn_mass_matrix(learn_mass_matrix), chain_rngs(chain_rngs), - results(results) + results(results), + pm(pm) {} void operator()(std::size_t begin, std::size_t end) { @@ -194,7 +198,8 @@ struct GibbsChainRunner : public Worker { hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, - rng + rng, + pm ); out.result = result; @@ -301,6 +306,7 @@ Rcpp::List run_bgm_parallel( chain_rngs[c] = SafeRNG(seed + c); } + ProgressManager pm(num_chains, burnin + iter, 50); GibbsChainRunner worker( observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, @@ -308,8 +314,8 @@ Rcpp::List run_bgm_parallel( counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, - pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, - chain_rngs, results + sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, + chain_rngs, results, pm ); { diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index 562348ee..86f8ae7b 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -10,8 +10,13 @@ #include "mcmc_nuts.h" #include "mcmc_rwm.h" #include "mcmc_utils.h" +<<<<<<< HEAD #include "print_mutex.h" #include "sbm_edge_prior.h" +======= +// #include "print_mutex.h" +#include "gibbs_functions_edge_prior.h" +>>>>>>> 9f0ad36 (parallel progress bar and interrupt) #include "rng_utils.h" using namespace Rcpp; @@ -1197,7 +1202,8 @@ Rcpp::List run_gibbs_sampler_bgm( const int hmc_num_leapfrogs, const int nuts_max_depth, const bool learn_mass_matrix, - SafeRNG& rng + SafeRNG& rng, + ProgressManager& pm ) { // --- Setup: dimensions and storage structures const int num_variables = observations.n_cols; @@ -1296,18 +1302,12 @@ Rcpp::List run_gibbs_sampler_bgm( ); const int total_iter = warmup_schedule.total_burnin + iter; - const int print_every = std::max(1, total_iter / 10); // --- Main Gibbs sampling loop for (int iteration = 0; iteration < total_iter; iteration++) { - if (iteration % print_every == 0) { - tbb::mutex::scoped_lock lock(get_print_mutex()); - std::cout - << "[bgm] chain " << chain_id - << " iteration " << iteration - << " / " << total_iter - << std::endl; - } + + pm.update(chain_id - 1); + if (pm.shouldExit()) break; // Shuffle update order of edge indices order = arma_randperm(rng, num_pairwise); diff --git a/src/bgm_sampler.h b/src/bgm_sampler.h index 7bb1403d..6c3c1fe7 100644 --- a/src/bgm_sampler.h +++ b/src/bgm_sampler.h @@ -1,6 +1,6 @@ #pragma once #include - +#include "progress_manager.h" // forward declaration struct SafeRNG; @@ -34,5 +34,6 @@ Rcpp::List run_gibbs_sampler_bgm( const int hmc_num_leapfrogs, const int nuts_max_depth, const bool learn_mass_matrix, - SafeRNG& rng + SafeRNG& rng, + ProgressManager& pm ); \ No newline at end of file diff --git a/src/progress_manager.cpp b/src/progress_manager.cpp new file mode 100644 index 00000000..6562ee15 --- /dev/null +++ b/src/progress_manager.cpp @@ -0,0 +1,398 @@ +#include "progress_manager.h" + + +// ProgressManager Implementation + +ProgressManager::ProgressManager(int nChains_, int nIter_, int printEvery_, bool useUnicode_) + : nChains(nChains_), nIter(nIter_), progress(nChains_), printEvery(printEvery_), + lastLineLengths(nChains_ + 2, 0), useUnicode(useUnicode_) { // +2 for total and time lines + + for (int i = 0; i < nChains; i++) progress[i] = 0; + start = Clock::now(); + lastPrint.store(start); + + // Check if we're in RStudio + Environment base("package:base"); + Function getOption = base["getOption"]; + Function Sysgetenv = base["Sys.getenv"]; + SEXP s = Sysgetenv("RSTUDIO"); + isRStudio = as(s) == "1"; + + no_spaces_for_total = 3 + static_cast(std::log10(nChains)); + + if (isRStudio) { + consoleWidth = getConsoleWidth(); + lineWidth = std::max(10, std::min(consoleWidth - 25, 70)); + } else { + // For terminal, use default console width + consoleWidth = 80; + lineWidth = 70; + } + + // cleverly determine barwidth so no line wrapping occurs + if (lineWidth <= 5) { + // TODO: we don't want to print anything in this case + barWidth = 0; + } else if (lineWidth < 20) { + barWidth = lineWidth - 10; + } else if (lineWidth < 40) { + barWidth = lineWidth - 15; + } else { // > 40 + barWidth = std::min(40, lineWidth - 30); + } + + if (isRStudio) { + barWidth = std::max(10, barWidth - 20); // minimum bar width of 10 for RStudio + } + + // Set up theme + setupTheme(); + update_prefixes(consoleWidth); +} + +ProgressManager::~ProgressManager() { + // Skip destructor print to avoid double printing + // The final state is already shown by the last regular print +} + +void ProgressManager::update(int chainId) { + progress[chainId]++; + totalDone++; + + // Only chain 0 handles printing to avoid race conditions + if (chainId == 0 && progress[chainId] % printEvery == 0) { + auto now = Clock::now(); + auto lastPrintTime = lastPrint.load(); + std::chrono::duration sinceLast = now - lastPrintTime; + + // Throttle printing to avoid spamming + if (sinceLast.count() >= 0.5) { + print(); + lastPrint.store(now); + } + } + + // Check for user interrupts and console width changes less frequently to reduce overhead + if (chainId == 0 && progress[chainId] % (printEvery * 5) == 0) { + needsToExit = checkInterrupt(); + // Also check for console width changes occasionally + checkConsoleWidthChange(); + } +} + +bool ProgressManager::shouldExit() const { + return needsToExit; +} + +void ProgressManager::checkConsoleWidthChange() { + if (!isRStudio) return; + + Environment base("package:base"); + Function getOption = base["getOption"]; + int currentWidth = getConsoleWidth(); + + if (prevConsoleWidth == -1) { + // First time, just store the current width + prevConsoleWidth = consoleWidth; + widthChanged = false; + return; + } + + if (currentWidth != consoleWidth && currentWidth > 0) { + // Width has changed + prevConsoleWidth = consoleWidth; + consoleWidth = currentWidth; + widthChanged = true; + } else { + widthChanged = false; + } +} + +int ProgressManager::getConsoleWidth() { + Environment base("package:base"); + Function getOption = base["getOption"]; + SEXP s = getOption("width", 0); + int width = as(s); + // Remove the +3 adjustment to test actual console width + return width + 3; +} + +std::string ProgressManager::formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal) { + std::ostringstream builder; + + double exactFilled = fraction * barWidth; + int filled = std::max(0, std::min(int(exactFilled), barWidth)); + + // Build progress bar with theme + std::string progressBar = lhsToken; + + // Add filled tokens + for (int i = 0; i < filled; i++) { + progressBar += filledToken; + } + + // Add partial token if needed + if (filled < barWidth) { + double partialAmount = exactFilled - filled; + if (partialAmount > 0) { + if (partialAmount > 0.5) { + progressBar += partialTokenMore; + } else { + progressBar += partialTokenLess; + } + filled++; // Account for the partial token + } + } + + // Add empty tokens + for (int i = filled; i < barWidth; i++) { + progressBar += emptyToken; + } + + progressBar += rhsToken; + + if (isTotal) { + builder << total_prefix << ":" << std::string(no_spaces_for_total, ' ') << progressBar << " " << current << "/" << total + << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; + } else { + builder << chain_prefix << " " << chainId << ": " << progressBar << " " << current << "/" << total + << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; + } + + return builder.str(); +} + +std::string ProgressManager::formatTimeInfo(int elapsed, int eta) { + std::ostringstream builder; + builder << "Elapsed: " << elapsed << "s | ETA: " << eta << "s"; + return builder.str(); +} + +void ProgressManager::setupTheme() { + if (useUnicode) { + // Unicode theme + lhsToken = "❨"; + rhsToken = "❩"; + filledToken = "\x1b[34m━\x1b[0m"; // Blue filled + emptyToken = "\x1b[37m━\x1b[0m"; // Gray empty + partialTokenMore = "\x1b[34m╸\x1b[0m"; // Blue partial (> 0.5) + partialTokenLess = "\x1b[37m╺\x1b[0m"; // Gray partial (< 0.5) + } else { + // Classic theme + lhsToken = "["; + rhsToken = "]"; + filledToken = "="; + emptyToken = " "; + partialTokenMore = " "; + partialTokenLess = " "; + } +} + +size_t ProgressManager::getVisualLength(const std::string& str) { + size_t visualLength = 0; + bool inEscapeSequence = false; + + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '\x1b' && i + 1 < str.length() && str[i + 1] == '[') { + inEscapeSequence = true; + i++; // Skip the '[' + } else if (inEscapeSequence && str[i] == 'm') { + inEscapeSequence = false; + } else if (!inEscapeSequence) { + visualLength++; + } + } + + return visualLength; +} + +void ProgressManager::print() { + std::lock_guard lock(printMutex); + + auto now = Clock::now(); + double elapsed = std::chrono::duration_cast(now - start).count(); + + int totalWork = nChains * nIter; + int done = totalDone; + double fracTotal = double(done) / totalWork; + // should actually be the eta of the slowest chain! + double eta = (fracTotal > 0) ? elapsed / fracTotal - elapsed : 0.0; + + std::ostringstream out; + int totalChars = 0; + int lineIndex = 0; + + if (isRStudio) { + // if this is not the first print, delete previous content + if (lastPrintedChars > 0) { + out << "\x1b[" << std::to_string(lastPrintedChars) << "D"; + } + + std::string clearedLine; + + // Print progress for each chain + for (int i = 0; i < nChains; i++) { + double frac = double(progress[i]) / nIter; + std::string chainProgress = formatProgressBar(i + 1, progress[i], nIter, frac); + + // Pad each line to exactly lineWidth characters (before adding \n) + if (chainProgress.length() < lineWidth) { + chainProgress += std::string(lineWidth - chainProgress.length(), ' '); + } else if (chainProgress.length() > lineWidth) { + chainProgress = chainProgress.substr(0, lineWidth); + } + + // Add newline to every line (including the last one) + clearedLine = chainProgress + "\n"; + out << clearedLine; + + lastLineLengths[lineIndex] = clearedLine.length(); // This will be lineWidth + 1 + totalChars += lastLineLengths[lineIndex]; + lineIndex++; + } + + // Print total progress + std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + + // Pad total progress line to exactly lineWidth characters (before adding \n) + if (totalProgress.length() < lineWidth) { + totalProgress += std::string(lineWidth - totalProgress.length(), ' '); + } else if (totalProgress.length() > lineWidth) { + totalProgress = totalProgress.substr(0, lineWidth); + } + + clearedLine = totalProgress + "\n"; + out << clearedLine; + + lastLineLengths[lineIndex] = clearedLine.length(); // This will be lineWidth + 1 + totalChars += lastLineLengths[lineIndex]; + lineIndex++; + + // Print time info + std::string timeInfo = formatTimeInfo(int(elapsed), int(eta)); + + // Pad time info line to exactly lineWidth characters (before adding \n) + if (timeInfo.length() < lineWidth) { + timeInfo += std::string(lineWidth - timeInfo.length(), ' '); + } else if (timeInfo.length() > lineWidth) { + timeInfo = timeInfo.substr(0, lineWidth); + } + + clearedLine = timeInfo + "\n"; + out << clearedLine; + + lastLineLengths[lineIndex] = clearedLine.length(); // This will be lineWidth + 1 + totalChars += lastLineLengths[lineIndex]; + + // Track characters and lines for next cursor movement + lastPrintedChars = totalChars; + lastPrintedLines = lineIndex; + + std::string out_str = out.str(); + + assert(out_str.length() == static_cast(totalChars)); + Rcpp::Rcout << out_str; + + } else { + // Terminal: Use carriage return and clear lines approach + if (lastPrintedLines > 0) { + // Move cursor up to start of our content and clear everything + for (int i = 0; i < lastPrintedLines; i++) { + out << "\x1b[1A\x1b[2K"; // Move up one line and clear entire line + } + } + + // Print progress for each chain + for (int i = 0; i < nChains; i++) { + double frac = double(progress[i]) / nIter; + std::string chainProgress = formatProgressBar(i + 1, progress[i], nIter, frac); + out << chainProgress << "\n"; + } + + // Print total progress + std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + out << totalProgress << "\n"; + + // Print time info + std::string timeInfo = formatTimeInfo(int(elapsed), int(eta)); + out << timeInfo << "\n"; + + // Track total lines printed (chains + total + time) + lastPrintedLines = nChains + 2; + + Rcpp::Rcout << out.str(); + } +} + +void ProgressManager::update_prefixes(int width) { + if (width < 20) { + chain_prefix = "C"; + total_prefix = "T"; + } else if (width < 30) { + chain_prefix = "Ch"; + total_prefix = "Tot"; + } else { + chain_prefix = "Chain"; + total_prefix = "Total"; + } +} + +std::string ProgressManager::addPadding(const std::string& content) { + int paddingNeeded = consoleWidth - content.length(); + if (paddingNeeded > 0) { + return content + std::string(paddingNeeded, ' '); + } else { + return content; + } +} + +std::string ProgressManager::clearLineLeftovers(const std::string& newContent, int oldLineLength) { + if (newContent.length() < oldLineLength) { + // Add spaces to clear old content + return newContent + std::string(oldLineLength - newContent.length(), ' '); + } + return newContent; +} + +// Worker functor for RcppParallel +struct ChainWorker : public Worker { + int nIter; + ProgressManager ± + bool display_progress; + + ChainWorker(int nIter_, ProgressManager &pm_, bool display_progress_) + : nIter(nIter_), pm(pm_), display_progress(display_progress_) {} + + void operator()(std::size_t begin, std::size_t end) { + + auto chainId = begin; + + for (int i = 0; i < nIter; i++) { + // ---- Simulated work ---- + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + + // ---- Update state ---- + pm.update(chainId); + if (pm.shouldExit()) break; + // if (Progress::check_abort()) Rcpp::checkUserInterrupt(); + } + } +}; + + +// [[Rcpp::export]] +void runMCMC_parallel(int nChains = 4, int nIter = 100, bool display_progress = true, bool useUnicode = true) { + + ProgressManager pm(nChains, nIter, 10, useUnicode); + ChainWorker worker(nIter, pm, display_progress); + + // Run each chain in parallel + parallelFor(0, nChains, worker); + + if (pm.shouldExit()) { + Rcpp::Rcout << "\nComputation interrupted by user.\n"; + } else { + Rcpp::Rcout << "\nAll chains finished!\n"; + } +} + diff --git a/src/progress_manager.h b/src/progress_manager.h new file mode 100644 index 00000000..23537032 --- /dev/null +++ b/src/progress_manager.h @@ -0,0 +1,119 @@ +#ifndef PROGRESS_MANAGER_H +#define PROGRESS_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Rcpp; +using namespace RcppParallel; +using Clock = std::chrono::steady_clock; + +// Interrupt checking functions +// https://github.com/kforner/rcpp_progress/blob/d851ac62fd0314239e852392de7face5fa4bf48e/inst/include/interrupts.hpp#L24-L31 +static void chkIntFn(void *dummy) { + R_CheckUserInterrupt(); +} + +// this will call the above in a top-level context so it won't longjmp-out of your context +inline bool checkInterrupt() { + return (R_ToplevelExec(chkIntFn, NULL) == FALSE); +} + +/** + * @brief Multi-chain progress bar manager for MCMC computations + * + * This class provides a thread-safe progress bar that works in both RStudio + * console and terminal environments. It supports Unicode theming with colored + * progress indicators and proper cursor positioning. + * + * Key features: + * - Multi-chain progress tracking with atomic operations + * - RStudio vs terminal environment detection and adaptation + * - Unicode and classic theming options + * - ANSI color support with proper visual length calculations + * - Thread-safe printing with mutex protection + * - Console width adaptation and change detection + * - User interrupt checking + */ +class ProgressManager { + +public: + + ProgressManager(int nChains_, int nIter_, int printEvery_ = 10, bool useUnicode_ = false); + ~ProgressManager(); + void update(int chainId); + bool shouldExit() const; + +private: + + void checkConsoleWidthChange(); + int getConsoleWidth(); + std::string formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal = false); + std::string formatTimeInfo(int elapsed, int eta); + void setupTheme(); + size_t getVisualLength(const std::string& str); + + void print(); + + void update_prefixes(int width); + + std::string addPadding(const std::string& content); + + std::string clearLineLeftovers(const std::string& newContent, int oldLineLength); + + // Configuration parameters + int nChains; ///< Number of parallel chains + int nIter; ///< Iterations per chain + int printEvery; ///< Print frequency + int no_spaces_for_total; ///< Spacing for total line alignment + int lastPrintedLines = 0; ///< Lines printed in last update + int lastPrintedChars = 0; ///< Characters printed in last update (RStudio) + int consoleWidth = 80; ///< Current console width + int lineWidth = 80; ///< Target line width for content + int prevConsoleWidth = -1; ///< Previous console width for change detection + + // Environment and state flags + bool isRStudio = false; ///< Whether running in RStudio console + bool needsToExit = false; ///< User interrupt flag + bool widthChanged = false; ///< Console width changed flag + + // Visual configuration + int barWidth = 40; ///< Progress bar width in characters + bool useUnicode = true; ///< Use Unicode vs ASCII theme + + // Theme tokens + std::string lhsToken; ///< Left bracket/delimiter + std::string rhsToken; ///< Right bracket/delimiter + std::string filledToken; ///< Filled progress character + std::string emptyToken; ///< Empty progress character + std::string partialTokenMore; ///< Partial progress (>50%) + std::string partialTokenLess; ///< Partial progress (<50%) + std::string chain_prefix; ///< Chain label prefix + std::string total_prefix; ///< Total label prefix + + // Line tracking for cursor positioning + std::vector lastLineLengths; ///< Track length of each printed line + + // Thread-safe progress tracking + std::vector> progress; ///< Per-chain progress counters + std::atomic totalDone{0}; ///< Total completed iterations + + // Timing + Clock::time_point start; ///< Start time + std::atomic> lastPrint; ///< Last print time + + // Thread synchronization + std::mutex printMutex; ///< Mutex for thread-safe printing +}; + +#endif // PROGRESS_MANAGER_H \ No newline at end of file From 6198ad314929fdaeb4c46bb2661709acc43b3907 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 22 Sep 2025 13:11:09 +0200 Subject: [PATCH 2/6] let's rebase first --- R/RcppExports.R | 12 +- R/bgm.R | 18 ++- src/RcppExports.cpp | 57 ++------ src/bgm_parallel.cpp | 12 +- src/bgm_sampler.cpp | 7 +- src/progress_manager.cpp | 290 ++++++++++++++++++--------------------- src/progress_manager.h | 40 +++--- 7 files changed, 203 insertions(+), 233 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 1f228e62..01cf2a38 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -13,12 +13,8 @@ get_explog_switch <- function() { .Call(`_bgms_get_explog_switch`) } -rcpp_ieee754_exp <- function(x) { - .Call(`_bgms_rcpp_ieee754_exp`, x) -} - -rcpp_ieee754_log <- function(x) { - .Call(`_bgms_rcpp_ieee754_log`, x) +run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { + .Call(`_bgms_run_bgm_parallel`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) } sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, iter) { @@ -33,3 +29,7 @@ compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) } +runMCMC_parallel <- function(nChains = 4L, nIter = 100L, nWarmup = 100L, progress_type = 2L, useUnicode = FALSE) { + invisible(.Call(`_bgms_runMCMC_parallel`, nChains, nIter, nWarmup, progress_type, useUnicode)) +} + diff --git a/R/bgm.R b/R/bgm.R index 3c99aec7..3c20206c 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -237,7 +237,7 @@ bgm = function( dirichlet_alpha = 1, lambda = 1, na_action = c("listwise", "impute"), - display_progress = TRUE, + display_progress = c("per-chain", "total", "none"), update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), target_accept, hmc_num_leapfrogs = 100, @@ -324,7 +324,15 @@ bgm = function( ".")) #Check display_progress ------------------------------------------------------ - display_progress = check_logical(display_progress, "display_progress") + if (is.logical(display_progress) && length(display_progress) == 1) { + if (is.na(display_progress)) + stop("The display_progress argument must be a single logical value, but not NA.") + display_progress = if (display_progress) "per-chain" else "none" + } else { + display_progress = match.arg(display_progress) + } + + progress_type = if (display_progress == "per-chain") 2L else if (display_progress == "total") 1L else 0L #Format the data input ------------------------------------------------------- data = reformat_data(x = x, @@ -428,7 +436,7 @@ bgm = function( target_accept = target_accept, pairwise_stats = pairwise_stats, hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, num_chains = chains, - nThreads = cores, seed = seed + nThreads = cores, seed = seed, progress_type = progress_type ) # Main output handler in the wrapper function @@ -457,5 +465,9 @@ bgm = function( output$nuts_diag = nuts_diag } + userInterrupt = any(vapply(out, `[[`, "userInterrupt", logical(1L))) + if (userInterrupt) + warning("Stopped sampling after user interrupt, results are likely uninterpretable.") + return(output) } \ No newline at end of file diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 97f0f246..3f048fff 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -57,8 +57,8 @@ BEGIN_RCPP END_RCPP } // run_bgm_parallel -Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed); -RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) { +Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double interaction_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& num_obs_categories, const arma::imat& sufficient_blume_capel, double threshold_alpha, double threshold_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& sufficient_pairwise, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type); +RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP interaction_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP threshold_alphaSEXP, SEXP threshold_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP sufficient_pairwiseSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -93,39 +93,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP); Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); Rcpp::traits::input_parameter< uint64_t >::type seed(seedSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed)); - return rcpp_result_gen; -END_RCPP -} -// get_explog_switch -Rcpp::String get_explog_switch(); -RcppExport SEXP _bgms_get_explog_switch() { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - rcpp_result_gen = Rcpp::wrap(get_explog_switch()); - return rcpp_result_gen; -END_RCPP -} -// rcpp_ieee754_exp -Rcpp::NumericVector rcpp_ieee754_exp(Rcpp::NumericVector x); -RcppExport SEXP _bgms_rcpp_ieee754_exp(SEXP xSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP); - rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_exp(x)); - return rcpp_result_gen; -END_RCPP -} -// rcpp_ieee754_log -Rcpp::NumericVector rcpp_ieee754_log(Rcpp::NumericVector x); -RcppExport SEXP _bgms_rcpp_ieee754_log(SEXP xSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP); - rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_log(x)); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); return rcpp_result_gen; END_RCPP } @@ -178,29 +147,27 @@ BEGIN_RCPP END_RCPP } // runMCMC_parallel -void runMCMC_parallel(int nChains, int nIter, bool display_progress, bool useUnicode); -RcppExport SEXP _bgms_runMCMC_parallel(SEXP nChainsSEXP, SEXP nIterSEXP, SEXP display_progressSEXP, SEXP useUnicodeSEXP) { +void runMCMC_parallel(int nChains, int nIter, int nWarmup, int progress_type, bool useUnicode); +RcppExport SEXP _bgms_runMCMC_parallel(SEXP nChainsSEXP, SEXP nIterSEXP, SEXP nWarmupSEXP, SEXP progress_typeSEXP, SEXP useUnicodeSEXP) { BEGIN_RCPP Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< int >::type nChains(nChainsSEXP); Rcpp::traits::input_parameter< int >::type nIter(nIterSEXP); - Rcpp::traits::input_parameter< bool >::type display_progress(display_progressSEXP); + Rcpp::traits::input_parameter< int >::type nWarmup(nWarmupSEXP); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); Rcpp::traits::input_parameter< bool >::type useUnicode(useUnicodeSEXP); - runMCMC_parallel(nChains, nIter, display_progress, useUnicode); + runMCMC_parallel(nChains, nIter, nWarmup, progress_type, useUnicode); return R_NilValue; END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 35}, - {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31}, - {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, - {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, - {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, + {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 33}, + {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, - {"_bgms_runMCMC_parallel", (DL_FUNC) &_bgms_runMCMC_parallel, 4}, + {"_bgms_runMCMC_parallel", (DL_FUNC) &_bgms_runMCMC_parallel, 5}, {NULL, NULL, 0} }; diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index cf423ed3..0f168d9c 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -7,6 +7,7 @@ #include #include "rng_utils.h" #include "progress_manager.h" +#include "mcmc_adaptation.h" using namespace Rcpp; using namespace RcppParallel; @@ -296,7 +297,8 @@ Rcpp::List run_bgm_parallel( bool learn_mass_matrix, int num_chains, int nThreads, - uint64_t seed + uint64_t seed, + int progress_type ) { std::vector results(num_chains); @@ -306,7 +308,11 @@ Rcpp::List run_bgm_parallel( chain_rngs[c] = SafeRNG(seed + c); } - ProgressManager pm(num_chains, burnin + iter, 50); + // only used to determine the total no. burnin iterations, a bit hacky + WarmupSchedule warmup_schedule_temp(burnin, edge_selection, (update_method != "adaptive-metropolis")); + int total_burnin = warmup_schedule_temp.total_burnin; + + ProgressManager pm(num_chains, iter + total_burnin, total_burnin, 50, progress_type); GibbsChainRunner worker( observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, @@ -335,5 +341,7 @@ Rcpp::List run_bgm_parallel( } } + pm.finish(); + return output; } diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index 86f8ae7b..ae5d93a8 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -1303,11 +1303,15 @@ Rcpp::List run_gibbs_sampler_bgm( const int total_iter = warmup_schedule.total_burnin + iter; + bool userInterrupt = false; // --- Main Gibbs sampling loop for (int iteration = 0; iteration < total_iter; iteration++) { pm.update(chain_id - 1); - if (pm.shouldExit()) break; + if (pm.shouldExit()) { + userInterrupt = true; + break; + } // Shuffle update order of edge indices order = arma_randperm(rng, num_pairwise); @@ -1424,6 +1428,7 @@ Rcpp::List run_gibbs_sampler_bgm( out["allocations"] = allocation_samples; } + out["userInterrupt"] = userInterrupt; out["chain_id"] = chain_id; return out; } \ No newline at end of file diff --git a/src/progress_manager.cpp b/src/progress_manager.cpp index 6562ee15..6191dede 100644 --- a/src/progress_manager.cpp +++ b/src/progress_manager.cpp @@ -1,24 +1,23 @@ #include "progress_manager.h" - -// ProgressManager Implementation - -ProgressManager::ProgressManager(int nChains_, int nIter_, int printEvery_, bool useUnicode_) - : nChains(nChains_), nIter(nIter_), progress(nChains_), printEvery(printEvery_), - lastLineLengths(nChains_ + 2, 0), useUnicode(useUnicode_) { // +2 for total and time lines +ProgressManager::ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_, int progress_type_, bool useUnicode_) + : nChains(nChains_), nIter(nIter_ + nWarmup_), nWarmup(nWarmup_), progress(nChains_), printEvery(printEvery_), + progress_type(progress_type_), useUnicode(useUnicode_) { // +2 for total and time lines for (int i = 0; i < nChains; i++) progress[i] = 0; start = Clock::now(); - lastPrint.store(start); + lastPrint = Clock::now(); // Check if we're in RStudio - Environment base("package:base"); - Function getOption = base["getOption"]; - Function Sysgetenv = base["Sys.getenv"]; + Rcpp::Environment base("package:base"); + Rcpp::Function getOption = base["getOption"]; + Rcpp::Function Sysgetenv = base["Sys.getenv"]; SEXP s = Sysgetenv("RSTUDIO"); - isRStudio = as(s) == "1"; + isRStudio = Rcpp::as(s) == "1"; no_spaces_for_total = 3 + static_cast(std::log10(nChains)); + if (progress_type == 1) no_spaces_for_total = 0; // no total line + total_padding = std::string(no_spaces_for_total, ' '); if (isRStudio) { consoleWidth = getConsoleWidth(); @@ -50,25 +49,20 @@ ProgressManager::ProgressManager(int nChains_, int nIter_, int printEvery_, bool update_prefixes(consoleWidth); } -ProgressManager::~ProgressManager() { - // Skip destructor print to avoid double printing - // The final state is already shown by the last regular print -} - void ProgressManager::update(int chainId) { progress[chainId]++; - totalDone++; - // Only chain 0 handles printing to avoid race conditions - if (chainId == 0 && progress[chainId] % printEvery == 0) { + // Only chain 0 actually does the printing/ checking for user interrupts + if (chainId != 0) return; + + if (progress[chainId] % printEvery == 0) { auto now = Clock::now(); - auto lastPrintTime = lastPrint.load(); - std::chrono::duration sinceLast = now - lastPrintTime; + std::chrono::duration sinceLast = now - lastPrint; // Throttle printing to avoid spamming - if (sinceLast.count() >= 0.5) { + if (progress_type != 0 && sinceLast.count() >= 0.5) { print(); - lastPrint.store(now); + lastPrint = now; } } @@ -80,6 +74,18 @@ void ProgressManager::update(int chainId) { } } +void ProgressManager::finish() { + + if (progress_type == 0) return; // No progress display + + // Mark all chains as complete and print one final time + for (int i = 0; i < nChains; i++) { + progress[i] = nIter; + } + print(); + +} + bool ProgressManager::shouldExit() const { return needsToExit; } @@ -87,8 +93,8 @@ bool ProgressManager::shouldExit() const { void ProgressManager::checkConsoleWidthChange() { if (!isRStudio) return; - Environment base("package:base"); - Function getOption = base["getOption"]; + Rcpp::Environment base("package:base"); + Rcpp::Function getOption = base["getOption"]; int currentWidth = getConsoleWidth(); if (prevConsoleWidth == -1) { @@ -109,57 +115,62 @@ void ProgressManager::checkConsoleWidthChange() { } int ProgressManager::getConsoleWidth() { - Environment base("package:base"); - Function getOption = base["getOption"]; + Rcpp::Environment base("package:base"); + Rcpp::Function getOption = base["getOption"]; SEXP s = getOption("width", 0); - int width = as(s); + int width = Rcpp::as(s); // Remove the +3 adjustment to test actual console width return width + 3; } std::string ProgressManager::formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal) { - std::ostringstream builder; + std::ostringstream builder; - double exactFilled = fraction * barWidth; - int filled = std::max(0, std::min(int(exactFilled), barWidth)); + double exactFilled = fraction * barWidth; + int filled = std::max(0, std::min(int(exactFilled), barWidth)); - // Build progress bar with theme - std::string progressBar = lhsToken; + // Build progress bar with theme + std::string progressBar = lhsToken; - // Add filled tokens - for (int i = 0; i < filled; i++) { - progressBar += filledToken; - } + // Add filled tokens + for (int i = 0; i < filled; i++) { + progressBar += filledToken; + } - // Add partial token if needed - if (filled < barWidth) { - double partialAmount = exactFilled - filled; - if (partialAmount > 0) { - if (partialAmount > 0.5) { - progressBar += partialTokenMore; - } else { - progressBar += partialTokenLess; + // Add partial token if needed + if (filled < barWidth) { + double partialAmount = exactFilled - filled; + if (partialAmount > 0) { + if (partialAmount > 0.5) { + progressBar += partialTokenMore; + } else { + progressBar += partialTokenLess; + } + filled++; // Account for the partial token } - filled++; // Account for the partial token } - } - // Add empty tokens - for (int i = filled; i < barWidth; i++) { - progressBar += emptyToken; - } + // Add empty tokens + for (int i = filled; i < barWidth; i++) { + progressBar += emptyToken; + } - progressBar += rhsToken; + progressBar += rhsToken; - if (isTotal) { - builder << total_prefix << ":" << std::string(no_spaces_for_total, ' ') << progressBar << " " << current << "/" << total - << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; - } else { - builder << chain_prefix << " " << chainId << ": " << progressBar << " " << current << "/" << total - << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; - } + if (isTotal) { - return builder.str(); + std::string warmupOrSampling = isWarmupPhase() ? "(Warmup)" : "(Sampling)"; + builder << total_prefix << total_padding << warmupOrSampling << ": " << progressBar << " " << current << "/" << total + << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; + + } else { + + std::string warmupOrSampling = isWarmupPhase(chainId - 1) ? " (Warmup)" : " (Sampling)"; + builder << chain_prefix << " " << chainId << warmupOrSampling << ": " << progressBar << " " << current << "/" << total + << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; + } + + return builder.str(); } std::string ProgressManager::formatTimeInfo(int elapsed, int eta) { @@ -173,10 +184,14 @@ void ProgressManager::setupTheme() { // Unicode theme lhsToken = "❨"; rhsToken = "❩"; - filledToken = "\x1b[34m━\x1b[0m"; // Blue filled - emptyToken = "\x1b[37m━\x1b[0m"; // Gray empty - partialTokenMore = "\x1b[34m╸\x1b[0m"; // Blue partial (> 0.5) - partialTokenLess = "\x1b[37m╺\x1b[0m"; // Gray partial (< 0.5) + // filledToken = "\x1b[34m━\x1b[0m"; // Blue filled + // emptyToken = "\x1b[37m━\x1b[0m"; // Gray empty + // partialTokenMore = "\x1b[34m╸\x1b[0m"; // Blue partial (> 0.5) + // partialTokenLess = "\x1b[37m╺\x1b[0m"; // Gray partial (< 0.5) + filledToken = "━"; // Blue filled + emptyToken = " "; // Gray empty + partialTokenMore = "╸"; // Blue partial (> 0.5) + partialTokenLess = " "; // Gray partial (< 0.5) } else { // Classic theme lhsToken = "["; @@ -213,7 +228,7 @@ void ProgressManager::print() { double elapsed = std::chrono::duration_cast(now - start).count(); int totalWork = nChains * nIter; - int done = totalDone; + int done = std::reduce(progress.begin(), progress.end()); double fracTotal = double(done) / totalWork; // should actually be the eta of the slowest chain! double eta = (fracTotal > 0) ? elapsed / fracTotal - elapsed : 0.0; @@ -222,106 +237,67 @@ void ProgressManager::print() { int totalChars = 0; int lineIndex = 0; - if (isRStudio) { - // if this is not the first print, delete previous content - if (lastPrintedChars > 0) { - out << "\x1b[" << std::to_string(lastPrintedChars) << "D"; - } - - std::string clearedLine; - - // Print progress for each chain - for (int i = 0; i < nChains; i++) { - double frac = double(progress[i]) / nIter; - std::string chainProgress = formatProgressBar(i + 1, progress[i], nIter, frac); - - // Pad each line to exactly lineWidth characters (before adding \n) - if (chainProgress.length() < lineWidth) { - chainProgress += std::string(lineWidth - chainProgress.length(), ' '); - } else if (chainProgress.length() > lineWidth) { - chainProgress = chainProgress.substr(0, lineWidth); - } - - // Add newline to every line (including the last one) - clearedLine = chainProgress + "\n"; - out << clearedLine; - - lastLineLengths[lineIndex] = clearedLine.length(); // This will be lineWidth + 1 - totalChars += lastLineLengths[lineIndex]; - lineIndex++; - } - - // Print total progress - std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); - - // Pad total progress line to exactly lineWidth characters (before adding \n) - if (totalProgress.length() < lineWidth) { - totalProgress += std::string(lineWidth - totalProgress.length(), ' '); - } else if (totalProgress.length() > lineWidth) { - totalProgress = totalProgress.substr(0, lineWidth); - } - - clearedLine = totalProgress + "\n"; - out << clearedLine; - - lastLineLengths[lineIndex] = clearedLine.length(); // This will be lineWidth + 1 - totalChars += lastLineLengths[lineIndex]; - lineIndex++; + // if this is not the first print, delete previous content + if (lastPrintedChars > 0) { - // Print time info - std::string timeInfo = formatTimeInfo(int(elapsed), int(eta)); - - // Pad time info line to exactly lineWidth characters (before adding \n) - if (timeInfo.length() < lineWidth) { - timeInfo += std::string(lineWidth - timeInfo.length(), ' '); - } else if (timeInfo.length() > lineWidth) { - timeInfo = timeInfo.substr(0, lineWidth); - } - - clearedLine = timeInfo + "\n"; - out << clearedLine; - - lastLineLengths[lineIndex] = clearedLine.length(); // This will be lineWidth + 1 - totalChars += lastLineLengths[lineIndex]; - - // Track characters and lines for next cursor movement - lastPrintedChars = totalChars; - lastPrintedLines = lineIndex; - - std::string out_str = out.str(); - - assert(out_str.length() == static_cast(totalChars)); - Rcpp::Rcout << out_str; - - } else { - // Terminal: Use carriage return and clear lines approach - if (lastPrintedLines > 0) { + if (isRStudio) { + out << "\x1b[" << std::to_string(lastPrintedChars) << "D"; + } else { // Move cursor up to start of our content and clear everything for (int i = 0; i < lastPrintedLines; i++) { out << "\x1b[1A\x1b[2K"; // Move up one line and clear entire line } } + } + if (progress_type == 2) { // Print progress for each chain for (int i = 0; i < nChains; i++) { double frac = double(progress[i]) / nIter; std::string chainProgress = formatProgressBar(i + 1, progress[i], nIter, frac); + maybePadToLength(chainProgress); out << chainProgress << "\n"; + totalChars += chainProgress.length() + 1; // +1 for newline } // Print total progress std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + maybePadToLength(totalProgress); out << totalProgress << "\n"; + totalChars += totalProgress.length() + 1; // +1 for newline // Print time info std::string timeInfo = formatTimeInfo(int(elapsed), int(eta)); + maybePadToLength(timeInfo); out << timeInfo << "\n"; + totalChars += timeInfo.length() + 1; // +1 for newline // Track total lines printed (chains + total + time) - lastPrintedLines = nChains + 2; + lastPrintedLines = nChains + 2; // used in a generic terminal + lastPrintedChars = totalChars; // used by RStudio + + } else if (progress_type == 1) { + + // Print total progress + std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + maybePadToLength(totalProgress); + + // Print time info + totalProgress += " " + formatTimeInfo(int(elapsed), int(eta)); + maybePadToLength(totalProgress); + + if (done < totalWork) { + out << totalProgress << "\r"; + } else { + out << totalProgress << "\n"; + } + + // we do not set lastPrintedChars or lastPrintedLines here since we always overwrite the same line - Rcpp::Rcout << out.str(); } + + Rcpp::Rcout << out.str(); + } void ProgressManager::update_prefixes(int width) { @@ -337,25 +313,22 @@ void ProgressManager::update_prefixes(int width) { } } -std::string ProgressManager::addPadding(const std::string& content) { - int paddingNeeded = consoleWidth - content.length(); - if (paddingNeeded > 0) { - return content + std::string(paddingNeeded, ' '); - } else { - return content; - } -} +void ProgressManager::maybePadToLength(std::string& content) const { + if (!isRStudio) return; -std::string ProgressManager::clearLineLeftovers(const std::string& newContent, int oldLineLength) { - if (newContent.length() < oldLineLength) { - // Add spaces to clear old content - return newContent + std::string(oldLineLength - newContent.length(), ' '); + // Pad each line to exactly lineWidth characters (before adding \n) + if (content.length() < lineWidth) { + content += std::string(lineWidth - content.length(), ' '); + } else if (content.length() > lineWidth) { + content = content.substr(0, lineWidth); } - return newContent; } + +// Example usage/ test with RcppParallel +#include // Worker functor for RcppParallel -struct ChainWorker : public Worker { +struct ChainWorker : public RcppParallel::Worker { int nIter; ProgressManager ± bool display_progress; @@ -381,13 +354,14 @@ struct ChainWorker : public Worker { // [[Rcpp::export]] -void runMCMC_parallel(int nChains = 4, int nIter = 100, bool display_progress = true, bool useUnicode = true) { +void runMCMC_parallel(int nChains = 4, int nIter = 100, int nWarmup = 100, int progress_type = 2, bool useUnicode = false) { - ProgressManager pm(nChains, nIter, 10, useUnicode); - ChainWorker worker(nIter, pm, display_progress); + int nTotal = nIter + nWarmup; + ProgressManager pm(nChains, nTotal, nWarmup, 10, progress_type, useUnicode); + ChainWorker worker(nTotal, pm, true); // Run each chain in parallel - parallelFor(0, nChains, worker); + RcppParallel::parallelFor(0, nChains, worker); if (pm.shouldExit()) { Rcpp::Rcout << "\nComputation interrupted by user.\n"; diff --git a/src/progress_manager.h b/src/progress_manager.h index 23537032..375bb2a5 100644 --- a/src/progress_manager.h +++ b/src/progress_manager.h @@ -2,20 +2,17 @@ #define PROGRESS_MANAGER_H #include -#include #include -#include #include #include #include -#include #include +#include #include #include #include +#include -using namespace Rcpp; -using namespace RcppParallel; using Clock = std::chrono::steady_clock; // Interrupt checking functions @@ -49,9 +46,9 @@ class ProgressManager { public: - ProgressManager(int nChains_, int nIter_, int printEvery_ = 10, bool useUnicode_ = false); - ~ProgressManager(); + ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_ = 10, int progress_type = 2, bool useUnicode_ = false); void update(int chainId); + void finish(); bool shouldExit() const; private: @@ -63,17 +60,26 @@ class ProgressManager { void setupTheme(); size_t getVisualLength(const std::string& str); + bool isWarmupPhase() const { + for (auto c : progress) + if (c < nWarmup) + return true; + return false; + } + bool isWarmupPhase(const int chain_id) const { + return progress[chain_id] < nWarmup; + } + void print(); void update_prefixes(int width); - std::string addPadding(const std::string& content); - - std::string clearLineLeftovers(const std::string& newContent, int oldLineLength); + void maybePadToLength(std::string& content) const; // Configuration parameters int nChains; ///< Number of parallel chains - int nIter; ///< Iterations per chain + int nIter; ///< TOTAL Iterations per chain + int nWarmup; ///< Warmup iterations per chain int printEvery; ///< Print frequency int no_spaces_for_total; ///< Spacing for total line alignment int lastPrintedLines = 0; ///< Lines printed in last update @@ -89,6 +95,7 @@ class ProgressManager { // Visual configuration int barWidth = 40; ///< Progress bar width in characters + int progress_type = 2; ///< Progress bar style type (0 = "none", 1 = "total", 2 = "per-chain") bool useUnicode = true; ///< Use Unicode vs ASCII theme // Theme tokens @@ -100,17 +107,14 @@ class ProgressManager { std::string partialTokenLess; ///< Partial progress (<50%) std::string chain_prefix; ///< Chain label prefix std::string total_prefix; ///< Total label prefix + std::string total_padding; ///< Padding for total line alignment - // Line tracking for cursor positioning - std::vector lastLineLengths; ///< Track length of each printed line - - // Thread-safe progress tracking - std::vector> progress; ///< Per-chain progress counters - std::atomic totalDone{0}; ///< Total completed iterations + // progress tracking + std::vector progress; ///< Per-chain progress counters // Timing Clock::time_point start; ///< Start time - std::atomic> lastPrint; ///< Last print time + std::chrono::time_point lastPrint; ///< Last print time // Thread synchronization std::mutex printMutex; ///< Mutex for thread-safe printing From 160772326b5036210eecb12202e4eb0af919b8d4 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 22 Sep 2025 13:16:01 +0200 Subject: [PATCH 3/6] fix rebase --- R/RcppExports.R | 21 +++++++++---- src/RcppExports.cpp | 73 +++++++++++++++++++++++++++++++++------------ 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 01cf2a38..b010d6a6 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -5,16 +5,25 @@ run_bgmCompare_parallel <- function(observations, num_groups, counts_per_categor .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs) } +<<<<<<< HEAD run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) { .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) +======= +run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { + .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) +>>>>>>> eaa3c70 (fix rebase) } get_explog_switch <- function() { .Call(`_bgms_get_explog_switch`) } -run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { - .Call(`_bgms_run_bgm_parallel`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) +rcpp_ieee754_exp <- function(x) { + .Call(`_bgms_rcpp_ieee754_exp`, x) +} + +rcpp_ieee754_log <- function(x) { + .Call(`_bgms_rcpp_ieee754_log`, x) } sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, iter) { @@ -25,11 +34,11 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact .Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) } -compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { - .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) -} - runMCMC_parallel <- function(nChains = 4L, nIter = 100L, nWarmup = 100L, progress_type = 2L, useUnicode = FALSE) { invisible(.Call(`_bgms_runMCMC_parallel`, nChains, nIter, nWarmup, progress_type, useUnicode)) } +compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { + .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) +} + diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 3f048fff..4fd77d61 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -57,8 +57,8 @@ BEGIN_RCPP END_RCPP } // run_bgm_parallel -Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double interaction_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& num_obs_categories, const arma::imat& sufficient_blume_capel, double threshold_alpha, double threshold_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& sufficient_pairwise, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type); -RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP interaction_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP threshold_alphaSEXP, SEXP threshold_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP sufficient_pairwiseSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { +Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type); +RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -94,7 +94,39 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); Rcpp::traits::input_parameter< uint64_t >::type seed(seedSEXP); Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); + rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); + return rcpp_result_gen; +END_RCPP +} +// get_explog_switch +Rcpp::String get_explog_switch(); +RcppExport SEXP _bgms_get_explog_switch() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_explog_switch()); + return rcpp_result_gen; +END_RCPP +} +// rcpp_ieee754_exp +Rcpp::NumericVector rcpp_ieee754_exp(Rcpp::NumericVector x); +RcppExport SEXP _bgms_rcpp_ieee754_exp(SEXP xSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP); + rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_exp(x)); + return rcpp_result_gen; +END_RCPP +} +// rcpp_ieee754_log +Rcpp::NumericVector rcpp_ieee754_log(Rcpp::NumericVector x); +RcppExport SEXP _bgms_rcpp_ieee754_log(SEXP xSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::NumericVector >::type x(xSEXP); + rcpp_result_gen = Rcpp::wrap(rcpp_ieee754_log(x)); return rcpp_result_gen; END_RCPP } @@ -132,20 +164,6 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// compute_Vn_mfm_sbm -arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda); -RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< arma::uword >::type no_variables(no_variablesSEXP); - Rcpp::traits::input_parameter< double >::type dirichlet_alpha(dirichlet_alphaSEXP); - Rcpp::traits::input_parameter< arma::uword >::type t_max(t_maxSEXP); - Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP); - rcpp_result_gen = Rcpp::wrap(compute_Vn_mfm_sbm(no_variables, dirichlet_alpha, t_max, lambda)); - return rcpp_result_gen; -END_RCPP -} // runMCMC_parallel void runMCMC_parallel(int nChains, int nIter, int nWarmup, int progress_type, bool useUnicode); RcppExport SEXP _bgms_runMCMC_parallel(SEXP nChainsSEXP, SEXP nIterSEXP, SEXP nWarmupSEXP, SEXP progress_typeSEXP, SEXP useUnicodeSEXP) { @@ -160,14 +178,31 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// compute_Vn_mfm_sbm +arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda); +RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::uword >::type no_variables(no_variablesSEXP); + Rcpp::traits::input_parameter< double >::type dirichlet_alpha(dirichlet_alphaSEXP); + Rcpp::traits::input_parameter< arma::uword >::type t_max(t_maxSEXP); + Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP); + rcpp_result_gen = Rcpp::wrap(compute_Vn_mfm_sbm(no_variables, dirichlet_alpha, t_max, lambda)); + return rcpp_result_gen; +END_RCPP +} static const R_CallMethodDef CallEntries[] = { - {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 33}, + {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 35}, {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32}, + {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, + {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, + {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8}, - {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {"_bgms_runMCMC_parallel", (DL_FUNC) &_bgms_runMCMC_parallel, 5}, + {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; From 33d55f934e5a19df05f4be6555604afb2b3b1051 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 22 Sep 2025 15:08:40 +0200 Subject: [PATCH 4/6] almost done --- R/RcppExports.R | 8 +- R/bgm.R | 14 +-- R/bgmCompare.R | 8 +- R/function_input_utils.R | 13 +- src/Makevars.win | 2 +- src/RcppExports.cpp | 24 +--- src/bgmCompare_parallel.cpp | 21 +++- src/bgmCompare_sampler.cpp | 19 +-- src/bgmCompare_sampler.h | 6 +- src/bgm_parallel.cpp | 14 +-- src/bgm_sampler.cpp | 7 +- src/bgm_sampler.h | 2 +- src/progress_manager.cpp | 229 +++++++++++++++++++++--------------- src/progress_manager.h | 10 +- src/sampler_output.h | 1 + 15 files changed, 212 insertions(+), 166 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index b010d6a6..3918c7c8 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs) { - .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs) +run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) { + .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) } <<<<<<< HEAD @@ -34,10 +34,6 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact .Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter) } -runMCMC_parallel <- function(nChains = 4L, nIter = 100L, nWarmup = 100L, progress_type = 2L, useUnicode = FALSE) { - invisible(.Call(`_bgms_runMCMC_parallel`, nChains, nIter, nWarmup, progress_type, useUnicode)) -} - compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) } diff --git a/R/bgm.R b/R/bgm.R index 3c20206c..e9a99a75 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -237,7 +237,7 @@ bgm = function( dirichlet_alpha = 1, lambda = 1, na_action = c("listwise", "impute"), - display_progress = c("per-chain", "total", "none"), + display_progress = c("per-chain", "total", "none"), update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), target_accept, hmc_num_leapfrogs = 100, @@ -324,15 +324,7 @@ bgm = function( ".")) #Check display_progress ------------------------------------------------------ - if (is.logical(display_progress) && length(display_progress) == 1) { - if (is.na(display_progress)) - stop("The display_progress argument must be a single logical value, but not NA.") - display_progress = if (display_progress) "per-chain" else "none" - } else { - display_progress = match.arg(display_progress) - } - - progress_type = if (display_progress == "per-chain") 2L else if (display_progress == "total") 1L else 0L + progress_type = progress_type_from_display_progress(display_progress) #Format the data input ------------------------------------------------------- data = reformat_data(x = x, @@ -465,7 +457,7 @@ bgm = function( output$nuts_diag = nuts_diag } - userInterrupt = any(vapply(out, `[[`, "userInterrupt", logical(1L))) + userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) if (userInterrupt) warning("Stopped sampling after user interrupt, results are likely uninterpretable.") diff --git a/R/bgmCompare.R b/R/bgmCompare.R index db91ceb4..9f102ae8 100644 --- a/R/bgmCompare.R +++ b/R/bgmCompare.R @@ -122,7 +122,7 @@ bgmCompare = function( iter = 1e3, burnin = 1e3, na_action = c("listwise", "impute"), - display_progress = TRUE, + display_progress = c("per-chain", "total", "none"), update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"), target_accept, hmc_num_leapfrogs = 100, @@ -202,7 +202,8 @@ bgmCompare = function( } # Check display_progress - display_progress = check_logical(display_progress, "display_progress") + progress_type = progress_type_from_display_progress(display_progress) + ## Format data data = compare_reformat_data( @@ -340,7 +341,8 @@ bgmCompare = function( inclusion_probability = model$inclusion_probability_difference, num_chains = chains, nThreads = cores, seed = seed, - update_method = update_method, hmc_num_leapfrogs = hmc_num_leapfrogs + update_method = update_method, hmc_num_leapfrogs = hmc_num_leapfrogs, + progress_type = progress_type ) # Main output handler in the wrapper function diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 2bb2a3c9..c41daab6 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -463,4 +463,15 @@ check_compare_model = function( inclusion_probability_difference = inclusion_probability_difference ) ) -} \ No newline at end of file +} + +progress_type_from_display_progress <- function(display_progress = c("per-chain", "total", "none")) { + if (is.logical(display_progress) && length(display_progress) == 1) { + if (is.na(display_progress)) + stop("The display_progress argument must be a single logical value, but not NA.") + display_progress = if (display_progress) "per-chain" else "none" + } else { + display_progress = match.arg(display_progress) + } + return(if (display_progress == "per-chain") 2L else if (display_progress == "total") 1L else 0L) +} diff --git a/src/Makevars.win b/src/Makevars.win index bfa74448..7d9a2d0f 100644 --- a/src/Makevars.win +++ b/src/Makevars.win @@ -1,4 +1,4 @@ -CXX_STD = CXX14 +CXX_STD = CXX20 PKG_CPPFLAGS = \ $(shell "$(R_HOME)\bin\Rscript.exe" -e "cat(RcppParallel::CxxFlags())") diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 4fd77d61..9fa0affe 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -12,8 +12,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // run_bgmCompare_parallel -Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& counts_per_category, const std::vector& blume_capel_stats, const std::vector& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs); -RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP) { +Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& counts_per_category, const std::vector& blume_capel_stats, const std::vector& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs, int progress_type); +RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -52,7 +52,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const std::string& >::type update_method(update_methodSEXP); Rcpp::traits::input_parameter< int >::type hmc_num_leapfrogs(hmc_num_leapfrogsSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs)); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type)); return rcpp_result_gen; END_RCPP } @@ -164,20 +165,6 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// runMCMC_parallel -void runMCMC_parallel(int nChains, int nIter, int nWarmup, int progress_type, bool useUnicode); -RcppExport SEXP _bgms_runMCMC_parallel(SEXP nChainsSEXP, SEXP nIterSEXP, SEXP nWarmupSEXP, SEXP progress_typeSEXP, SEXP useUnicodeSEXP) { -BEGIN_RCPP - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< int >::type nChains(nChainsSEXP); - Rcpp::traits::input_parameter< int >::type nIter(nIterSEXP); - Rcpp::traits::input_parameter< int >::type nWarmup(nWarmupSEXP); - Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); - Rcpp::traits::input_parameter< bool >::type useUnicode(useUnicodeSEXP); - runMCMC_parallel(nChains, nIter, nWarmup, progress_type, useUnicode); - return R_NilValue; -END_RCPP -} // compute_Vn_mfm_sbm arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda); RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) { @@ -194,14 +181,13 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 35}, + {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36}, {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32}, {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8}, - {"_bgms_runMCMC_parallel", (DL_FUNC) &_bgms_runMCMC_parallel, 5}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/bgmCompare_parallel.cpp b/src/bgmCompare_parallel.cpp index b9b0648c..577449e1 100644 --- a/src/bgmCompare_parallel.cpp +++ b/src/bgmCompare_parallel.cpp @@ -6,6 +6,9 @@ #include #include #include "rng_utils.h" +#include "progress_manager.h" +#include "sampler_output.h" +#include "mcmc_adaptation.h" using namespace Rcpp; using namespace RcppParallel; @@ -130,6 +133,7 @@ struct GibbsCompareChainRunner : public Worker { const std::vector& chain_rngs; const std::string& update_method; const int hmc_num_leapfrogs; + ProgressManager& pm; // output std::vector& results; @@ -167,6 +171,7 @@ struct GibbsCompareChainRunner : public Worker { const std::vector& chain_rngs, const std::string& update_method, const int hmc_num_leapfrogs, + ProgressManager& pm, std::vector& results ) : observations(observations), @@ -202,6 +207,7 @@ struct GibbsCompareChainRunner : public Worker { chain_rngs(chain_rngs), update_method(update_method), hmc_num_leapfrogs(hmc_num_leapfrogs), + pm(pm), results(results) {} @@ -257,7 +263,8 @@ struct GibbsCompareChainRunner : public Worker { inclusion_probability, rng, update_method, - hmc_num_leapfrogs + hmc_num_leapfrogs, + pm ); out.result = result; @@ -376,7 +383,8 @@ Rcpp::List run_bgmCompare_parallel( int nThreads, int seed, const std::string& update_method, - int hmc_num_leapfrogs + int hmc_num_leapfrogs, + int progress_type ) { std::vector results(num_chains); @@ -386,6 +394,10 @@ Rcpp::List run_bgmCompare_parallel( chain_rngs[c] = SafeRNG(seed + c); } + // only used to determine the total no. burnin iterations, a bit hacky + WarmupSchedule warmup_schedule_temp(burnin, difference_selection, (update_method != "adaptive-metropolis")); + int total_burnin = warmup_schedule_temp.total_burnin; + ProgressManager pm(num_chains, iter, total_burnin, 50, progress_type); GibbsCompareChainRunner worker( observations, num_groups, @@ -397,7 +409,7 @@ Rcpp::List run_bgmCompare_parallel( pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, chain_rngs, update_method, hmc_num_leapfrogs, - results + pm, results ); { @@ -426,9 +438,12 @@ Rcpp::List run_bgmCompare_parallel( if (r.has_indicator) { chain_out["indicator_samples"] = r.indicator_samples; } + chain_out["userInterrupt"] = r.userInterrupt; output[i] = chain_out; } } + pm.finish(); + return output; } \ No newline at end of file diff --git a/src/bgmCompare_sampler.cpp b/src/bgmCompare_sampler.cpp index 264898ac..8ddf730b 100644 --- a/src/bgmCompare_sampler.cpp +++ b/src/bgmCompare_sampler.cpp @@ -13,6 +13,7 @@ #include "rng_utils.h" #include "sampler_output.h" #include +#include "progress_manager.h" using namespace Rcpp; @@ -1576,7 +1577,8 @@ SamplerOutput run_gibbs_sampler_bgmCompare( arma::mat inclusion_probability, SafeRNG& rng, const std::string& update_method, - const int hmc_num_leapfrogs + const int hmc_num_leapfrogs, + ProgressManager& pm ) { // --- Setup: dimensions and storage structures const int num_variables = observations.n_cols; @@ -1645,14 +1647,13 @@ SamplerOutput run_gibbs_sampler_bgmCompare( const int print_every = std::max(1, total_iter / 10); // --- Main Gibbs sampling loop + bool userInterrupt = false; for (int iteration = 0; iteration < total_iter; iteration++) { - if (iteration % print_every == 0) { - tbb::mutex::scoped_lock lock(get_print_mutex()); - std::cout - << "[bgm] chain " << chain_id - << " iteration " << iteration - << " / " << total_iter - << std::endl; + + pm.update(chain_id - 1); + if (pm.shouldExit()) { + userInterrupt = true; + break; } // Shuffle update order of edge indices @@ -1753,5 +1754,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( } else { out.indicator_samples = arma::imat(); } + out.userInterrupt = userInterrupt; + return out; } \ No newline at end of file diff --git a/src/bgmCompare_sampler.h b/src/bgmCompare_sampler.h index 0672d3fb..132bd957 100644 --- a/src/bgmCompare_sampler.h +++ b/src/bgmCompare_sampler.h @@ -1,10 +1,11 @@ #pragma once #include -#include "sampler_output.h" #include +struct SamplerOutput; struct SafeRNG; +class ProgressManager; SamplerOutput run_gibbs_sampler_bgmCompare( int chain_id, @@ -40,5 +41,6 @@ SamplerOutput run_gibbs_sampler_bgmCompare( arma::mat inclusion_probability, SafeRNG& rng, const std::string& update_method, - const int hmc_num_leapfrogs + const int hmc_num_leapfrogs, + ProgressManager& pm ); \ No newline at end of file diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index 0f168d9c..b3ac1447 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -124,8 +124,8 @@ struct GibbsChainRunner : public Worker { int nuts_max_depth, bool learn_mass_matrix, const std::vector& chain_rngs, - std::vector& results, - ProgressManager& pm + ProgressManager& pm, + std::vector& results ) : observations(observations), num_categories(num_categories), @@ -156,8 +156,8 @@ struct GibbsChainRunner : public Worker { nuts_max_depth(nuts_max_depth), learn_mass_matrix(learn_mass_matrix), chain_rngs(chain_rngs), - results(results), - pm(pm) + pm(pm), + results(results) {} void operator()(std::size_t begin, std::size_t end) { @@ -311,8 +311,8 @@ Rcpp::List run_bgm_parallel( // only used to determine the total no. burnin iterations, a bit hacky WarmupSchedule warmup_schedule_temp(burnin, edge_selection, (update_method != "adaptive-metropolis")); int total_burnin = warmup_schedule_temp.total_burnin; + ProgressManager pm(num_chains, iter, total_burnin, 50, progress_type); - ProgressManager pm(num_chains, iter + total_burnin, total_burnin, 50, progress_type); GibbsChainRunner worker( observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, @@ -320,8 +320,8 @@ Rcpp::List run_bgm_parallel( counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, - sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, - chain_rngs, results, pm + pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, + chain_rngs, pm, results ); { diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index ae5d93a8..0fb2f10b 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -10,14 +10,9 @@ #include "mcmc_nuts.h" #include "mcmc_rwm.h" #include "mcmc_utils.h" -<<<<<<< HEAD -#include "print_mutex.h" #include "sbm_edge_prior.h" -======= -// #include "print_mutex.h" -#include "gibbs_functions_edge_prior.h" ->>>>>>> 9f0ad36 (parallel progress bar and interrupt) #include "rng_utils.h" +#include "progress_manager.h" using namespace Rcpp; diff --git a/src/bgm_sampler.h b/src/bgm_sampler.h index 6c3c1fe7..42234481 100644 --- a/src/bgm_sampler.h +++ b/src/bgm_sampler.h @@ -1,8 +1,8 @@ #pragma once #include -#include "progress_manager.h" // forward declaration struct SafeRNG; +class ProgressManager; Rcpp::List run_gibbs_sampler_bgm( int chain_id, diff --git a/src/progress_manager.cpp b/src/progress_manager.cpp index 6191dede..f7a43836 100644 --- a/src/progress_manager.cpp +++ b/src/progress_manager.cpp @@ -16,7 +16,7 @@ ProgressManager::ProgressManager(int nChains_, int nIter_, int nWarmup_, int pri isRStudio = Rcpp::as(s) == "1"; no_spaces_for_total = 3 + static_cast(std::log10(nChains)); - if (progress_type == 1) no_spaces_for_total = 0; // no total line + if (progress_type == 1) no_spaces_for_total = 1; // no need to align, so one space is fine total_padding = std::string(no_spaces_for_total, ' '); if (isRStudio) { @@ -69,8 +69,9 @@ void ProgressManager::update(int chainId) { // Check for user interrupts and console width changes less frequently to reduce overhead if (chainId == 0 && progress[chainId] % (printEvery * 5) == 0) { needsToExit = checkInterrupt(); + // would be a nice feature, but not working atm // Also check for console width changes occasionally - checkConsoleWidthChange(); + // checkConsoleWidthChange(); } } @@ -114,7 +115,7 @@ void ProgressManager::checkConsoleWidthChange() { } } -int ProgressManager::getConsoleWidth() { +int ProgressManager::getConsoleWidth() const { Rcpp::Environment base("package:base"); Rcpp::Function getOption = base["getOption"]; SEXP s = getOption("width", 0); @@ -123,7 +124,7 @@ int ProgressManager::getConsoleWidth() { return width + 3; } -std::string ProgressManager::formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal) { +std::string ProgressManager::formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal) const { std::ostringstream builder; double exactFilled = fraction * barWidth; @@ -157,6 +158,9 @@ std::string ProgressManager::formatProgressBar(int chainId, int current, int tot progressBar += rhsToken; + // store the current length of the progress bar without any additional text + size_t currentWidth = progressBar.length(); + if (isTotal) { std::string warmupOrSampling = isWarmupPhase() ? "(Warmup)" : "(Sampling)"; @@ -170,57 +174,98 @@ std::string ProgressManager::formatProgressBar(int chainId, int current, int tot << " (" << std::fixed << std::setprecision(1) << fraction * 100 << "%)"; } - return builder.str(); + std::string output = builder.str(); + + if (isRStudio && progress_type == 2) { + + currentWidth = output.length() + 2 + barWidth - currentWidth; + // Pad each line to exactly lineWidth characters (before adding \n) + if (currentWidth < lineWidth) { + output += std::string(lineWidth - currentWidth, ' '); + } else if (currentWidth > lineWidth) { + output = output.substr(0, lineWidth); + } + } + + return output; } -std::string ProgressManager::formatTimeInfo(int elapsed, int eta) { +// std::string ProgressManager::formatTimeInfo(double elapsed, double eta) { +// std::ostringstream builder; +// builder << "Elapsed: " << elapsed << "s | ETA: " << eta << "s"; +// return builder.str(); +// } + +std::string ProgressManager::formatTimeInfo(double elapsed, double eta) const { std::ostringstream builder; - builder << "Elapsed: " << elapsed << "s | ETA: " << eta << "s"; + builder << "Elapsed: " << formatDuration(elapsed) << " | ETA: " << formatDuration(eta); return builder.str(); } +// Add this helper function to the class +std::string ProgressManager::formatDuration(double seconds) const { + if (seconds < 0) { + return "0s"; + } + + // Convert to different units + if (seconds < 60) { + // Less than 1 minute: show seconds + return std::to_string(static_cast(std::round(seconds))) + "s"; + } + else if (seconds < 3600) { + // Less than 1 hour: show minutes and seconds + int mins = static_cast(seconds / 60); + int secs = static_cast(seconds) % 60; + if (secs == 0) { + return std::to_string(mins) + "m"; + } else { + return std::to_string(mins) + "m " + std::to_string(secs) + "s"; + } + } + else if (seconds < 86400) { + // Less than 1 day: show hours and minutes + int hours = static_cast(seconds / 3600); + int mins = static_cast((seconds - hours * 3600) / 60); + if (mins == 0) { + return std::to_string(hours) + "h"; + } else { + return std::to_string(hours) + "h " + std::to_string(mins) + "m"; + } + } + else { + // 1 day or more: show days and hours + int days = static_cast(seconds / 86400); + int hours = static_cast((seconds - days * 86400) / 3600); + if (hours == 0) { + return std::to_string(days) + "d"; + } else { + return std::to_string(days) + "d " + std::to_string(hours) + "h"; + } + } +} + void ProgressManager::setupTheme() { + // should be a struct of some kind... if (useUnicode) { // Unicode theme - lhsToken = "❨"; - rhsToken = "❩"; - // filledToken = "\x1b[34m━\x1b[0m"; // Blue filled - // emptyToken = "\x1b[37m━\x1b[0m"; // Gray empty - // partialTokenMore = "\x1b[34m╸\x1b[0m"; // Blue partial (> 0.5) - // partialTokenLess = "\x1b[37m╺\x1b[0m"; // Gray partial (< 0.5) - filledToken = "━"; // Blue filled - emptyToken = " "; // Gray empty - partialTokenMore = "╸"; // Blue partial (> 0.5) - partialTokenLess = " "; // Gray partial (< 0.5) + lhsToken = "⦗"; + rhsToken = "⦘"; + filledToken = "\033[38;5;73m━\033[39m"; // Blue filled + partialTokenMore = "\033[38;5;73m━\033[39m"; // Blue partial (> 0.5) + partialTokenLess = "\033[37m╺\033[39m"; // Gray partial (< 0.5) + emptyToken = "\033[37m━\033[39m"; // Gray empty } else { // Classic theme - lhsToken = "["; - rhsToken = "]"; - filledToken = "="; - emptyToken = " "; + lhsToken = "["; + rhsToken = "]"; + filledToken = "="; + emptyToken = " "; partialTokenMore = " "; partialTokenLess = " "; } } -size_t ProgressManager::getVisualLength(const std::string& str) { - size_t visualLength = 0; - bool inEscapeSequence = false; - - for (size_t i = 0; i < str.length(); i++) { - if (str[i] == '\x1b' && i + 1 < str.length() && str[i + 1] == '[') { - inEscapeSequence = true; - i++; // Skip the '[' - } else if (inEscapeSequence && str[i] == 'm') { - inEscapeSequence = false; - } else if (!inEscapeSequence) { - visualLength++; - } - } - - return visualLength; -} - void ProgressManager::print() { std::lock_guard lock(printMutex); @@ -234,57 +279,54 @@ void ProgressManager::print() { double eta = (fracTotal > 0) ? elapsed / fracTotal - elapsed : 0.0; std::ostringstream out; - int totalChars = 0; - int lineIndex = 0; + // int totalChars = 0; // if this is not the first print, delete previous content - if (lastPrintedChars > 0) { + if (progress_type == 2) { - if (isRStudio) { - out << "\x1b[" << std::to_string(lastPrintedChars) << "D"; - } else { - // Move cursor up to start of our content and clear everything - for (int i = 0; i < lastPrintedLines; i++) { - out << "\x1b[1A\x1b[2K"; // Move up one line and clear entire line + if (lastPrintedChars > 0) { + + if (isRStudio) { + out << "\x1b[" << std::to_string((1 + lineWidth) * lastPrintedLines) << "D"; + // out << "\x1b[" << std::to_string(lastPrintedChars) << "D"; + } else { + // Move cursor up to start of our content and clear everything + for (int i = 0; i < lastPrintedLines; i++) { + out << "\x1b[1A\x1b[2K"; // Move up one line and clear entire line + } } } - } - if (progress_type == 2) { // Print progress for each chain for (int i = 0; i < nChains; i++) { double frac = double(progress[i]) / nIter; std::string chainProgress = formatProgressBar(i + 1, progress[i], nIter, frac); - maybePadToLength(chainProgress); out << chainProgress << "\n"; - totalChars += chainProgress.length() + 1; // +1 for newline + // totalChars += chainProgress.length() + 1; // +1 for newline } // Print total progress std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); - maybePadToLength(totalProgress); out << totalProgress << "\n"; - totalChars += totalProgress.length() + 1; // +1 for newline + // totalChars += totalProgress.length() + 1; // +1 for newline // Print time info - std::string timeInfo = formatTimeInfo(int(elapsed), int(eta)); + std::string timeInfo = formatTimeInfo(elapsed, eta); maybePadToLength(timeInfo); out << timeInfo << "\n"; - totalChars += timeInfo.length() + 1; // +1 for newline + // totalChars += timeInfo.length() + 1; // +1 for newline // Track total lines printed (chains + total + time) lastPrintedLines = nChains + 2; // used in a generic terminal - lastPrintedChars = totalChars; // used by RStudio + lastPrintedChars = 1;//totalChars; // used by RStudio } else if (progress_type == 1) { // Print total progress std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); - maybePadToLength(totalProgress); // Print time info - totalProgress += " " + formatTimeInfo(int(elapsed), int(eta)); - maybePadToLength(totalProgress); + totalProgress += " " + formatTimeInfo(elapsed, eta); if (done < totalWork) { out << totalProgress << "\r"; @@ -326,47 +368,48 @@ void ProgressManager::maybePadToLength(std::string& content) const { // Example usage/ test with RcppParallel -#include -// Worker functor for RcppParallel -struct ChainWorker : public RcppParallel::Worker { - int nIter; - ProgressManager ± - bool display_progress; +// #include +// // Worker functor for RcppParallel +// struct ChainWorker : public RcppParallel::Worker { +// int nIter; +// ProgressManager ± +// bool display_progress; - ChainWorker(int nIter_, ProgressManager &pm_, bool display_progress_) - : nIter(nIter_), pm(pm_), display_progress(display_progress_) {} +// ChainWorker(int nIter_, ProgressManager &pm_, bool display_progress_) +// : nIter(nIter_), pm(pm_), display_progress(display_progress_) {} - void operator()(std::size_t begin, std::size_t end) { +// void operator()(std::size_t begin, std::size_t end) { - auto chainId = begin; +// auto chainId = begin; - for (int i = 0; i < nIter; i++) { - // ---- Simulated work ---- - std::this_thread::sleep_for(std::chrono::milliseconds(20)); +// for (int i = 0; i < nIter; i++) { +// // ---- Simulated work ---- +// std::this_thread::sleep_for(std::chrono::milliseconds(20)); - // ---- Update state ---- - pm.update(chainId); - if (pm.shouldExit()) break; - // if (Progress::check_abort()) Rcpp::checkUserInterrupt(); - } - } -}; +// // ---- Update state ---- +// pm.update(chainId); +// if (pm.shouldExit()) break; +// // if (Progress::check_abort()) Rcpp::checkUserInterrupt(); +// } +// } +// }; -// [[Rcpp::export]] -void runMCMC_parallel(int nChains = 4, int nIter = 100, int nWarmup = 100, int progress_type = 2, bool useUnicode = false) { +// // [[Rcpp::export]] +// void runMCMC_parallel(int nChains = 4, int nIter = 100, int nWarmup = 100, int progress_type = 2, bool useUnicode = false) { - int nTotal = nIter + nWarmup; - ProgressManager pm(nChains, nTotal, nWarmup, 10, progress_type, useUnicode); - ChainWorker worker(nTotal, pm, true); +// ProgressManager pm(nChains, nIter, nWarmup, 10, progress_type, useUnicode); +// ChainWorker worker(nIter + nWarmup, pm, true); - // Run each chain in parallel - RcppParallel::parallelFor(0, nChains, worker); +// // Run each chain in parallel +// RcppParallel::parallelFor(0, nChains, worker); - if (pm.shouldExit()) { - Rcpp::Rcout << "\nComputation interrupted by user.\n"; - } else { - Rcpp::Rcout << "\nAll chains finished!\n"; - } -} +// pm.finish(); + +// if (pm.shouldExit()) { +// Rcpp::Rcout << "\nComputation interrupted by user.\n"; +// } else { +// Rcpp::Rcout << "\nAll chains finished!\n"; +// } +// } diff --git a/src/progress_manager.h b/src/progress_manager.h index 375bb2a5..80ac526a 100644 --- a/src/progress_manager.h +++ b/src/progress_manager.h @@ -46,7 +46,7 @@ class ProgressManager { public: - ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_ = 10, int progress_type = 2, bool useUnicode_ = false); + ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_ = 10, int progress_type = 2, bool useUnicode_ = true); void update(int chainId); void finish(); bool shouldExit() const; @@ -54,11 +54,11 @@ class ProgressManager { private: void checkConsoleWidthChange(); - int getConsoleWidth(); - std::string formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal = false); - std::string formatTimeInfo(int elapsed, int eta); + int getConsoleWidth() const; + std::string formatProgressBar(int chainId, int current, int total, double fraction, bool isTotal = false) const; + std::string formatTimeInfo(double elapsed, double eta) const; + std::string formatDuration(double seconds) const; void setupTheme(); - size_t getVisualLength(const std::string& str); bool isWarmupPhase() const { for (auto c : progress) diff --git a/src/sampler_output.h b/src/sampler_output.h index f36d8e1d..8f899d00 100644 --- a/src/sampler_output.h +++ b/src/sampler_output.h @@ -30,6 +30,7 @@ struct SamplerOutput { arma::vec energy_samples; int chain_id; bool has_indicator; + bool userInterrupt; }; #endif From 029e7cfe663c59642839ed82e45b4da8268bbdf9 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 22 Sep 2025 15:09:55 +0200 Subject: [PATCH 5/6] cleanup --- ppt2.R | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 ppt2.R diff --git a/ppt2.R b/ppt2.R deleted file mode 100644 index 60ec82d8..00000000 --- a/ppt2.R +++ /dev/null @@ -1,4 +0,0 @@ -Rcpp::sourceCpp("src/progress_manager.cpp") -runMCMC_parallel(3, 500) -runMCMC_parallel(3, 500, useUnicode = FALSE) -runMCMC_parallel(3, 500, useUnicode = FALSE, display_progress = TRUE) \ No newline at end of file From 9c77cf2bfe6ab15b6ca03761486beadc4ebaf583 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 22 Sep 2025 15:14:40 +0200 Subject: [PATCH 6/6] fix rebase again --- R/RcppExports.R | 5 ----- 1 file changed, 5 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 3918c7c8..69628ca3 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -5,13 +5,8 @@ run_bgmCompare_parallel <- function(observations, num_groups, counts_per_categor .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) } -<<<<<<< HEAD -run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) { - .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) -======= run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) { .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) ->>>>>>> eaa3c70 (fix rebase) } get_explog_switch <- function() {