Skip to content

Commit

Permalink
[FixBug][Ansor] Fixing BroadcastShape function
Browse files Browse the repository at this point in the history
  • Loading branch information
thaisacs committed Feb 6, 2025
1 parent 3eb5ad6 commit 234b87c
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions include/tvm/topi/detail/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
};

for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
const IntImmNode* static_size1 = shape1[s1_size - i].as<IntImmNode>();
const IntImmNode* static_size2 = shape2[s2_size - i].as<IntImmNode>();
DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());
Expand All @@ -72,38 +71,37 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
bh.vars2.push_front(bh.all_vars[0]);
} else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
} else if (!static_size1 && !static_size2) {
} else if (!static_size1 || !static_size2) {
bh.common_shape.push_front(
cast_if_needed(common_type, max(shape1[s1_size - i], shape2[s2_size - i])));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (!static_size1) {
bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
bh.vars2.push_front(bh.all_vars[0]);
bh.vars1.push_front(bh.all_vars[0]);
} else if (!static_size2) {
bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else {
ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
<< shape2[s2_size - i]
<< " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end()) << " and "
<< tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
if (static_size1->value != static_size2->value) {
LOG(WARNING) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
<< shape2[s2_size - i] << ". Automatically cutting the larger dimension.";
auto min_dim = tvm::IntImm(common_type, std::min(static_size1->value, static_size2->value));
bh.common_shape.push_front(cast_if_needed(common_type, min_dim));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else {
bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
}
}
}
// Remaining dimensions whether on shape1 or shape2 can always be completed

auto max_size = std::max(s1_size, s2_size);
auto& shape = (s1_size > s2_size) ? shape1 : shape2;
auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
for (; i <= max_size; ++i) {
bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - 1].dtype()));
bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - i].dtype()));
bh.common_shape.push_front(shape[max_size - i]);
vars.push_front(bh.all_vars[0]);
}
Expand Down

0 comments on commit 234b87c

Please sign in to comment.