From 502aa5ceb0868fd2f56224cfeb77f1a98935fffa Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre@amd.com>
Date: Mon, 6 May 2024 15:08:17 +0200
Subject: [PATCH] TosaToLinalg: Allow to skip TOSA validation

---
 .../mlir/Conversion/TosaToLinalg/TosaToLinalg.h    |  2 +-
 .../Conversion/TosaToLinalg/TosaToLinalgPass.cpp   | 14 ++++++++------
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index cae0408c3d163..6b23d4c82359c 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,7 +39,7 @@ void addTosaToLinalgPasses(
     const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions =
         TosaToLinalgNamedOptions(),
     // Note: Default to 'none' level unless otherwise specified.
-    tosa::TosaValidationOptions const &validationOptions = {
+    std::optional<tosa::TosaValidationOptions> validationOptions = tosa::TosaValidationOptions{
         tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
 
 /// Populates TOSA to linalg pipelines
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 091998b7af361..ca0a41207dcce 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -90,7 +90,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 void mlir::tosa::addTosaToLinalgPasses(
     OpPassManager &pm, const TosaToLinalgOptions &options,
     const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
-    tosa::TosaValidationOptions const &validationOptions) {
+    std::optional<tosa::TosaValidationOptions> validationOptions) {
   // Optional decompositions are designed to benefit linalg.
   if (!options.disableTosaDecompositions)
     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -107,7 +107,8 @@ void mlir::tosa::addTosaToLinalgPasses(
   pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
       tosaFoldOptions));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
-  pm.addPass(tosa::createTosaValidation(validationOptions));
+  if (validationOptions)
+    pm.addPass(tosa::createTosaValidation(*validationOptions));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
 }
 
@@ -124,11 +125,12 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
       [](OpPassManager &pm) {
         TosaToLinalgOptions tosaToLinalgOptions;
         TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
+        TosaValidationOptions validationOptions;
+        validationOptions.profile = tosa::TosaProfileEnum::BaseInference;
+        validationOptions.StrictOperationSpecAlignment = true;
+        validationOptions.level = tosa::TosaLevelEnum::EightK;
         tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
                                     tosaToLinalgNamedOptions,
-                                    /* validationOptions = */
-                                    {tosa::TosaProfileEnum::BaseInference,
-                                     /* StrictOperationSpecAlignment = */ true,
-                                     tosa::TosaLevelEnum::EightK});
+                                    validationOptions);
       });
 }