1- // ===-- MatmulConfigAnalysis.h - DESC -------------------------- -*- C++ -*-===//
1+ // ===-- MatmulConfigAnalysis.h - the analysis for matmul config -*- C++ -*-===//
22//
33// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
1111
1212#include " gc/Dialect/Linalgx/LinalgxOps.h"
1313#include " mlir/Dialect/Linalg/IR/Linalg.h"
14- #include " mlir/Dialect/Tensor/IR/Tensor.h"
15- #include " mlir/Pass/Pass.h"
16- #include " mlir/Support/LLVM.h"
17- #include " llvm/ADT/DenseMap.h"
18- #include < llvm/Support/Debug.h>
19- #include < memory>
20- #include < numeric>
14+ #include < cstring>
2115
2216namespace mlir {
2317namespace gc {
2418
2519using namespace mlir ;
2620
21+ // A mock for the taget information
22+ // TODO: replace it with upstream hardware description model
2723struct SystemDesc {
24+
25+ static int getPositiveIntFromStr (char *str, int defaultValue = 1 ) {
26+ if (!str || strlen (str) == 0 || str[0 ] > ' 9' || str[0 ] < ' 0' ) {
27+ return defaultValue;
28+ }
29+ auto val = std::stoi (str);
30+ return val > 0 ? val : defaultValue;
31+ }
32+
2833 // get runtime OMP_NUM_THREADS
2934 uint32_t getNumThreads () {
3035 char *numThreads = getenv (" OMP_NUM_THREADS" );
31- if (numThreads) {
32- return std::stoi (numThreads);
33- }
34- return 1 ;
36+ return getPositiveIntFromStr (numThreads, 1 );
3537 }
3638 // get cache size by cacheLevel
3739 size_t getCacheSize (uint8_t cacheLevel) {
3840 if (cacheLevel == 1 ) {
3941 char *cacheSize = getenv (" L1_CACHE_SIZE" );
40- if (cacheSize) {
41- return std::stoi (cacheSize);
42- }
42+ return getPositiveIntFromStr (cacheSize, 0 );
4343 } else if (cacheLevel == 2 ) {
4444 char *cacheSize = getenv (" L2_CACHE_SIZE" );
45- if (cacheSize) {
46- return std::stoi (cacheSize);
47- }
45+ return getPositiveIntFromStr (cacheSize, 0 );
4846 } else if (cacheLevel == 3 ) {
4947 char *cacheSize = getenv (" L3_CACHE_SIZE" );
50- if (cacheSize) {
51- return std::stoi (cacheSize);
52- }
48+ return getPositiveIntFromStr (cacheSize, 0 );
5349 }
5450 return 0 ;
5551 }
5652
57- SmallVector<size_t > getContractionOperationMaxVectorLength () {
58- return {512UL , 512UL };
53+ // get the maximum vector length in bits
54+ size_t getMaxVectorLength () {
55+ char *maxVectorLanes = getenv (" MAX_VECTOR_LENGTH" );
56+ return getPositiveIntFromStr (maxVectorLanes, 512 );
5957 }
6058};
6159
60+ // The configuration for matmul tiling
61+ // TODO: support batch matmul
6262struct MatmulConfig {
63- uint32_t MBlock, NBlock, KBlock;
63+ // The number of threads distributed to M, N, K
6464 uint32_t MThreads, NThreads, KThreads;
65+ // The innermost block size for M, N, K which will be directly converted to
66+ // brgemm.
6567 uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
66- friend llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
67- const MatmulConfig &config);
68+ // The outer block size for M, N, K which will be used to decide the loop tile
69+ // size in single thread
70+ uint32_t MBlock, NBlock, KBlock;
6871};
6972
7073enum DimType { Batch, M, N, K };
7174
72- [[maybe_unused]] static SmallVector<unsigned >
73- extractDimTypeIdx (ArrayRef<DimType> tyList, DimType ty) {
75+ // Extract the index of the given DimType in the DimType list
76+ inline SmallVector<unsigned > extractDimTypeIdx (ArrayRef<DimType> tyList,
77+ DimType ty) {
7478 SmallVector<unsigned > idxList;
7579 for (auto [idx, type] : llvm::enumerate (tyList)) {
7680 if (type == ty) {
@@ -80,9 +84,11 @@ extractDimTypeIdx(ArrayRef<DimType> tyList, DimType ty) {
8084 return idxList;
8185}
8286
83- static FailureOr<SmallVector<SmallVector<DimType>>>
87+ // Get the operand dim type for every operand for the given linalg op
88+ inline FailureOr<SmallVector<SmallVector<DimType>>>
8489getOprandDimType (linalg::LinalgOp &linalgOp) {
85- if (isa<linalg::MatmulOp>(linalgOp)) {
90+ // TODO: replace the linalgx op with generic op
91+ if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
8692 return SmallVector<SmallVector<DimType>>{
8793 SmallVector<DimType>{DimType::M, DimType::K},
8894 SmallVector<DimType>{DimType::K, DimType::N},
@@ -104,10 +110,31 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
104110 SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
105111 SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
106112 SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
113+ } else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
114+ return SmallVector<SmallVector<DimType>>{
115+ SmallVector<DimType>{DimType::K, DimType::M},
116+ SmallVector<DimType>{DimType::K, DimType::N},
117+ SmallVector<DimType>{DimType::M, DimType::N}};
118+ } else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
119+ return SmallVector<SmallVector<DimType>>{
120+ SmallVector<DimType>{DimType::M, DimType::K},
121+ SmallVector<DimType>{DimType::N, DimType::K},
122+ SmallVector<DimType>{DimType::M, DimType::N}};
123+ } else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
124+ return SmallVector<SmallVector<DimType>>{
125+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
126+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
127+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
128+ } else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
129+ return SmallVector<SmallVector<DimType>>{
130+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
131+ SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
132+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
107133 }
108134 return failure ();
109135}
110136
137+ // The analysis to extract the matmul configuration from the given linalg op
111138struct MatmulConfigAnalysis {
112139public:
113140 explicit MatmulConfigAnalysis (Operation *root);
0 commit comments