@@ -61,7 +61,6 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
61
61
};
62
62
63
63
for (i = 1 ; i <= std::min (s1_size, s2_size); ++i) {
64
- // TODO(@icemelon9): Need to revisit this part
65
64
const IntImmNode* static_size1 = shape1[s1_size - i].as <IntImmNode>();
66
65
const IntImmNode* static_size2 = shape2[s2_size - i].as <IntImmNode>();
67
66
DataType common_type = CommonType (shape1[s1_size - i].dtype (), shape2[s2_size - i].dtype ());
@@ -92,10 +91,12 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
92
91
bh.vars1 .push_front (bh.all_vars [0 ]);
93
92
bh.vars2 .push_front (bh.all_vars [0 ]);
94
93
} else {
95
- ICHECK (false ) << " Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
96
- << shape2[s2_size - i]
97
- << " in: " << tvm::Array<tvm::PrimExpr>(shape1.begin (), shape1.end ()) << " and "
98
- << tvm::Array<tvm::PrimExpr>(shape2.begin (), shape2.end ());
94
+ LOG (WARNING) << " Incompatible broadcast dims: " << shape1[s1_size - i] << " and "
95
+ << shape2[s2_size - i] << " . Automatically cutting the larger dimension." ;
96
+ auto min_dim = tvm::IntImm (common_type, std::min (static_size1->value , static_size2->value ));
97
+ bh.common_shape .push_front (cast_if_needed (common_type, min_dim));
98
+ bh.vars1 .push_front (bh.all_vars [0 ]);
99
+ bh.vars2 .push_front (bh.all_vars [0 ]);
99
100
}
100
101
}
101
102
// Remaining dimensions whether on shape1 or shape2 can always be completed
0 commit comments