Skip to content

Commit 5c5f28e

Browse files
authored
Fix issue when broadcasting a reduction with different dimensions as the input (#4408)
1 parent 985305a commit 5c5f28e

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

src/shape_transform_descriptor.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,40 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector<
185185
return result;
186186
}
187187

188+
static bool is_broadcast_only(const std::vector<dimension>& src_dims,
189+
const std::vector<dimension>& dst_dims)
190+
{
191+
return std::equal(src_dims.begin(),
192+
src_dims.end(),
193+
dst_dims.begin(),
194+
dst_dims.end(),
195+
[](const auto& src_dim, const auto& dst_dim) {
196+
if(src_dim.subdimensions.size() != dst_dim.subdimensions.size())
197+
return false;
198+
auto match_sub_dim = [](const dimension::sub& src_sub,
199+
const dimension::sub& dst_sub) {
200+
if(src_sub.len == 1)
201+
return true;
202+
return src_sub.len == dst_sub.len;
203+
};
204+
auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(),
205+
src_dim.subdimensions.end(),
206+
dst_dim.subdimensions.begin(),
207+
dst_dim.subdimensions.end(),
208+
match_sub_dim);
209+
if(src_it == src_dim.subdimensions.end())
210+
return true;
211+
// One mismatch is fine as long as the dimension is still the same size
212+
if(src_dim.len() != dst_dim.len())
213+
return false;
214+
return std::equal(std::next(src_it),
215+
src_dim.subdimensions.end(),
216+
std::next(dst_it),
217+
dst_dim.subdimensions.end(),
218+
match_sub_dim);
219+
});
220+
}
221+
188222
shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector<std::size_t>& dims,
189223
bool broadcast) const
190224
{
@@ -233,6 +267,8 @@ shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector<
233267
}
234268
// TODO: Only simplify if the subs was changed
235269
result.simplify();
270+
if(broadcast and not is_broadcast_only(dimensions, result.dimensions))
271+
return {};
236272

237273
return result;
238274
}

src/simplify_reshapes.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ struct find_op_shape_transform_op
263263
auto desc = desc1.rebase(x_ins->inputs().front()->get_shape().lens(), true);
264264
if(not desc.empty())
265265
return desc;
266-
if(not is_reduce(x_ins))
266+
if(not is_reduce(x_ins) or any_of(ops, [](const operation& op) {
267+
return contains({"broadcast", "multibroadcast"}, op.name());
268+
}))
267269
return desc1;
268270
// Find a broadcast to append to improve the reduction analysis
269271
auto output_path = get_output_path(input_ins);

test/simplify_reshapes_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2666,6 +2666,42 @@ TEST_CASE(reduce_broadcast_reshape_pointwise2)
26662666
EXPECT(m1.sort() == m2.sort());
26672667
}
26682668

2669+
TEST_CASE(reduce_transpose_broadcast_pointwise_diff_size)
2670+
{
2671+
auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 128, 128, 3}};
2672+
auto s2 = migraphx::shape{migraphx::shape::float_type, {1, 3, 256, 256}};
2673+
migraphx::module m1;
2674+
{
2675+
auto x = m1.add_parameter("x", s1);
2676+
auto y = m1.add_parameter("y", s2);
2677+
auto reduce_sum =
2678+
m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x);
2679+
auto transpose = m1.add_instruction(
2680+
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), reduce_sum);
2681+
auto broadcast = m1.add_instruction(
2682+
migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), transpose);
2683+
auto add = m1.add_instruction(migraphx::make_op("add"), broadcast, y);
2684+
auto relu = m1.add_instruction(migraphx::make_op("relu"), add);
2685+
m1.add_return({relu});
2686+
}
2687+
run_pass(m1);
2688+
migraphx::module m2;
2689+
{
2690+
auto x = m2.add_parameter("x", s1);
2691+
auto y = m2.add_parameter("y", s2);
2692+
auto transpose =
2693+
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), x);
2694+
auto reduce_sum =
2695+
m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), transpose);
2696+
auto broadcast = m2.add_instruction(
2697+
migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), reduce_sum);
2698+
auto add = m2.add_instruction(migraphx::make_op("add"), broadcast, y);
2699+
auto relu = m2.add_instruction(migraphx::make_op("relu"), add);
2700+
m2.add_return({relu});
2701+
}
2702+
EXPECT(m1.sort() == m2.sort());
2703+
}
2704+
26692705
TEST_CASE(transpose_contiguous_reshape_binary_packed)
26702706
{
26712707
migraphx::module m1;

0 commit comments

Comments
 (0)