Skip to content

set_variant(): try next requested variant if one fails to import #1522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions src/python/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ static nb::object variant_module(nb::handle variant) {

/// Sets the variant
static void set_variant(nb::args args) {
nb::object new_variant{};
for (auto arg : args) {
// Find the first valid & compiled variant in the arguments
nb::list valid_variants{};
for (const auto &arg : args) {
if (PyDict_Contains(variant_modules, arg.ptr()) == 1) {
new_variant = nb::borrow(arg);
break;
// Variant is at least compiled, we can attempt to use it.
valid_variants.append(arg);
}
}

if (!new_variant) {
if (valid_variants.size() == 0) {
// None of the requested variants are compiled.
nb::object all_args(nb::str(", ").attr("join")(args));
nb::object all_variants(
nb::str(", ").attr("join")(nb::steal(PyDict_Keys(variant_modules))));
Expand All @@ -116,8 +116,36 @@ static void set_variant(nb::args args) {
);
}

if (!curr_variant.equal(new_variant)) {
nb::object new_variant_module = variant_module(new_variant);
nb::object old_variant = curr_variant;
// For each requested _and_ available variant, in order of preference.
for (size_t i = 0; i < valid_variants.size(); ++i) {
const auto &requested_variant = valid_variants[i];
bool is_last = (i == valid_variants.size() - 1);

if (requested_variant.equal(old_variant)) {
// We're already using this variant, no need to do anything.
break;
}

nb::object new_variant_module;
try {
new_variant_module = variant_module(requested_variant);
} catch (const nb::python_error &e) {
// The variant failed to import, this could happen e.g. if the
// CUDA driver is installed, but there is no GPU available.
// We only allow such failures as long as we have more variants to try.
if (!is_last && e.matches(PyExc_ImportError)) {
const auto mi = nb::module_::import_("mitsuba");
mi.attr("Log")(
mi.attr("LogLevel").attr("Debug"),
nb::str("The requested variant \"{}\" could not be loaded, "
"attempting the next one. The exception was:\n{}\n")
.format(requested_variant, e.what()));
continue;
} else {
throw e;
}
}

nb::dict variant_dict = new_variant_module.attr("__dict__");
for (const auto &k : variant_dict.keys())
Expand All @@ -126,21 +154,23 @@ static void set_variant(nb::args args) {
Safe_PyDict_SetItem(mi_dict, k.ptr(),
PyDict_GetItem(variant_dict.ptr(), k.ptr()));

nb::object old_variant = curr_variant;
curr_variant = requested_variant;
break;
}


// Need to update curr_variant = mi.variant() before reloading internal plugins
curr_variant = new_variant;
if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) {
if (!curr_variant.equal(old_variant)) {
// Reload internal plugins
if (curr_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) {
nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators");
nb::steal(PyImport_ReloadModule(mi_python.ptr()));
}

// Only invoke callbacks after Mitsuba plugins have reloaded as there
// may be a dependency
// Only invoke user-provided callbacks after Mitsuba plugins have reloaded,
// as there may be a dependency
const auto &callbacks = nb::borrow<nb::set>(variant_change_callbacks);
for (const auto &cb : callbacks)
cb(old_variant, new_variant);

cb(old_variant, curr_variant);
}
}

Expand Down Expand Up @@ -232,20 +262,22 @@ NB_MODULE(mitsuba_alias, m) {
mi_dict = m.attr("__dict__").ptr();
nb::object mi_ext = import_with_deepbind_if_necessary("mitsuba.mitsuba_ext");
nb::dict mitsuba_ext_dict = mi_ext.attr("__dict__");
for (const auto &k : mitsuba_ext_dict.keys())
for (const auto &k : mitsuba_ext_dict.keys()) {
if (!nb::bool_(k.attr("startswith")("__")) &&
!nb::bool_(k.attr("endswith")("__"))) {
Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_ext_dict[k].ptr());
}
}

// Import contents of `mitsuba.python` into top-level `mitsuba` module
nb::object mi_python = nb::module_::import_("mitsuba.python");
nb::dict mitsuba_python_dict = mi_python.attr("__dict__");
for (const auto &k : mitsuba_python_dict.keys())
for (const auto &k : mitsuba_python_dict.keys()) {
if (!nb::bool_(k.attr("startswith")("__")) &&
!nb::bool_(k.attr("endswith")("__"))) {
Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_python_dict[k].ptr());
}
}

/// Cleanup static variables, this is called when the interpreter is exiting
auto atexit = nb::module_::import_("atexit");
Expand Down
10 changes: 10 additions & 0 deletions src/python/main_v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ using Caster = nb::object(*)(mitsuba::Object *);
Caster cast_object = nullptr;

NB_MODULE(MI_VARIANT_NAME, m) {
/* scoped */ {
// Before loading everything in and creating a lot of references to
// various objects, we ensure that this backend can be initialized
// without issues by creating a simple variable.
// If initialization fails, an exception will be raised, which the user
// can catch and handle if desired.
// Leaving initialization to fail later would lead to reference leaks.
MI_VARIANT_FLOAT(0);
Comment on lines +124 to +130
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should hopefully be enough to avoid leaks that have been happening when failing to import a variant.

}

m.attr("__name__") = "mitsuba";

// Create sub-modules
Expand Down