@@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
2929
3030template <typename T>
3131static llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
32- std::vector<T> arry ) {
32+ std::vector<T> array ) {
3333 ss << " [" ;
34- for (auto [idx, a] : llvm::enumerate (arry)) {
35- if (idx != 0 ) {
36- ss << " , " ;
37- }
38- ss << a;
39- }
34+ llvm::interleaveComma (array, ss);
4035 ss << " ]" ;
4136 return ss;
4237}
@@ -174,7 +169,7 @@ std::vector<MatmulConfig>
174169filterConfigByCostModel (ArrayRef<MatmulConfig> configs,
175170 linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape,
176171 SystemDesc &sysDesc, const CostModelFn &costModel,
177- float eliminationRatio = 0.5 , float threshold = -1 ) {
172+ float preserveRatio = 0.5 , float threshold = -1 ) {
178173 std::vector<MatmulConfig> result;
179174 std::vector<float > costs;
180175 std::vector<size_t > idx;
@@ -185,8 +180,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
185180 std::stable_sort (idx.begin (), idx.end (), [&costs](size_t i1, size_t i2) {
186181 return costs[i1] < costs[i2];
187182 });
188- double thresholdCost =
189- costs[idx[(size_t )(eliminationRatio * configs.size ())]];
183+ double thresholdCost = costs[idx[(size_t )(preserveRatio * configs.size ())]];
190184 thresholdCost =
191185 threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
192186 for (size_t i = 0 ; i < configs.size (); i++) {
@@ -210,6 +204,11 @@ std::vector<MatmulConfig>
210204prepareConfigCandidates (Operation *root, SystemDesc &sysDesc,
211205 ArrayRef<uint32_t > shape,
212206 ArrayRef<uint32_t > givenInnermostBlock) {
207+ if (shape.size () < 3 ) {
208+ LLVM_DEBUG (llvm::dbgs ()
209+ << " The shape is invalid, no candidate is generated\n " );
210+ return {};
211+ }
213212 std::vector<MatmulConfig> configs;
214213 uint32_t threads = sysDesc.getNumThreads ();
215214 std::vector<uint32_t > MThreadsCandidates =
@@ -290,6 +289,21 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
290289 return configs;
291290}
292291
292+ bool validateConfig (const MatmulConfig &cfg) {
293+ if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294+ cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295+ cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296+ cfg.innerMostKBlock <= 0 ) {
297+ return false ;
298+ }
299+ if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300+ cfg.NBlock % cfg.innerMostNBlock != 0 ||
301+ cfg.KBlock % cfg.innerMostKBlock != 0 ) {
302+ return false ;
303+ }
304+ return true ;
305+ }
306+
293307// read the config from the attributes for tuning
294308bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
295309 size_t cfgItemCnt = 0 ;
@@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
323337 cfgItemCnt++;
324338 }
325339 }
326- return cfgItemCnt == 9 ;
340+ if (validateConfig (config)) {
341+ return cfgItemCnt == 9 ;
342+ } else {
343+ LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
344+ return false ;
345+ }
327346}
328347
329348// Analyze the workload and system description to generate the default config
0 commit comments