Skip to content

Commit 240eebb

Browse files
committed
set_variant(): try next requested variant upon ImportError
If the user specifies several variants in order: mi.set_variant("cuda_ad_rgb", "llvm_ad_rgb") and they are all compiled (available) but some fail to import, then keep trying with the next requested variant instead of throwing an exception. This could happen e.g. when requested a CUDA variant with CUDA installed but no GPU available, or when LLVM is not installed.
1 parent a052419 commit 240eebb

File tree

1 file changed

+45
-16
lines changed

1 file changed

+45
-16
lines changed

src/python/alias.cpp

+45-16
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,16 @@ static nb::object variant_module(nb::handle variant) {
9595

9696
/// Sets the variant
9797
static void set_variant(nb::args args) {
98-
nb::object new_variant{};
99-
for (auto arg : args) {
100-
// Find the first valid & compiled variant in the arguments
98+
nb::list valid_variants{};
99+
for (const auto &arg : args) {
101100
if (PyDict_Contains(variant_modules, arg.ptr()) == 1) {
102-
new_variant = nb::borrow(arg);
103-
break;
101+
// Variant is at least compiled, we can attempt to use it.
102+
valid_variants.append(nb::borrow(arg));
104103
}
105104
}
106105

107-
if (!new_variant) {
106+
if (valid_variants.size() == 0) {
107+
// None of the requested variants are compiled.
108108
nb::object all_args(nb::str(", ").attr("join")(args));
109109
nb::object all_variants(
110110
nb::str(", ").attr("join")(nb::steal(PyDict_Keys(variant_modules))));
@@ -116,8 +116,35 @@ static void set_variant(nb::args args) {
116116
);
117117
}
118118

119-
if (!curr_variant.equal(new_variant)) {
120-
nb::object new_variant_module = variant_module(new_variant);
119+
nb::object old_variant = curr_variant;
120+
// For each requested _and_ available variant, in order of preference.
121+
for (size_t i = 0; i < valid_variants.size(); ++i) {
122+
const auto &requested_variant = valid_variants[i];
123+
bool is_last = (i == valid_variants.size() - 1);
124+
125+
if (requested_variant.equal(old_variant)) {
126+
// We're already using this variant, no need to do anything.
127+
break;
128+
}
129+
130+
nb::object new_variant_module;
131+
try {
132+
new_variant_module = variant_module(requested_variant);
133+
} catch (const nb::python_error &e) {
134+
// The variant failed to import, this could happen e.g. if the
135+
// CUDA driver is installed, but there is no GPU available.
136+
// We only allow such failures as long as we have more variants to try.
137+
if (!is_last && e.matches(PyExc_ImportError)) {
138+
// nb::str var_str = requested_variant;
139+
// jit_log(LogLevel::Debug,
140+
// "The requested variant \"{}\" could not be loaded, "
141+
// "attempting the next one. The exception was: {}",
142+
// var_str.c_str(), e.what());
143+
continue;
144+
} else {
145+
throw e;
146+
}
147+
}
121148

122149
nb::dict variant_dict = new_variant_module.attr("__dict__");
123150
for (const auto &k : variant_dict.keys())
@@ -126,21 +153,23 @@ static void set_variant(nb::args args) {
126153
Safe_PyDict_SetItem(mi_dict, k.ptr(),
127154
PyDict_GetItem(variant_dict.ptr(), k.ptr()));
128155

129-
nb::object old_variant = curr_variant;
156+
curr_variant = requested_variant;
157+
break;
158+
}
159+
130160

131-
// Need to update curr_variant = mi.variant() before reloading internal plugins
132-
curr_variant = new_variant;
133-
if (new_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) {
161+
if (!curr_variant.equal(old_variant)) {
162+
// Reload internal plugins
163+
if (curr_variant.attr("startswith")(nb::make_tuple("llvm_", "cuda_"))) {
134164
nb::module_ mi_python = nb::module_::import_("mitsuba.python.ad.integrators");
135165
nb::steal(PyImport_ReloadModule(mi_python.ptr()));
136166
}
137167

138-
// Only invoke callbacks after Mitsuba plugins have reloaded as there
139-
// may be a dependency
168+
// Only invoke user-provided callbacks after Mitsuba plugins have reloaded,
169+
// as there may be a dependency
140170
const auto &callbacks = nb::borrow<nb::set>(variant_change_callbacks);
141171
for (const auto &cb : callbacks)
142-
cb(old_variant, new_variant);
143-
172+
cb(old_variant, curr_variant);
144173
}
145174
}
146175

0 commit comments

Comments
 (0)