Implementing an argmax in drjit #1562
-
Hi, I am trying to implement a BSDF that stores several externally evaluated samples and returns the closest one to the input directions. This is the eval code in numpy, which searches the most similar sample and gathers the associated color. def eval(self, ctx, si, wo, active):
wi = np.array(wi)
wo = np.array(wo)
cosi = np.abs(np.dot(wi, self.cartesian_wis))
coso = np.abs(np.dot(wo, self.cartesian_wos))
crossi = np.abs(np.dot(wi, self.cartesian_wos))
crosso = np.abs(np.dot(wo, self.cartesian_wis))
# Check also cross similarity because of Hellmann's reciprocity
similarities = np.maximum(cosi + coso, crossi + crosso)
idx = int(np.argmax(similarities))
return mi.Color3f(self.data[:, idx]) * mi.Frame3f.cos_theta(wo) However this only works in scalar mode and thus is extremely slow. I am trying to implement it with drjit, but I don't know how to implement the final argmax reduction. This is my drjit code: def eval(self, ctx, si, wo, active):
active &= ...... # Check the input is correct
cosi = dr.abs_dot(wi, self.cartesian_wis)
coso = dr.abs_dot(wo, self.cartesian_wos)
crossi = dr.abs_dot(wi, self.cartesian_wos)
crosso = dr.abs_dot(wo, self.cartesian_wis)
similarities = dr.maximum(cosi + coso, crossi + crosso)
# Perform the argmax operation
max_val = dr.max(similarities, axis=0)
max_mask = similarities == max_val
filtered = self.data_indices & max_mask
idx = dr.sum(filtered, axis=0)
result = dr.gather(Array3f, source=self.data_flat, index=idx)
return mi.Color3f(result) * mi.Frame3f.cos_theta(wo) There is no built-in argmax, so I have tried to implement it by combining several reductions. But since the variables are symbolic, they throw errors due to not being able to evaluate them. I would be grateful if someone could help me, thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Hi @hsunekichi You might have seen this discussion in the Dr.Jit repository: mitsuba-renderer/drjit#375 Fundamentally, you'll never be able to have a symbolic horizontal reduction with a read of the result in the same kernel. The reason lies in the execution model: not all threads are alive at the same time. At any point in time, assuming you're running a very wide kernel, only a subset of threads are running and once they're done we can move on to the next subset, until all threads are executed. Inherently, a reduction requires some level of synchronization - all threads must come together to produce a result before being able to continue their work - this is not compatible with the execution model. |
Beta Was this translation helpful? Give feedback.
Hi @hsunekichi
You might have seen this discussion in the Dr.Jit repository: mitsuba-renderer/drjit#375
You'll also want to have a look at this, slightly more general, page: https://drjit.readthedocs.io/en/stable/eval.html
Fundamentally, you'll never be able to have a symbolic horizontal reduction with a read of the result in the same kernel.
Note that we do have a
mode=symbolic
on most of our horizontal reductions, but they produce a side-effect which means that when you'll try to access the result it will trigger an evaluation.The reason lies in the execution model: not all threads are alive at the same time. At any point in time, assuming you're running a very wide kernel, only a subs…