@@ -2745,3 +2745,58 @@ func.func @conv3d_ncdhw(
2745
2745
-> tensor <128 x128 x26 x26 x26 x!quant.uniform <i32 :f32 , 1.000000e+00 :5 >>
2746
2746
return
2747
2747
}
2748
+
2749
+ // -----
2750
+
2751
+ // CHECK-LABEL: func.func @decompose_composite_ops
2752
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>,
2753
+ // CHECK-SAME: %[[VAL_1:.*]]: tensor<2xf32>) -> tensor<2xf32> {
2754
+ func.func @decompose_composite_ops (%arg0 : tensor <2 xf32 >, %arg1 : tensor <2 xf32 >) -> tensor <2 xf32 > {
2755
+ // CHECK: %[[QUANT_0:.*]] = stablehlo.composite "stablehlo.uniform_quantize" %[[VAL_0]]
2756
+ // CHECK-SAME: : (tensor<2xf32>) -> tensor<2xi8>
2757
+ %0 = stablehlo.composite " stablehlo.uniform_quantize" %arg0 {decomposition = @stablehlo.uniform_quantize.impl } : (tensor <2 xf32 >) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2758
+
2759
+ // CHECK: %[[QUANT_1:.*]] = stablehlo.composite "stablehlo.uniform_quantize" %[[VAL_1]]
2760
+ // CHECK-SAME: : (tensor<2xf32>) -> tensor<2xi8>
2761
+ %1 = stablehlo.composite " stablehlo.uniform_quantize" %arg1 {decomposition = @stablehlo.uniform_quantize.impl1 } : (tensor <2 xf32 >) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2762
+
2763
+ // CHECK: %[[ADD:.*]] = stablehlo.composite "stablehlo.add" %[[QUANT_0]], %[[QUANT_1]]
2764
+ // CHECK-SAME: : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
2765
+ %2 = stablehlo.composite " stablehlo.add" %0 , %1 {decomposition = @stablehlo.add.impl } : (tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>, tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2766
+
2767
+ // CHECK: %[[DEQUANT_0:.*]] = stablehlo.composite "stablehlo.uniform_dequantize" %[[ADD]]
2768
+ // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xf32>
2769
+ %3 = stablehlo.composite " stablehlo.uniform_dequantize" %2 {decomposition = @stablehlo.uniform_dequantize.impl } : (tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>) -> tensor <2 xf32 >
2770
+
2771
+ // CHECK: return %[[DEQUANT_0]] : tensor<2xf32>
2772
+ return %3 : tensor <2 xf32 >
2773
+ }
2774
+
2775
+ // CHECK-LABEL: func.func private @stablehlo.uniform_quantize.impl
2776
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi8> {
2777
+ func.func private @stablehlo.uniform_quantize.impl (%arg0 : tensor <2 xf32 >) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >> {
2778
+ %0 = stablehlo.uniform_quantize %arg0 : (tensor <2 xf32 >) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2779
+ return %0: tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2780
+ }
2781
+
2782
+ // CHECK-LABEL: func.func private @stablehlo.uniform_quantize.impl1
2783
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi8> {
2784
+ func.func private @stablehlo.uniform_quantize.impl1 (%arg0 : tensor <2 xf32 >) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >> {
2785
+ %0 = stablehlo.uniform_quantize %arg0 : (tensor <2 xf32 >) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2786
+ return %0 : tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2787
+ }
2788
+
2789
+ // CHECK-LABEL: func.func private @stablehlo.add.impl
2790
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi8>,
2791
+ // CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi8>) -> tensor<2xi8> {
2792
+ func.func private @stablehlo.add.impl (%arg0 : tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>, %arg1 : tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>) -> tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >> {
2793
+ %0 = stablehlo.add %arg0 , %arg1 : tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2794
+ return %0 : tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>
2795
+ }
2796
+
2797
+ // CHECK-LABEL: func.func private @stablehlo.uniform_dequantize.impl
2798
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi8>) -> tensor<2xf32> {
2799
+ func.func private @stablehlo.uniform_dequantize.impl (%arg0 : tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>) -> tensor <2 xf32 > {
2800
+ %0 = stablehlo.uniform_dequantize %arg0 : (tensor <2 x!quant.uniform <i8 :f32 , 1.000000e-01 >>) -> tensor <2 xf32 >
2801
+ return %0: tensor <2 xf32 >
2802
+ }
0 commit comments