diff --git a/src/python/alias.cpp b/src/python/alias.cpp index 0164a0225..448115284 100644 --- a/src/python/alias.cpp +++ b/src/python/alias.cpp @@ -95,16 +95,16 @@ static nb::object variant_module(nb::handle variant) { /// Sets the variant static void set_variant(nb::args args) { - nb::object new_variant{}; - for (auto arg : args) { - // Find the first valid & compiled variant in the arguments + nb::list valid_variants{}; + for (const auto &arg : args) { if (PyDict_Contains(variant_modules, arg.ptr()) == 1) { - new_variant = nb::borrow(arg); - break; + // Variant is at least compiled, we can attempt to use it. + valid_variants.append(arg); } } - if (!new_variant) { + if (valid_variants.size() == 0) { + // None of the requested variants are compiled. nb::object all_args(nb::str(", ").attr("join")(args)); nb::object all_variants( nb::str(", ").attr("join")(nb::steal(PyDict_Keys(variant_modules)))); @@ -116,8 +116,36 @@ static void set_variant(nb::args args) { ); } - if (!curr_variant.equal(new_variant)) { - nb::object new_variant_module = variant_module(new_variant); + nb::object old_variant = curr_variant; + // For each requested _and_ available variant, in order of preference. + for (size_t i = 0; i < valid_variants.size(); ++i) { + const auto &requested_variant = valid_variants[i]; + bool is_last = (i == valid_variants.size() - 1); + + if (requested_variant.equal(old_variant)) { + // We're already using this variant, no need to do anything. + break; + } + + nb::object new_variant_module; + try { + new_variant_module = variant_module(requested_variant); + } catch (const nb::python_error &e) { + // The variant failed to import, this could happen e.g. if the + // CUDA driver is installed, but there is no GPU available. + // We only allow such failures as long as we have more variants to try. + if (!is_last && e.matches(PyExc_ImportError)) { + const auto mi = nb::module_::import_("mitsuba"); + mi.attr("Log")( + mi.attr("LogLevel").attr("Debug"), + nb::str("The requested variant \"{}\" could not be loaded, " + "attempting the next one. The exception was:\n{}\n") + .format(requested_variant, e.what())); + continue; + } else { + throw e; + } + } nb::dict variant_dict = new_variant_module.attr("__dict__"); for (const auto &k : variant_dict.keys()) @@ -126,21 +154,23 @@ static void set_variant(nb::args args) { Safe_PyDict_SetItem(mi_dict, k.ptr(), PyDict_GetItem(variant_dict.ptr(), k.ptr())); - nb::object old_variant = curr_variant; + curr_variant = requested_variant; + break; + } + - // Need to update curr_variant = mi.variant() before reloading internal plugins - curr_variant = new_variant; - if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) { + if (!curr_variant.equal(old_variant)) { + // Reload internal plugins + if (curr_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) { nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators"); nb::steal(PyImport_ReloadModule(mi_python.ptr())); } - // Only invoke callbacks after Mitsuba plugins have reloaded as there - // may be a dependency + // Only invoke user-provided callbacks after Mitsuba plugins have reloaded, + // as there may be a dependency const auto &callbacks = nb::borrow(variant_change_callbacks); for (const auto &cb : callbacks) - cb(old_variant, new_variant); - + cb(old_variant, curr_variant); } } @@ -232,20 +262,22 @@ NB_MODULE(mitsuba_alias, m) { mi_dict = m.attr("__dict__").ptr(); nb::object mi_ext = import_with_deepbind_if_necessary("mitsuba.mitsuba_ext"); nb::dict mitsuba_ext_dict = mi_ext.attr("__dict__"); - for (const auto &k : mitsuba_ext_dict.keys()) + for (const auto &k : mitsuba_ext_dict.keys()) { if (!nb::bool_(k.attr("startswith")("__")) && !nb::bool_(k.attr("endswith")("__"))) { Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_ext_dict[k].ptr()); } + } // Import contents of `mitsuba.python` into top-level `mitsuba` module nb::object mi_python = nb::module_::import_("mitsuba.python"); nb::dict mitsuba_python_dict = mi_python.attr("__dict__"); - for (const auto &k : mitsuba_python_dict.keys()) + for (const auto &k : mitsuba_python_dict.keys()) { if (!nb::bool_(k.attr("startswith")("__")) && !nb::bool_(k.attr("endswith")("__"))) { Safe_PyDict_SetItem(mi_dict, k.ptr(), mitsuba_python_dict[k].ptr()); } + } /// Cleanup static variables, this is called when the interpreter is exiting auto atexit = nb::module_::import_("atexit"); diff --git a/src/python/main_v.cpp b/src/python/main_v.cpp index 200ad09ee..dd1c2e617 100644 --- a/src/python/main_v.cpp +++ b/src/python/main_v.cpp @@ -120,6 +120,16 @@ using Caster = nb::object(*)(mitsuba::Object *); Caster cast_object = nullptr; NB_MODULE(MI_VARIANT_NAME, m) { + /* scoped */ { + // Before loading everything in and creating a lot of references to + // various objects, we ensure that this backend can be initialized + // without issues by creating a simple variable. + // If initialization fails, an exception will be raised, which the user + // can catch and handle if desired. + // Leaving initialization to fail later would lead to reference leaks. + MI_VARIANT_FLOAT(0); + } + m.attr("__name__") = "mitsuba"; // Create sub-modules