|
10 | 10 | #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H |
11 | 11 |
|
12 | 12 | #include "gc/Dialect/Linalgx/LinalgxOps.h" |
| 13 | +#include "mlir/Dialect/DLTI/DLTI.h" |
13 | 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | | -#include <cstring> |
| 15 | +#include "mlir/Interfaces/DataLayoutInterfaces.h" |
15 | 16 |
|
16 | 17 | namespace mlir { |
17 | 18 | namespace gc { |
18 | 19 |
|
19 | 20 | using namespace mlir; |
20 | 21 |
|
21 | | -// A mock for the taget information |
22 | | -// TODO: replace it with upstream hardware description model |
23 | 22 | struct SystemDesc { |
24 | | - |
25 | | - static int getPositiveIntFromStr(char *str, int defaultValue = 1) { |
26 | | - if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') { |
27 | | - return defaultValue; |
28 | | - } |
29 | | - auto val = std::stoi(str); |
30 | | - return val > 0 ? val : defaultValue; |
31 | | - } |
32 | | - |
33 | 23 | // get runtime OMP_NUM_THREADS |
34 | 24 | uint32_t getNumThreads() { |
35 | | - char *numThreads = getenv("OMP_NUM_THREADS"); |
36 | | - return getPositiveIntFromStr(numThreads, 1); |
| 25 | + std::optional<Attribute> numThreads = layout.getDevicePropertyValue( |
| 26 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 27 | + Builder(ctx).getStringAttr("num_threads")); |
| 28 | + if (numThreads && isa<IntegerAttr>(*numThreads)) { |
| 29 | + return dyn_cast<IntegerAttr>(*numThreads).getInt(); |
| 30 | + } |
| 31 | + return 1; |
37 | 32 | } |
38 | 33 | // get cache size by cacheLevel |
39 | 34 | size_t getCacheSize(uint8_t cacheLevel) { |
40 | 35 | if (cacheLevel == 1) { |
41 | | - char *cacheSize = getenv("L1_CACHE_SIZE"); |
42 | | - return getPositiveIntFromStr(cacheSize, 0); |
| 36 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 37 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 38 | + Builder(ctx).getStringAttr("L1_cache_size_in_bytes")); |
| 39 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 40 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 41 | + } |
43 | 42 | } else if (cacheLevel == 2) { |
44 | | - char *cacheSize = getenv("L2_CACHE_SIZE"); |
45 | | - return getPositiveIntFromStr(cacheSize, 0); |
| 43 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 44 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 45 | + Builder(ctx).getStringAttr("L2_cache_size_in_bytes")); |
| 46 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 47 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 48 | + } |
46 | 49 | } else if (cacheLevel == 3) { |
47 | | - char *cacheSize = getenv("L3_CACHE_SIZE"); |
48 | | - return getPositiveIntFromStr(cacheSize, 0); |
| 50 | + std::optional<Attribute> cacheSize = layout.getDevicePropertyValue( |
| 51 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 52 | + Builder(ctx).getStringAttr("L3_cache_size_in_bytes")); |
| 53 | + if (cacheSize && isa<IntegerAttr>(*cacheSize)) { |
| 54 | + return dyn_cast<IntegerAttr>(*cacheSize).getInt(); |
| 55 | + } |
49 | 56 | } |
50 | 57 | return 0; |
51 | 58 | } |
52 | 59 |
|
53 | 60 | // get the maximum vector length in bits |
54 | 61 | size_t getMaxVectorLength() { |
55 | | - char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH"); |
56 | | - return getPositiveIntFromStr(maxVectorLanes, 512); |
| 62 | + std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue( |
| 63 | + Builder(ctx).getStringAttr("CPU" /* device ID*/), |
| 64 | + Builder(ctx).getStringAttr("max_vector_width")); |
| 65 | + if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) { |
| 66 | + return dyn_cast<IntegerAttr>(*maxVectorLength).getInt(); |
| 67 | + } |
| 68 | + return 512; |
57 | 69 | } |
| 70 | + |
| 71 | + SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {} |
| 72 | + |
| 73 | +private: |
| 74 | + DataLayout layout; |
| 75 | + MLIRContext *ctx; |
58 | 76 | }; |
59 | 77 |
|
60 | 78 | // The configuration for matmul tiling |
61 | 79 | // TODO: support batch matmul |
62 | 80 | struct MatmulConfig { |
63 | 81 | // The number of threads distributed to M, N, K |
64 | 82 | uint32_t MThreads, NThreads, KThreads; |
65 | | - // The innermost block size for M, N, K which will be directly converted to |
66 | | - // brgemm. |
67 | | - uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; |
68 | 83 | // The outer block size for M, N, K which will be used to decide the loop tile |
69 | 84 | // size in single thread |
70 | 85 | uint32_t MBlock, NBlock, KBlock; |
| 86 | + // The innermost block size for M, N, K which will be directly converted to |
| 87 | + // brgemm. |
| 88 | + uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock; |
71 | 89 | }; |
72 | 90 |
|
73 | 91 | enum DimType { Batch, M, N, K }; |
|
0 commit comments