2323#include " llvm/Support/LogicalResult.h"
2424#include < cassert>
2525
26+ #define DEBUG_TYPE " mask-analysis"
27+
2628namespace mlir {
2729
2830namespace triton {
@@ -198,15 +200,18 @@ LogicalResult MaskState::addStates(const MaskState &lhsState,
198200 const MaskState &rhsState, Location loc,
199201 OpBuilder &builder) {
200202 if (lhsState.scalar && rhsState.scalar ) {
201- InFlightDiagnostic diag =
202- emitError (loc) << " Unexpected case where both lhs and rhs are scalars" ;
203+ LLVM_DEBUG ({
204+ InFlightDiagnostic diag =
205+ emitRemark (loc, " Unexpected case where both lhs and rhs are scalars" );
206+ });
203207 return failure ();
204208 }
205209
206210 if (!lhsState.scalar && !rhsState.scalar ) {
207- InFlightDiagnostic diag =
208- emitError (loc)
209- << " Unsupported scenario where neither lhs nor rhs is a scalar" ;
211+ LLVM_DEBUG ({
212+ InFlightDiagnostic diag = emitRemark (
213+ loc, " Unsupported scenario where neither lhs nor rhs is a scalar" );
214+ });
210215 return failure ();
211216 }
212217
@@ -222,16 +227,19 @@ LogicalResult MaskState::minStateScalar(const MaskState &lhsState,
222227 // Conjunction where both sides are scalar should not be done after splats. We
223228 // should ensure that code generation pushes the splat as late as possible.
224229 if (lhsState.scalar && rhsState.scalar ) {
225- InFlightDiagnostic diag =
226- emitError (loc) << " Unexpected case where both lhs and rhs are scalars" ;
230+ LLVM_DEBUG ({
231+ InFlightDiagnostic diag =
232+ emitRemark (loc, " Unexpected case where both lhs and rhs are scalars" );
233+ });
227234 return failure ();
228235 }
229236
230237 // Caller should ensure that at least one side is scalar.
231238 if (!lhsState.scalar && !rhsState.scalar ) {
232- InFlightDiagnostic diag =
233- emitError (loc)
234- << " Unexpected case where both lhs and rhs are not scalars" ;
239+ LLVM_DEBUG ({
240+ InFlightDiagnostic diag = emitRemark (
241+ loc, " Unexpected case where both lhs and rhs are not scalars" );
242+ });
235243 return failure ();
236244 }
237245
@@ -262,9 +270,10 @@ LogicalResult MaskState::minStates(const MaskState &lhsState,
262270 const MaskState &rhsState, Location loc,
263271 OpBuilder &builder) {
264272 if (lhsState.getRank () != rhsState.getRank ()) {
265- InFlightDiagnostic diag =
266- emitError (loc)
267- << " Unexpected case where lhs and rhs have different ranks" ;
273+ LLVM_DEBUG ({
274+ InFlightDiagnostic diag = emitRemark (
275+ loc, " Unexpected case where lhs and rhs have different ranks" );
276+ });
268277 return failure ();
269278 }
270279
@@ -497,7 +506,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
497506 if (cmpOp.getPredicate () != arith::CmpIPredicate::slt &&
498507 cmpOp.getPredicate () != arith::CmpIPredicate::ult &&
499508 cmpOp.getPredicate () != arith::CmpIPredicate::sge) {
500- InFlightDiagnostic diag = emitError (loc) << " Unsupported cmpi" ;
509+ LLVM_DEBUG (
510+ { InFlightDiagnostic diag = emitRemark (loc, " Unsupported cmpi" ); });
501511 return failure ();
502512 }
503513
@@ -514,8 +524,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
514524 // the comparison evaluates to true.
515525 if (cmpOp.getPredicate () == arith::CmpIPredicate::sge &&
516526 !(rhsState.scalar && hasConstZero (rhsState.scalar ))) {
517- InFlightDiagnostic diag = emitError (loc)
518- << " Unsupported cmpi with rhs not equal to 0" ;
527+ LLVM_DEBUG ({
528+ InFlightDiagnostic diag =
529+ emitRemark (loc, " Unsupported cmpi with rhs not equal to 0" );
530+ });
519531 return failure ();
520532 }
521533
@@ -524,9 +536,11 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
524536 auto dimIntAttr = getIntAttr (lhsState.dims [i]);
525537 if (!dimIntAttr || dimIntAttr.value () != 1 ) {
526538 if (cmpDim != -1 ) {
527- InFlightDiagnostic diag = emitError (loc)
528- << " Unsupported cmpi with more than one "
529- " dimension with size larger than 1" ;
539+ LLVM_DEBUG ({
540+ InFlightDiagnostic diag =
541+ emitRemark (loc, " Unsupported cmpi with more than one dimension "
542+ " with size larger than 1" );
543+ });
530544 return failure ();
531545 }
532546 cmpDim = i;
@@ -690,10 +704,11 @@ LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp,
690704 auto stride = (end - start + shape[0 ] - 1 ) / shape[0 ];
691705
692706 if (stride != 1 ) {
693- InFlightDiagnostic diag =
694- emitError (loc)
695- << " stride must be 1 for make_range whose result is used "
696- " as load or store masks" ;
707+ LLVM_DEBUG ({
708+ InFlightDiagnostic diag = emitRemark (
709+ loc, " stride must be 1 for make_range whose result is used "
710+ " as load or store masks" );
711+ });
697712 return failure ();
698713 }
699714
@@ -743,9 +758,10 @@ LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
743758 auto dstShape = cast<ShapedType>(dst.getType ()).getShape ();
744759
745760 if (!isa<IntegerType>(src.getType ())) {
746- InFlightDiagnostic diag =
747- emitError (loc)
748- << " splat source must be an integer scalar for load/store masks" ;
761+ LLVM_DEBUG ({
762+ InFlightDiagnostic diag = emitRemark (
763+ loc, " splat source must be an integer scalar for load/store masks" );
764+ });
749765 return failure ();
750766 }
751767
0 commit comments