Skip to content

Commit 11e5e7d

Browse files
James Molloycopybara-github
James Molloy
authored andcommitted
[xls][mlir] Fix bug in array-to-bits, add support for ArrayConcatOp and speed up
We weren't declaring ArrayIndexOp and friends illegal, so the rewriter sometimes would not apply our rewrite patterns. Also noticed that ArrayConcatOp was missing. Also noticed it was super slow on a large module due to no parallelization, so added that. Then noticed that XLS verification was super slow (dominant in multithreading as we verify per thread), so fixed that. PiperOrigin-RevId: 697543700
1 parent 64e6d11 commit 11e5e7d

File tree

4 files changed

+58
-16
lines changed

4 files changed

+58
-16
lines changed

xls/contrib/mlir/IR/xls_ops.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,9 @@ LogicalResult SpawnOp::verifySymbolUses(SymbolTableCollection& symbolTable) {
583583

584584
namespace {
585585
LogicalResult verifyChannelUsingOp(Operation* op, SymbolRefAttr channelAttr,
586-
Type elementType) {
587-
auto chanOp = SymbolTable::lookupNearestSymbolFrom<ChanOp>(op, channelAttr);
586+
Type elementType,
587+
SymbolTableCollection& symbolTable) {
588+
auto chanOp = symbolTable.lookupNearestSymbolFrom<ChanOp>(op, channelAttr);
588589
if (!chanOp) {
589590
return op->emitOpError("channel symbol not found: ") << channelAttr;
590591
}
@@ -607,19 +608,21 @@ LogicalResult verifyStructuredChannelUsingOp(Operation* op, Value channel,
607608

608609
} // namespace
609610

610-
LogicalResult BlockingReceiveOp::verify() {
611+
LogicalResult BlockingReceiveOp::verifySymbolUses(
612+
SymbolTableCollection& symbolTable) {
611613
return verifyChannelUsingOp(getOperation(), getChannelAttr(),
612-
getResult().getType());
614+
getResult().getType(), symbolTable);
613615
}
614616

615-
LogicalResult NonblockingReceiveOp::verify() {
617+
LogicalResult NonblockingReceiveOp::verifySymbolUses(
618+
SymbolTableCollection& symbolTable) {
616619
return verifyChannelUsingOp(getOperation(), getChannelAttr(),
617-
getResult().getType());
620+
getResult().getType(), symbolTable);
618621
}
619622

620-
LogicalResult SendOp::verify() {
623+
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection& symbolTable) {
621624
return verifyChannelUsingOp(getOperation(), getChannelAttr(),
622-
getData().getType());
625+
getData().getType(), symbolTable);
623626
}
624627

625628
LogicalResult SBlockingReceiveOp::verify() {

xls/contrib/mlir/IR/xls_ops.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ def Xls_SignExtOp : Xls_UnaryOp<"sign_ext", [Pure, SameOperandsAndResultShape]>
418418
def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [
419419
TensorArrayTypeFungible,
420420
PredicatableOpInterface,
421-
CallOpInterface]> {
421+
CallOpInterface,
422+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
422423
let summary = "Receives a data value from a specified channel.";
423424
let description = [{
424425
Receives a data value from a specified channel. The type of the data value
@@ -440,7 +441,6 @@ def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [
440441
let assemblyFormat = [{
441442
$tkn `,` ($predicate^ `,`)? $channel attr-dict `:` type($result)
442443
}];
443-
let hasVerifier = 1;
444444
let extraClassDeclaration = [{
445445
::mlir::Value getCondition() {
446446
return getPredicate();
@@ -464,7 +464,8 @@ def Xls_BlockingReceiveOp : Xls_Op<"blocking_receive", [
464464
def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [
465465
TensorArrayTypeFungible,
466466
PredicatableOpInterface,
467-
CallOpInterface]> {
467+
CallOpInterface,
468+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
468469
let summary = "Receives a data value from a specified channel.";
469470
let description = [{
470471
Receives a data value from a specified channel. The type of the data value
@@ -488,7 +489,6 @@ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [
488489
let assemblyFormat = [{
489490
$tkn `,` ($predicate^ `,`)? $channel attr-dict `:` type($result)
490491
}];
491-
let hasVerifier = 1;
492492
let extraClassDeclaration = [{
493493
::mlir::Value getCondition() {
494494
return getPredicate();
@@ -514,7 +514,8 @@ def Xls_NonblockingReceiveOp : Xls_Op<"nonblocking_receive", [
514514
def Xls_SendOp : Xls_Op<"send", [
515515
TensorArrayTypeFungible,
516516
PredicatableOpInterface,
517-
CallOpInterface]> {
517+
CallOpInterface,
518+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
518519
let summary = "Sends data to a specified channel.";
519520
let description = [{
520521
Sends data to a specified channel. The type of the data values is determined
@@ -535,7 +536,6 @@ def Xls_SendOp : Xls_Op<"send", [
535536
let assemblyFormat = [{
536537
$tkn `,` $data `,` ($predicate^ `,`)? $channel attr-dict `:` type($data)
537538
}];
538-
let hasVerifier = 1;
539539
let extraClassDeclaration = [{
540540
::mlir::Value getCondition() {
541541
return getPredicate();

xls/contrib/mlir/testdata/array_to_bits.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,14 @@ func.func @call_dslx(%arg0: !xls.array<4 x i32>) -> !xls.array<4 x f32> attribut
263263
%8 = xls.array %1, %3, %5, %7 : (f32, f32, f32, f32) -> !xls.array<4 x f32>
264264
return %8 : !xls.array<4 x f32>
265265
}
266+
267+
// CHECK-LABEL: func.func @array_concat(
268+
// CHECK-SAME: %[[VAL_0:.*]]: i64,
269+
// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i128 attributes {xls = true} {
270+
// CHECK: %[[VAL_2:.*]] = xls.concat %[[VAL_0]], %[[VAL_1]] : (i64, i64) -> i128
271+
// CHECK: return %[[VAL_2]] : i128
272+
// CHECK: }
273+
func.func @array_concat(%arg0: !xls.array<2 x i32>, %arg1: !xls.array<2 x i32>) -> !xls.array<4 x i32> attributes {xls = true} {
274+
%0 = "xls.array_concat"(%arg0, %arg1) : (!xls.array<2 x i32>, !xls.array<2 x i32>) -> !xls.array<4 x i32>
275+
return %0 : !xls.array<4 x i32>
276+
}

xls/contrib/mlir/transforms/array_to_bits.cc

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/include/mlir/IR/OpDefinition.h"
2929
#include "mlir/include/mlir/IR/OperationSupport.h"
3030
#include "mlir/include/mlir/IR/PatternMatch.h"
31+
#include "mlir/include/mlir/IR/Threading.h"
3132
#include "mlir/include/mlir/IR/TypeUtilities.h"
3233
#include "mlir/include/mlir/IR/ValueRange.h"
3334
#include "mlir/include/mlir/IR/Visitors.h"
@@ -296,6 +297,24 @@ class LegalizeArrayZeroPattern : public OpConversionPattern<ArrayZeroOp> {
296297
}
297298
};
298299

300+
class LegalizeArrayConcatPattern : public OpConversionPattern<ArrayConcatOp> {
301+
using OpConversionPattern::OpConversionPattern;
302+
303+
LogicalResult matchAndRewrite(
304+
ArrayConcatOp op, OpAdaptor adaptor,
305+
ConversionPatternRewriter& rewriter) const override {
306+
(void)adaptor;
307+
SmallVector<Value> operands =
308+
CoerceFloats(adaptor.getOperands(), rewriter, op);
309+
if (operands.empty() && !adaptor.getOperands().empty()) {
310+
return failure();
311+
}
312+
rewriter.replaceOpWithNewOp<ConcatOp>(
313+
op, typeConverter->convertType(op.getType()), operands);
314+
return success();
315+
}
316+
};
317+
299318
class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
300319
public:
301320
void runOnOperation() override {
@@ -308,7 +327,9 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
308327
return all_of(op->getOperandTypes(), is_legal) &&
309328
all_of(op->getResultTypes(), is_legal);
310329
});
311-
target.addIllegalOp<VectorizedCallOp>();
330+
target.addIllegalOp<VectorizedCallOp, ArrayOp, ArrayUpdateOp, ArraySliceOp,
331+
ArrayIndexOp, ArrayIndexStaticOp, ArrayZeroOp,
332+
ArrayConcatOp>();
312333
RewritePatternSet chanPatterns(&getContext());
313334
chanPatterns.add<LegalizeChanOpPattern>(typeConverter, &getContext());
314335
FrozenRewritePatternSet frozenChanPatterns(std::move(chanPatterns));
@@ -323,6 +344,7 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
323344
LegalizeArrayIndexPattern,
324345
LegalizeArrayIndexStaticPattern,
325346
LegalizeArrayZeroPattern,
347+
LegalizeArrayConcatPattern,
326348
LegalizeGenericOpPattern
327349
// clang-format on
328350
>(typeConverter, &getContext());
@@ -338,10 +360,11 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
338360
});
339361
FrozenRewritePatternSet frozenRegionPatterns(std::move(regionPatterns));
340362

363+
SmallVector<XlsRegionOpInterface> regions;
341364
getOperation()->walk([&](Operation* op) {
342365
if (auto interface = dyn_cast<XlsRegionOpInterface>(op)) {
343366
if (interface.isSupportedRegion()) {
344-
runOnOperation(interface, target, frozenRegionPatterns);
367+
regions.push_back(interface);
345368
return WalkResult::skip();
346369
}
347370
} else if (auto chanOp = dyn_cast<ChanOp>(op)) {
@@ -350,6 +373,11 @@ class ArrayToBitsPass : public impl::ArrayToBitsPassBase<ArrayToBitsPass> {
350373
}
351374
return WalkResult::advance();
352375
});
376+
377+
mlir::parallelForEach(
378+
&getContext(), regions, [&](XlsRegionOpInterface interface) {
379+
runOnOperation(interface, target, frozenRegionPatterns);
380+
});
353381
}
354382

355383
void runOnOperation(ChanOp operation, ConversionTarget& target,

0 commit comments

Comments
 (0)