Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/mppi/controllers/controller.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ public:
{
bool change_seed = p.seed_ != params_.seed_;
bool change_num_timesteps = p.num_timesteps_ != params_.num_timesteps_;
bool change_dt = p.dt_ != params_.dt_;
// bool change_std_dev = p.control_std_dev_ != params_.control_std_dev_;
params_ = p;
if (change_num_timesteps)
Expand All @@ -847,6 +848,10 @@ public:
{
setSeedCUDARandomNumberGen(params_.seed_);
}
if (change_dt)
{
fb_controller_->setDt(p.dt_);
}
}

int getNumIters() const
Expand Down
28 changes: 28 additions & 0 deletions include/mppi/core/base_plant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class BasePlant
using COST_PARAMS_T = typename COST_T::COST_PARAMS_T;
using TEMPLATED_CONTROLLER = CONTROLLER_T;
using CONTROLLER_PARAMS_T = typename CONTROLLER_T::TEMPLATED_PARAMS;
using SAMPLER_T = typename CONTROLLER_T::TEMPLATED_SAMPLING;
using SAMPLER_PARAMS_T = typename CONTROLLER_T::TEMPLATED_SAMPLING_PARAMS;

// Feedback related aliases
using FB_STATE_T = typename CONTROLLER_T::TEMPLATED_FEEDBACK::TEMPLATED_FEEDBACK_STATE;
Expand All @@ -61,10 +63,13 @@ class BasePlant
std::mutex cost_params_guard_;
CONTROLLER_PARAMS_T controller_params_;
std::mutex controller_params_guard_;
SAMPLER_PARAMS_T sampler_params_;
std::mutex sampler_params_guard_;

std::atomic<bool> has_new_dynamics_params_{ false };
std::atomic<bool> has_new_cost_params_{ false };
std::atomic<bool> has_new_controller_params_{ false };
std::atomic<bool> has_new_sampler_params_{ false };
std::atomic<bool> enabled_{ false };

// Values needed
Expand Down Expand Up @@ -332,6 +337,10 @@ class BasePlant
{
return has_new_controller_params_;
};
virtual bool hasNewSamplerParams()
{
return has_new_sampler_params_;
};

virtual DYN_PARAMS_T getNewDynamicsParams(bool set_flag = false)
{
Expand All @@ -348,6 +357,11 @@ class BasePlant
has_new_controller_params_ = set_flag;
return controller_params_;
}
virtual SAMPLER_PARAMS_T getNewSamplerParams(bool set_flag = false)
{
has_new_sampler_params_ = set_flag;
return sampler_params_;
}

virtual void setDynamicsParams(const DYN_PARAMS_T& params)
{
Expand All @@ -367,6 +381,12 @@ class BasePlant
controller_params_ = params;
has_new_controller_params_ = true;
}
virtual void setSamplerParams(const SAMPLER_PARAMS_T& params)
{
std::lock_guard<std::mutex> guard(sampler_params_guard_);
sampler_params_ = params;
has_new_sampler_params_ = true;
}

virtual void setLogger(const mppi::util::MPPILoggerPtr& logger)
{
Expand Down Expand Up @@ -423,6 +443,14 @@ class BasePlant
CONTROLLER_PARAMS_T controller_params = getNewControllerParams();
controller_->setParams(controller_params);
}
// Update sampler params
if (hasNewSamplerParams())
{
std::lock_guard<std::mutex> guard(sampler_params_guard_);
changed = true;
SAMPLER_PARAMS_T sampler_params = getNewSamplerParams();
controller_->setSamplingParams(sampler_params);
}
return changed;
}

Expand Down
150 changes: 93 additions & 57 deletions include/mppi/utils/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
* Created by Bogdan on 11/01/2023
*/

#include <cstdarg>
#include <cstdio>
#include <vector>
#include <memory>
#include <string>
#include <vector>

namespace mppi
{
Expand Down Expand Up @@ -89,91 +89,127 @@ class MPPILogger
}

/**
* @brief Log debug messages to the output stream in green if the log level is set for DEBUG
* @param fmt Format string (if additional arguments are passed) or message to display
* @brief Log debug messages using virtual debug_impl() method
*
* @tparam ...Args variadic template type of args used in the format string
* @param fmt format string used in printf
* @param args extra args used by the format string fmt
*/
template <typename... Args>
void debug(const char* fmt, Args const&... args)
{
std::string message = format_string(fmt, args...);
this->debug_impl(message);
}

/**
* @brief Log info messages using virtual info_impl() method
*
* @tparam ...Args variadic template type of args used in the format string
* @param fmt format string used in printf
* @param args extra args used by the format string fmt
*/
template <typename... Args>
void info(const char* fmt, Args const&... args)
{
std::string message = format_string(fmt, args...);
this->info_impl(message);
}

/**
* @brief Log warning messages using virtual warning_impl() method
*
* @tparam ...Args variadic template type of args used in the format string
* @param fmt format string used in printf
* @param args extra args used by the format string fmt
*/
virtual void debug(const char* fmt, ...)
template <typename... Args>
void warning(const char* fmt, Args const&... args)
{
std::string message = format_string(fmt, args...);
this->warning_impl(message);
}

/**
* @brief Log errror messages using virtual errror_impl() method
*
* @tparam ...Args variadic template type of args used in the format string
* @param fmt format string used in printf
* @param args extra args used by the format string fmt
*/
template <typename... Args>
void error(const char* fmt, Args const&... args)
{
std::string message = format_string(fmt, args...);
this->error_impl(message);
}

protected:
LOG_LEVEL log_level_ = GLOBAL_LOG_LEVEL;
std::FILE* output_stream_ = stdout;

virtual void debug_impl(const std::string& message)
{
if (log_level_ <= LOG_LEVEL::DEBUG)
{
std::va_list argptr;
va_start(argptr, fmt);
surround_fprintf(output_stream_, GREEN, RESET, fmt, argptr);
va_end(argptr);
surround_fprintf(output_stream_, GREEN, RESET, message);
}
}

/**
* @brief Log info messages to the output stream in cyan if the log level is set for INFO
* @param fmt Format string (if additional arguments are passed) or message to display
*/
virtual void info(const char* fmt, ...)
virtual void info_impl(const std::string& message)
{
if (log_level_ <= LOG_LEVEL::INFO)
{
std::va_list argptr;
va_start(argptr, fmt);
surround_fprintf(output_stream_, CYAN, RESET, fmt, argptr);
va_end(argptr);
surround_fprintf(output_stream_, CYAN, RESET, message);
}
}

/**
* @brief Log debug messages to the output stream in yellow if the log level is set for WARNING
* @param fmt Format string (if additional arguments are passed) or message to display
*/
virtual void warning(const char* fmt, ...)
virtual void warning_impl(const std::string& message)
{
if (log_level_ <= LOG_LEVEL::WARNING)
{
std::va_list argptr;
va_start(argptr, fmt);
surround_fprintf(output_stream_, YELLOW, RESET, fmt, argptr);
va_end(argptr);
surround_fprintf(output_stream_, YELLOW, RESET, message);
}
}

/**
* @brief Log debug messages to the output stream in red if the log level is set for ERROR
* @param fmt Format string (if additional arguments are passed) or message to display
*/
virtual void error(const char* fmt, ...)
virtual void error_impl(const std::string& message)
{
if (log_level_ <= LOG_LEVEL::ERROR)
{
std::va_list argptr;
va_start(argptr, fmt);
surround_fprintf(output_stream_, RED, RESET, fmt, argptr);
va_end(argptr);
surround_fprintf(output_stream_, RED, RESET, message);
}
}

protected:
LOG_LEVEL log_level_ = GLOBAL_LOG_LEVEL;
std::FILE* output_stream_ = stdout;
/**
* @brief Print message to stream with coloring defined by prefix
*
* @param fstream where the message will be printed to
* @param prefix prefix string to print before message. Expected to be a color code
* @param suffix suffix string to print after message. Expected to be a color reset code
* @param message actual message to be printed
*/
virtual void surround_fprintf(std::FILE* fstream, const char* prefix, const char* suffix, const std::string& message)
{
std::fprintf(fstream, "%s%s%s", prefix, message.c_str(), suffix);
}

/**
* @brief Prints a colored output to a provided fstream. It does this by first creating the formatted string
* as a std::vector<char> so that it can be used as an input to fprintf with a different format string
* @brief create a string out of format string and variable number of additional arguments
*
* @tparam ...Args variadic template type for extra arguments passed to format_string()
* @param fmt format string defining how to display additional arguments
* @param args additional arguments for formatting
*
* @param fstream file stream to write output to
* @param color color code to use on provided string
* @param fmt format string
* @param ... extra variables for format string
* @return std::string containing formatted text
*/
virtual void surround_fprintf(std::FILE* fstream, const char* prefix, const char* suffix, const char* fmt,
std::va_list args)
template <typename... Args>
std::string format_string(const char* fmt, Args const&... args)
{
// introducing a second copy of the args as calling vsnprintf leaves args in an indeterminate state
std::va_list args_cpy;
va_copy(args_cpy, args);
// figure out size of formatted string, also uses up args
std::vector<char> buf(1 + std::vsnprintf(nullptr, 0, fmt, args));
// Fill buffer with formatted string using copy of the args
std::vsnprintf(buf.data(), buf.size(), fmt, args_cpy);
va_end(args_cpy);
// print formatted string but colored
std::fprintf(fstream, "%s%s%s", prefix, buf.data(), suffix);
// figure out size of formatted string
std::vector<char> buf(1 + std::snprintf(nullptr, 0, fmt, args...));
// Fill buffer with formatted string
std::snprintf(buf.data(), buf.size(), fmt, args...);
return std::string(buf.data());
}
};

Expand Down