|
9 | 9 | #include "gc/Dialect/Linalgx/LinalgxOps.h"
|
10 | 10 | #include "gc/Dialect/Linalgx/LinalgxDialect.h"
|
11 | 11 | #include "mlir/IR/OpImplementation.h"
|
| 12 | +#include <utility> |
12 | 13 |
|
13 | 14 | //===----------------------------------------------------------------------===//
|
14 | 15 | // Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
|
@@ -613,6 +614,58 @@ void MultiBatchMatmulOp::getEffects(
|
613 | 614 | getDpsInits());
|
614 | 615 | }
|
615 | 616 |
|
| 617 | +//===----------------------------------------------------------------------===// |
| 618 | +// ScaledDotProductAttentionOp |
| 619 | +//===----------------------------------------------------------------------===// |
| 620 | + |
| 621 | +LogicalResult ScaledDotProductAttentionOp::verify() { return success(); } |
| 622 | + |
| 623 | +/// Given an N-dimensional tensor x, this method converts |
| 624 | +/// softmax(x) to the following sequence of operations: |
| 625 | +/// |
| 626 | +/// 1. transpose ins[1] |
| 627 | +/// 2. matmul ins[0] @ 1 |
| 628 | +/// |
| 629 | +FailureOr<SmallVector<Value>> |
| 630 | +ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) { |
| 631 | + OpBuilder::InsertionGuard guard(b); |
| 632 | + b.setInsertionPoint(*this); |
| 633 | + Location loc = getLoc(); |
| 634 | + Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2], |
| 635 | + mask = getInputs()[3]; |
| 636 | + auto dtype = cast<RankedTensorType>(query.getType()).getElementType(); |
| 637 | + auto shape = cast<RankedTensorType>(query.getType()).getShape(); |
| 638 | + |
| 639 | + SmallVector<int64_t> permutation{0, 1, 3, 2}; |
| 640 | + SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]}; |
| 641 | + auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype); |
| 642 | + auto transpose = b.create<linalg::TransposeOp>( |
| 643 | + /*location=*/loc, |
| 644 | + /*inputs=*/key, |
| 645 | + /*outputs=*/transposeOut, |
| 646 | + /*permutation=*/permutation); |
| 647 | + |
| 648 | + SmallVector<int64_t> matmulQKShape{shape[0], shape[1], shape[2], shape[2]}; |
| 649 | + auto matmulQKOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype); |
| 650 | + auto matmulQK = b.create<linalgx::MultiBatchMatmulOp>( |
| 651 | + /*location=*/loc, matmulQKOut.getResult().getType(), |
| 652 | + /*inputs=*/ValueRange{query, transpose->getResult(0)}, |
| 653 | + /*outputs=*/ValueRange{matmulQKOut.getResult()}); |
| 654 | + |
| 655 | + auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype); |
| 656 | + auto add = b.create<linalg::AddOp>( |
| 657 | + /*location=*/loc, addOut.getResult().getType(), |
| 658 | + /*inputs=*/ValueRange{matmulQK->getResult(0), mask}, |
| 659 | + /*outputs=*/ValueRange{addOut.getResult()}); |
| 660 | + |
| 661 | + auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype); |
| 662 | + auto matmulV = b.create<linalgx::MultiBatchMatmulOp>( |
| 663 | + /*location=*/loc, matmulVOut.getResult().getType(), |
| 664 | + /*inputs=*/ValueRange{add->getResult(0), value}, |
| 665 | + /*outputs=*/ValueRange{matmulVOut.getResult()}); |
| 666 | + return SmallVector<Value>{matmulV.getResults()[0]}; |
| 667 | +} |
| 668 | + |
616 | 669 | /////// Operations corresponding to library calls defined with Tablegen ////////
|
617 | 670 |
|
618 | 671 | #define GET_OP_CLASSES
|
|
0 commit comments