Skip to content

Commit 3f29c9e

Browse files
authored
Wrap all error messages in PtrAnalysis and MaskAnalysis in LLVM_DEBUG (#357)
Change all instances of `emitError` to `emitRemark` in `PtrAnalysis` and `MaskAnalysis` while also wrapping them in `LLVM_DEBUG` to prevent confusion to users. These errors should already gracefully handled by triton-shared and should be used for debugging purposes only.
1 parent c3f2342 commit 3f29c9e

File tree

2 files changed

+152
-109
lines changed

2 files changed

+152
-109
lines changed

lib/Analysis/MaskAnalysis.cpp

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "llvm/Support/LogicalResult.h"
2424
#include <cassert>
2525

26+
#define DEBUG_TYPE "mask-analysis"
27+
2628
namespace mlir {
2729

2830
namespace triton {
@@ -198,15 +200,18 @@ LogicalResult MaskState::addStates(const MaskState &lhsState,
198200
const MaskState &rhsState, Location loc,
199201
OpBuilder &builder) {
200202
if (lhsState.scalar && rhsState.scalar) {
201-
InFlightDiagnostic diag =
202-
emitError(loc) << "Unexpected case where both lhs and rhs are scalars";
203+
LLVM_DEBUG({
204+
InFlightDiagnostic diag =
205+
emitRemark(loc, "Unexpected case where both lhs and rhs are scalars");
206+
});
203207
return failure();
204208
}
205209

206210
if (!lhsState.scalar && !rhsState.scalar) {
207-
InFlightDiagnostic diag =
208-
emitError(loc)
209-
<< "Unsupported scenario where neither lhs nor rhs is a scalar";
211+
LLVM_DEBUG({
212+
InFlightDiagnostic diag = emitRemark(
213+
loc, "Unsupported scenario where neither lhs nor rhs is a scalar");
214+
});
210215
return failure();
211216
}
212217

@@ -222,16 +227,19 @@ LogicalResult MaskState::minStateScalar(const MaskState &lhsState,
222227
// Conjunction where both sides are scalar should not be done after splats. We
223228
// should ensure that code generation pushes the splat as late as possible.
224229
if (lhsState.scalar && rhsState.scalar) {
225-
InFlightDiagnostic diag =
226-
emitError(loc) << "Unexpected case where both lhs and rhs are scalars";
230+
LLVM_DEBUG({
231+
InFlightDiagnostic diag =
232+
emitRemark(loc, "Unexpected case where both lhs and rhs are scalars");
233+
});
227234
return failure();
228235
}
229236

230237
// Caller should ensure that at least one side is scalar.
231238
if (!lhsState.scalar && !rhsState.scalar) {
232-
InFlightDiagnostic diag =
233-
emitError(loc)
234-
<< "Unexpected case where both lhs and rhs are not scalars";
239+
LLVM_DEBUG({
240+
InFlightDiagnostic diag = emitRemark(
241+
loc, "Unexpected case where both lhs and rhs are not scalars");
242+
});
235243
return failure();
236244
}
237245

@@ -262,9 +270,10 @@ LogicalResult MaskState::minStates(const MaskState &lhsState,
262270
const MaskState &rhsState, Location loc,
263271
OpBuilder &builder) {
264272
if (lhsState.getRank() != rhsState.getRank()) {
265-
InFlightDiagnostic diag =
266-
emitError(loc)
267-
<< "Unexpected case where lhs and rhs have different ranks";
273+
LLVM_DEBUG({
274+
InFlightDiagnostic diag = emitRemark(
275+
loc, "Unexpected case where lhs and rhs have different ranks");
276+
});
268277
return failure();
269278
}
270279

@@ -497,7 +506,8 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
497506
if (cmpOp.getPredicate() != arith::CmpIPredicate::slt &&
498507
cmpOp.getPredicate() != arith::CmpIPredicate::ult &&
499508
cmpOp.getPredicate() != arith::CmpIPredicate::sge) {
500-
InFlightDiagnostic diag = emitError(loc) << "Unsupported cmpi";
509+
LLVM_DEBUG(
510+
{ InFlightDiagnostic diag = emitRemark(loc, "Unsupported cmpi"); });
501511
return failure();
502512
}
503513

@@ -514,8 +524,10 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
514524
// the comparison evaluates to true.
515525
if (cmpOp.getPredicate() == arith::CmpIPredicate::sge &&
516526
!(rhsState.scalar && hasConstZero(rhsState.scalar))) {
517-
InFlightDiagnostic diag = emitError(loc)
518-
<< "Unsupported cmpi with rhs not equal to 0";
527+
LLVM_DEBUG({
528+
InFlightDiagnostic diag =
529+
emitRemark(loc, "Unsupported cmpi with rhs not equal to 0");
530+
});
519531
return failure();
520532
}
521533

@@ -524,9 +536,11 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
524536
auto dimIntAttr = getIntAttr(lhsState.dims[i]);
525537
if (!dimIntAttr || dimIntAttr.value() != 1) {
526538
if (cmpDim != -1) {
527-
InFlightDiagnostic diag = emitError(loc)
528-
<< "Unsupported cmpi with more than one "
529-
"dimension with size larger than 1";
539+
LLVM_DEBUG({
540+
InFlightDiagnostic diag =
541+
emitRemark(loc, "Unsupported cmpi with more than one dimension "
542+
"with size larger than 1");
543+
});
530544
return failure();
531545
}
532546
cmpDim = i;
@@ -690,10 +704,11 @@ LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp,
690704
auto stride = (end - start + shape[0] - 1) / shape[0];
691705

692706
if (stride != 1) {
693-
InFlightDiagnostic diag =
694-
emitError(loc)
695-
<< "stride must be 1 for make_range whose result is used "
696-
"as load or store masks";
707+
LLVM_DEBUG({
708+
InFlightDiagnostic diag = emitRemark(
709+
loc, "stride must be 1 for make_range whose result is used "
710+
"as load or store masks");
711+
});
697712
return failure();
698713
}
699714

@@ -743,9 +758,10 @@ LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
743758
auto dstShape = cast<ShapedType>(dst.getType()).getShape();
744759

745760
if (!isa<IntegerType>(src.getType())) {
746-
InFlightDiagnostic diag =
747-
emitError(loc)
748-
<< "splat source must be an integer scalar for load/store masks";
761+
LLVM_DEBUG({
762+
InFlightDiagnostic diag = emitRemark(
763+
loc, "splat source must be an integer scalar for load/store masks");
764+
});
749765
return failure();
750766
}
751767

0 commit comments

Comments
 (0)