@@ -95,16 +95,16 @@ static nb::object variant_module(nb::handle variant) {
95
95
96
96
// / Sets the variant
97
97
static void set_variant (nb::args args) {
98
- nb::object new_variant{};
99
- for (auto arg : args) {
100
- // Find the first valid & compiled variant in the arguments
98
+ nb::list valid_variants{};
99
+ for (const auto &arg : args) {
101
100
if (PyDict_Contains (variant_modules, arg.ptr ()) == 1 ) {
102
- new_variant = nb::borrow (arg);
103
- break ;
101
+ // Variant is at least compiled, we can attempt to use it.
102
+ valid_variants. append ( nb::borrow (arg)) ;
104
103
}
105
104
}
106
105
107
- if (!new_variant) {
106
+ if (valid_variants.size () == 0 ) {
107
+ // None of the requested variants are compiled.
108
108
nb::object all_args (nb::str (" , " ).attr (" join" )(args));
109
109
nb::object all_variants (
110
110
nb::str (" , " ).attr (" join" )(nb::steal (PyDict_Keys (variant_modules))));
@@ -116,8 +116,35 @@ static void set_variant(nb::args args) {
116
116
);
117
117
}
118
118
119
- if (!curr_variant.equal (new_variant)) {
120
- nb::object new_variant_module = variant_module (new_variant);
119
+ nb::object old_variant = curr_variant;
120
+ // For each requested _and_ available variant, in order of preference.
121
+ for (size_t i = 0 ; i < valid_variants.size (); ++i) {
122
+ const auto &requested_variant = valid_variants[i];
123
+ bool is_last = (i == valid_variants.size () - 1 );
124
+
125
+ if (requested_variant.equal (old_variant)) {
126
+ // We're already using this variant, no need to do anything.
127
+ break ;
128
+ }
129
+
130
+ nb::object new_variant_module;
131
+ try {
132
+ new_variant_module = variant_module (requested_variant);
133
+ } catch (const nb::python_error &e) {
134
+ // The variant failed to import, this could happen e.g. if the
135
+ // CUDA driver is installed, but there is no GPU available.
136
+ // We only allow such failures as long as we have more variants to try.
137
+ if (!is_last && e.matches (PyExc_ImportError)) {
138
+ // nb::str var_str = requested_variant;
139
+ // jit_log(LogLevel::Debug,
140
+ // "The requested variant \"{}\" could not be loaded, "
141
+ // "attempting the next one. The exception was: {}",
142
+ // var_str.c_str(), e.what());
143
+ continue ;
144
+ } else {
145
+ throw e;
146
+ }
147
+ }
121
148
122
149
nb::dict variant_dict = new_variant_module.attr (" __dict__" );
123
150
for (const auto &k : variant_dict.keys ())
@@ -126,21 +153,23 @@ static void set_variant(nb::args args) {
126
153
Safe_PyDict_SetItem (mi_dict, k.ptr (),
127
154
PyDict_GetItem (variant_dict.ptr (), k.ptr ()));
128
155
129
- nb::object old_variant = curr_variant;
156
+ curr_variant = requested_variant;
157
+ break ;
158
+ }
159
+
130
160
131
- // Need to update curr_variant = mi.variant() before reloading internal plugins
132
- curr_variant = new_variant;
133
- if (new_variant .attr (" startswith" )(nb::make_tuple (" llvm_" , " cuda_" ))) {
161
+ if (! curr_variant. equal (old_variant)) {
162
+ // Reload internal plugins
163
+ if (curr_variant .attr (" startswith" )(nb::make_tuple (" llvm_" , " cuda_" ))) {
134
164
nb::module_ mi_python = nb::module_::import_ (" mitsuba.python.ad.integrators" );
135
165
nb::steal (PyImport_ReloadModule (mi_python.ptr ()));
136
166
}
137
167
138
- // Only invoke callbacks after Mitsuba plugins have reloaded as there
139
- // may be a dependency
168
+ // Only invoke user-provided callbacks after Mitsuba plugins have reloaded,
169
+ // as there may be a dependency
140
170
const auto &callbacks = nb::borrow<nb::set>(variant_change_callbacks);
141
171
for (const auto &cb : callbacks)
142
- cb (old_variant, new_variant);
143
-
172
+ cb (old_variant, curr_variant);
144
173
}
145
174
}
146
175
0 commit comments