1+ // ===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===//
2+ //
3+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+
9+ #ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
10+ #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
11+
12+ #include " gc/Dialect/Linalgx/LinalgxOps.h"
13+ #include " mlir/Dialect/DLTI/DLTI.h"
14+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
15+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
16+
17+ namespace mlir {
18+ namespace gc {
19+
20+ using namespace mlir ;
21+
22+ // The configuration for matmul tiling
23+ // TODO: support batch matmul
24+ struct MatmulConfig {
25+ // The number of threads distributed to M, N, K
26+ uint32_t MThreads, NThreads, KThreads;
27+ // The outer block size for M, N, K which will be used to decide the loop tile
28+ // size in single thread
29+ uint32_t MBlock, NBlock, KBlock;
30+ // The innermost block size for M, N, K which will be directly converted to
31+ // brgemm.
32+ uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
33+ };
34+
35+ enum DimType { Batch, M, N, K };
36+
37+ // Extract the index of the given DimType in the DimType list
38+ inline SmallVector<unsigned > extractDimTypeIdx (ArrayRef<DimType> tyList,
39+ DimType ty) {
40+ SmallVector<unsigned > idxList;
41+ for (auto [idx, type] : llvm::enumerate (tyList)) {
42+ if (type == ty) {
43+ idxList.push_back (idx);
44+ }
45+ }
46+ return idxList;
47+ }
48+
49+ // Get the operand dim type for every operand for the given linalg op
50+ inline FailureOr<SmallVector<SmallVector<DimType>>>
51+ getOprandDimType (linalg::LinalgOp &linalgOp) {
52+ // TODO: replace the linalgx op with generic op
53+ if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
54+ return SmallVector<SmallVector<DimType>>{
55+ SmallVector<DimType>{DimType::M, DimType::K},
56+ SmallVector<DimType>{DimType::K, DimType::N},
57+ SmallVector<DimType>{DimType::M, DimType::N}};
58+ } else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
59+ return SmallVector<SmallVector<DimType>>{
60+ SmallVector<DimType>{DimType::M, DimType::K},
61+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
62+ DimType::K},
63+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
64+ } else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
65+ return SmallVector<SmallVector<DimType>>{
66+ SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
67+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
68+ DimType::K},
69+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
70+ } else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
71+ return SmallVector<SmallVector<DimType>>{
72+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
73+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
74+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
75+ } else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
76+ return SmallVector<SmallVector<DimType>>{
77+ SmallVector<DimType>{DimType::K, DimType::M},
78+ SmallVector<DimType>{DimType::K, DimType::N},
79+ SmallVector<DimType>{DimType::M, DimType::N}};
80+ } else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
81+ return SmallVector<SmallVector<DimType>>{
82+ SmallVector<DimType>{DimType::M, DimType::K},
83+ SmallVector<DimType>{DimType::N, DimType::K},
84+ SmallVector<DimType>{DimType::M, DimType::N}};
85+ } else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
86+ return SmallVector<SmallVector<DimType>>{
87+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
88+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
89+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
90+ } else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
91+ return SmallVector<SmallVector<DimType>>{
92+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
93+ SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
94+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
95+ }
96+ return failure ();
97+ }
98+
99+ // The analysis to extract the matmul configuration from the given linalg op
100+ struct MatmulConfigAnalysis {
101+ public:
102+ explicit MatmulConfigAnalysis (Operation *root);
103+ MatmulConfig getConfig () { return config; }
104+
105+ private:
106+ MatmulConfig config;
107+ };
108+
109+ } // namespace gc
110+ } // namespace mlir
111+
112+ #endif
0 commit comments