Skip to content

Commit 38c4a3c

Browse files
committed
fix compilation of dr.matvec() within loops
1 parent b0a6e21 commit 38c4a3c

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

src/coop_vec.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
#include "cuda.h"
1818
#include <drjit-core/nanostl.h>
1919

20+
static uint32_t unwrap(uint32_t index) {
21+
while (true) {
22+
const Variable *v = jitc_var(index);
23+
if (v->kind != (uint32_t) VarKind::LoopPhi)
24+
return index;
25+
index = borrow(v->dep[3]);
26+
}
27+
}
28+
2029
uint32_t jitc_coop_vec_pack(uint32_t n, const uint32_t *in) {
2130
if (n == 0)
2231
jitc_raise("jit_coop_vec_pack(): vector cannot be empty!");
@@ -473,6 +482,10 @@ uint32_t jitc_coop_vec_matvec(uint32_t A_index,
473482
output_length, input_length, x_v->array_length);
474483
}
475484

485+
A_index = unwrap(A_index);
486+
if (b_index)
487+
b_index = unwrap(b_index);
488+
476489
Ref a_ptr, b_ptr;
477490
{
478491
void *p = nullptr;
@@ -578,6 +591,9 @@ uint32_t jitc_coop_vec_accum(uint32_t target_, uint32_t target_size,
578591
"float16 precision on the CUDA/OptiX backend.");
579592
}
580593

594+
if (target_)
595+
target_ = unwrap(target_);
596+
581597
Ref target = borrow(target_);
582598
if (!target) {
583599
uint64_t z = 0;
@@ -626,6 +642,10 @@ uint32_t jitc_coop_vec_outer_product_accum(uint32_t target_,
626642
JitBackend backend;
627643
VarType vt;
628644
uint32_t size;
645+
646+
if (target_)
647+
target_ = unwrap(target_);
648+
629649
{
630650
const Variable *v_a = jitc_var(a),
631651
*v_b = jitc_var(b);

0 commit comments

Comments
 (0)