Skip to content

Commit

Permalink
[FixBug][Ansor] Fixing BroadcastShape function
Browse files Browse the repository at this point in the history
  • Loading branch information
thaisacs committed Feb 7, 2025
1 parent 3eb5ad6 commit d7dc989
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions include/tvm/topi/detail/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
};

for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
const IntImmNode* static_size1 = shape1[s1_size - i].as<IntImmNode>();
const IntImmNode* static_size2 = shape2[s2_size - i].as<IntImmNode>();
DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());
Expand Down Expand Up @@ -92,10 +91,12 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else {
ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
<< shape2[s2_size - i]
<< " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end()) << " and "
<< tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
LOG(WARNING) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
<< shape2[s2_size - i] << ". Automatically trimming the larger dimension.";
auto min_dim = tvm::IntImm(common_type, std::min(static_size1->value, static_size2->value));
bh.common_shape.push_front(cast_if_needed(common_type, min_dim));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
}
}
// Remaining dimensions whether on shape1 or shape2 can always be completed
Expand Down

0 comments on commit d7dc989

Please sign in to comment.