Skip to content

Commit

Permalink
add module with device to aievec tests
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Jan 30, 2025
1 parent 483d9ec commit e59dbe9
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 64 deletions.
20 changes: 12 additions & 8 deletions compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,17 +663,20 @@ class MatMulOpConversion
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

llvm::errs() << "In MatMulOpConversion match and rewrite" << '\n';

DataPathConfiguration configuration;
Value lhs = op.getLhs();

assert(!isa<arith::ExtSIOp>(lhs.getDefiningOp()) &&
!isa<arith::ExtUIOp>(lhs.getDefiningOp()) &&
"Not supporting extsi/extui here (should be canonned already)");
if (auto d = lhs.getDefiningOp()) {
assert(!isa<arith::ExtSIOp>(d) && !isa<arith::ExtUIOp>(d) &&
"Not supporting extsi/extui here (should be canonned already)");
}

Value rhs = op.getRhs();
assert(!isa<arith::ExtSIOp>(rhs.getDefiningOp()) &&
!isa<arith::ExtUIOp>(rhs.getDefiningOp()) &&
"Not supporting extsi/extui here (should be canonned already)");
if (auto d = rhs.getDefiningOp()) {
assert(!isa<arith::ExtSIOp>(d) && !isa<arith::ExtUIOp>(d) &&
"Not supporting extsi/extui here (should be canonned already)");
}

Value acc = op.getAcc();
auto accVecTy = cast<VectorType>(acc.getType());
Expand Down Expand Up @@ -1040,7 +1043,8 @@ struct ConvertAIEVecToLLVMPass
if (!maybeDevice.has_value()) {
getOperation()->emitOpError(
"No AMDAIEDevice found in the target attribute.");
assert(false && "No AMDAIEDevice found in the target attribute.");
signalPassFailure();
// assert(false && "No AMDAIEDevice found in the target attribute.");
}

AMDAIE::AMDAIEDevice device = maybeDevice.value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,38 @@ struct LowerVectorContractionOpToAIEVecMatMulPattern
return b.create<vector::ShapeCastOp>(v.getLoc(), newVecTy, v).getResult();
}

LogicalResult foo(vector::ContractionOp contractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto getMatMulOperand = [&](Value v) {
Value sourceOfWidening = getSourceOfWideningOp(v).value_or(nullptr);
v = sourceOfWidening ? sourceOfWidening : v;
v = reshapeLeadingUnitDims(rewriter, v);
return v;
};

// TODO(newling) keep pushing on this:

Value lhs = getMatMulOperand(adaptor.getLhs());
Value rhs = getMatMulOperand(adaptor.getRhs());

Type accType = adaptor.getAcc().getType();
Value acc = getMatMulOperand(adaptor.getAcc());

auto mm = rewriter.create<aievec::MatMulOp>(contractOp.getLoc(),
acc.getType(), lhs, rhs, acc);

Value result =
rewriter.create<vector::ShapeCastOp>(contractOp.getLoc(), accType, mm);

rewriter.replaceOp(contractOp, result);
return success();
}

LogicalResult matchAndRewrite(
vector::ContractionOp contractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return foo(contractOp, adaptor, rewriter);

auto lhs = reshapeLeadingUnitDims(rewriter, adaptor.getLhs());
auto rhs = reshapeLeadingUnitDims(rewriter, adaptor.getRhs());
auto acc = reshapeLeadingUnitDims(rewriter, adaptor.getAcc());
Expand Down
83 changes: 28 additions & 55 deletions compiler/plugins/target/AMD-AIE/aievec/test/matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
// RUN: iree-opt %s -split-input-file -convert-aievec-to-llvm | FileCheck %s

func.func @matmul(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}
#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {

// CHECK-LABEL: @matmul
// CHECK-LABEL: @matmulbf16bf16f32
// CHECK-SAME: %[[A:.*]]: vector<4x8xbf16>
// CHECK-SAME: %[[B:.*]]: vector<8x4xbf16>
// CHECK-SAME: %[[C:.*]]: vector<4x4xf32>
Expand All @@ -27,17 +23,14 @@ func.func @matmul(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>,
// CHECK: %[[R:.*]] = vector.shape_cast %[[BCR]] :
// CHECK-SAME: vector<16xf32> to vector<4x4xf32>
// CHECK: return %[[R]] : vector<4x4xf32>

// -----

func.func @matmul(%A : vector<4x8xi8>, %B : vector<8x8xi8>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
func.func @matmulbf16bf16f32(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>,
%C : vector<4x4xf32>) -> vector<4x4xf32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16>
into vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// CHECK-LABEL: @matmul
// CHECK-LABEL: @matmuli8i8i32
// CHECK-SAME: %[[A:.*]]: vector<4x8xi8>
// CHECK-SAME: %[[B:.*]]: vector<8x8xi8>
// CHECK-SAME: %[[C:.*]]: vector<4x8xi32>
Expand All @@ -64,47 +57,27 @@ func.func @matmul(%A : vector<4x8xi8>, %B : vector<8x8xi8>,
// CHECK: %[[R:.*]] = vector.shape_cast %[[BCR]] :
// CHECK-SAME: vector<32xi32> to vector<4x8xi32>
// CHECK: return %[[R]] : vector<4x8xi32>
func.func @matmuli8i8i32(%A : vector<4x8xi8>, %B : vector<8x8xi8>,
%C : vector<4x8xi32>) -> vector<4x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8>
into vector<4x8xi32>
return %0 : vector<4x8xi32>
}
}

// -----

func.func @matmul(%A : vector<4x2xi32>, %B : vector<2x4xi16>,
%C : vector<4x4xi64>) -> vector<4x4xi64> {
%0 = aievec.matmul %A, %B, %C : vector<4x2xi32>, vector<2x4xi16>
into vector<4x4xi64>
return %0 : vector<4x4xi64>
// TODO(newling)

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu4"}>
module attributes {hal.executable.target = #foo} {

// CHECK-LABEL: @matmuli8i8i32
func.func @matmuli8i8i32(%A : vector<8x8xi8>, %B : vector<8x8xi8>,
%C : vector<8x8xi32>) -> vector<8x8xi32> {
%0 = aievec.matmul %A, %B, %C : vector<8x8xi8>, vector<8x8xi8>
into vector<8x8xi32>
return %0 : vector<8x8xi32>
}

// CHECK-LABEL: @matmul
// CHECK-SAME: %[[A:.*]]: vector<4x2xi32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xi16>
// CHECK-SAME: %[[C:.*]]: vector<4x4xi64>
// CHECK: %[[FA:.*]] = vector.shape_cast %[[A]] :
// CHECK-SAME: vector<4x2xi32> to vector<8xi32>
// CHECK: %[[FB:.*]] = vector.shape_cast %[[B]] :
// CHECK-SAME: vector<2x4xi16> to vector<8xi16>
// CHECK: %[[FC:.*]] = vector.shape_cast %[[C]] :
// CHECK-SAME: vector<4x4xi64> to vector<16xi64>
// CHECK: %[[CONF:.*]] = llvm.mlir.constant(770 : i32) : i32
// CHECK: %[[C0I32:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[IFA2512b:.*]] = llvm.bitcast %[[FA]] : vector<8xi32> to
// CHECK-SAME: vector<8xi32>
// CHECK: %[[IFA:.*]] = "xllvm.intr.aie2.set.I512.I256"(%[[IFA2512b]],
// CHECK-SAME: %[[C0I32]]) : (vector<8xi32>, i32) ->
// CHECK-SAME: vector<16xi32>
// CHECK: %[[BCA:.*]] = llvm.bitcast %[[IFA]] : vector<16xi32> to
// CHECK-SAME: vector<64xi8>
// CHECK: %[[IFB2512b:.*]] = llvm.bitcast %[[FB]] : vector<8xi16> to
// CHECK-SAME: vector<4xi32>
// CHECK: %[[IFB:.*]] = "xllvm.intr.aie2.set.I512.I128"(%[[IFB2512b]]) :
// CHECK-SAME: (vector<4xi32>) -> vector<16xi32>
// CHECK: %[[BCB:.*]] = llvm.bitcast %[[IFB]] : vector<16xi32> to
// CHECK-SAME: vector<16xi32>
// CHECK: %[[RACC:.*]] =
// CHECK-SAME: "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(
// CHECK-SAME: %[[BCA]], %[[BCB]], %[[FC]], %[[CONF]]) :
// CHECK-SAME: (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32)
// CHECK-SAME: -> vector<16xi64>
// CHECK: %[[BCR:.*]] = llvm.bitcast %[[RACC]] : vector<16xi64> to vector<16xi64>
// CHECK: %[[R:.*]] = vector.shape_cast %[[BCR]] :
// CHECK-SAME: vector<16xi64> to vector<4x4xi64>
// CHECK: return %[[R]] : vector<4x4xi64>
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// CHECK-SAME: %[[V0:[a-zA-Z0-9]+]]: vector<16xbf16>,
// CHECK-SAME: %[[V1:.*]]: vector<16xbf16>,
// CHECK-SAME: %[[V2:.*]]: vector<16xf32>)

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {
func.func @mac_flat_vec(%v0 : vector<16xbf16>,
%v1 : vector<16xbf16>,
%v2 : vector<16xf32>) -> vector<16xf32> {
Expand Down Expand Up @@ -37,13 +40,17 @@ func.func @mac_flat_vec(%v0 : vector<16xbf16>,
%0 = aievec.mac_elem %v0, %v1, %v2 : vector<16xbf16>, vector<16xbf16>, vector<16xf32>
return %0 : vector<16xf32>
}
}

// -----

// CHECK-LABEL: mac_2d_vec
// CHECK-SAME: %[[V02D:[a-zA-Z0-9]+]]: vector<4x4xbf16>,
// CHECK-SAME: %[[V12D:.*]]: vector<4x4xbf16>,
// CHECK-SAME: %[[V22D:.*]]: vector<4x4xf32>)

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {
func.func @mac_2d_vec(%v0 : vector<4x4xbf16>,
%v1 : vector<4x4xbf16>,
%v2 : vector<4x4xf32>) -> vector<4x4xf32> {
Expand Down Expand Up @@ -87,3 +94,4 @@ func.func @mac_2d_vec(%v0 : vector<4x4xbf16>,
%0 = aievec.mac_elem %v0, %v1, %v2 : vector<4x4xbf16>, vector<4x4xbf16>, vector<4x4xf32>
return %0 : vector<4x4xf32>
}
}
16 changes: 16 additions & 0 deletions compiler/plugins/target/AMD-AIE/aievec/test/test-shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

// CHECK-LABEL: @shuffle_single_operand_nocast
// CHECK-SAME: %[[LHS:.*]]: vector<16xi32>

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {
func.func @shuffle_single_operand_nocast(%lhs : vector<16xi32>)
-> vector<16xi32> {
// CHECK: %[[M:.*]] = llvm.mlir.constant(34 : i32) : i32
Expand All @@ -12,12 +15,16 @@ func.func @shuffle_single_operand_nocast(%lhs : vector<16xi32>)
// CHECK: return %[[R]] : vector<16xi32>
return %0 : vector<16xi32>
}
}

// -----

// CHECK-LABEL: @shuffle_two_operands_nocast
// CHECK-SAME: %[[LHS:.*]]: vector<16xi32>,
// CHECK-SAME: %[[RHS:.*]]: vector<16xi32>

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {
func.func @shuffle_two_operands_nocast(%lhs : vector<16xi32>,
%rhs : vector<16xi32>)
-> vector<16xi32> {
Expand All @@ -28,11 +35,15 @@ func.func @shuffle_two_operands_nocast(%lhs : vector<16xi32>,
// CHECK: return %[[R]] : vector<16xi32>
return %0 : vector<16xi32>
}
}

// -----

// CHECK-LABEL: @shuffle_single_operand_cast
// CHECK-SAME: %[[V:.*]]: vector<32xbf16>

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {
func.func @shuffle_single_operand_cast(%lhs : vector<32xbf16>)
-> vector<32xbf16> {
// CHECK: %[[M:.*]] = llvm.mlir.constant(42 : i32) : i32
Expand All @@ -45,12 +56,16 @@ func.func @shuffle_single_operand_cast(%lhs : vector<32xbf16>)
// CHECK: return %[[R]] : vector<32xbf16>
return %0 : vector<32xbf16>
}
}

// -----

// CHECK-LABEL: @shuffle_two_operands_cast
// CHECK-SAME: %[[LV:.*]]: vector<32xbf16>,
// CHECK-SAME: %[[RV:.*]]: vector<32xbf16>

#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}>
module attributes {hal.executable.target = #foo} {
func.func @shuffle_two_operands_cast(%lhs : vector<32xbf16>,
%rhs : vector<32xbf16>)
-> vector<32xbf16> {
Expand All @@ -64,3 +79,4 @@ func.func @shuffle_two_operands_cast(%lhs : vector<32xbf16>,
// CHECK: return %[[R]] : vector<32xbf16>
return %0 : vector<32xbf16>
}
}
Loading

0 comments on commit e59dbe9

Please sign in to comment.