Skip to content

Commit 5bf0fef

Browse files
authored
Extend legalize-quant-to-math pass to support composite op (#2723)
**Note to the reviewers:** This PR is based on #2722. The changes related to this PR is localized in `stablehlo/transforms/StablehloLegalizeQuantToMath.cpp` and `stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir` files only. ## Summary The `quant-to-math` legalization of composite op can be realized as: 1. Apply legalization to its decomposition. 1. Convert the quantized signature of the composite op to the integer signature. Note that both 1 and 2 are achieved __almost for free__ by the existing patterns. * (1) By virtue of the fact that the existing pass applies to every func in module * (2) As part of [ConvertGenericOp](https://github.com/openxla/stablehlo/blob/03597b1e592129f0c79e99e5ed65dac7ebee240f/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp#L1310) conversion pattern. Together with #2722, we can do something like ### Step 1 ``` $ cat input.mlir func.func @decompose_composite_op(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:2>> %1 = stablehlo.uniform_quantize %arg1 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:2>> %2 = stablehlo.add %0, %1 : tensor<2x!quant.uniform<i8:f32, 0.1:2>> %3 = stablehlo.uniform_dequantize %2 : (tensor<2x!quant.uniform<i8:f32, 0.1:2>>) -> tensor<2xf32> return %3 : tensor<2xf32> } ``` ### Step 2: Apply https://github.com/openxla/stablehlo/blob/3a0cd9d12166d8426777206339b8562be64c55bc/stablehlo/transforms/Passes.td#L413 pass As part of applying the pass, we are providing the attribute names for quantize/dequantize composites. With that we may get something like ```mlir func.func @decompose_composite_ops(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.composite "stablehlo.uniform_quantize" %arg0 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl_0} : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> %1 = stablehlo.composite "stablehlo.uniform_quantize" %arg1 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl} : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> %2 = stablehlo.composite "stablehlo.add" %0, %1 {decomposition = @stablehlo.add.impl} : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>, tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> %3 = stablehlo.composite "stablehlo.uniform_dequantize" %2 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_dequantize.impl} : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32> return %3 : tensor<2xf32> } func.func private @stablehlo.uniform_dequantize.impl(%arg0: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32> { %0 = stablehlo.uniform_dequantize %arg0 : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32> return %0 : tensor<2xf32> } func.func private @stablehlo.add.impl(%arg0: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>, %arg1: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> { %0 = stablehlo.add %arg0, %arg1 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> } func.func private @stablehlo.uniform_quantize.impl(%arg0: tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> { %0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> } func.func private @stablehlo.uniform_quantize.impl_0(%arg0: tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> { %0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> } ``` ### Step 3: Apply `stablehlo-legalize-quant-to-int` pass ```mlir func.func @decompose_composite_ops(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.composite "stablehlo.uniform_quantize" %arg0 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl_0} : (tensor<2xf32>) -> tensor<2xi8> %1 = stablehlo.composite "stablehlo.uniform_quantize" %arg1 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_quantize.impl} : (tensor<2xf32>) -> tensor<2xi8> %2 = stablehlo.composite "stablehlo.add" %0, %1 {decomposition = @stablehlo.add.impl} : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8> %3 = stablehlo.composite "stablehlo.uniform_dequantize" %2 {composite_attributes = {expressed_type = f32, scale = 1.000000e-01 : f64, storage_type = i8, storage_type_max = 127 : i64, storage_type_min = -128 : i64, zero_point = 0 : i64}, decomposition = @stablehlo.uniform_dequantize.impl} : (tensor<2xi8>) -> tensor<2xf32> return %3 : tensor<2xf32> } func.func private @stablehlo.uniform_dequantize.impl(%arg0: tensor<2xi8>) -> tensor<2xf32> { // ... decomposition of stablehlo.uniform_dequantize } func.func private @stablehlo.add.impl(%arg0: tensor<2xi8>, %arg1: tensor<2xi8>) -> tensor<2xi8> { // decomposition of quantized stablehlo.add } func.func private @stablehlo.uniform_quantize.impl(%arg0: tensor<2xf32>) -> tensor<2xi8> { // ... decomposition of stablehlo.uniform_quantize } func.func private @stablehlo.uniform_quantize.impl_0(%arg0: tensor<2xf32>) -> tensor<2xi8> { // ... decomposition of stablehlo.uniform_quantize } ``` cc @mahmoud-abuzaina
1 parent 66f90d5 commit 5bf0fef

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

stablehlo/tests/transforms/stablehlo_legalize_quant_to_int.mlir

+55
Original file line numberDiff line numberDiff line change
@@ -2745,3 +2745,58 @@ func.func @conv3d_ncdhw(
27452745
-> tensor<128x128x26x26x26x!quant.uniform<i32:f32, 1.000000e+00:5>>
27462746
return
27472747
}
2748+
2749+
// -----
2750+
2751+
// CHECK-LABEL: func.func @decompose_composite_ops
2752+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>,
2753+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xf32>) -> tensor<2xf32> {
2754+
func.func @decompose_composite_ops(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> {
2755+
// CHECK: %[[QUANT_0:.*]] = stablehlo.composite "stablehlo.uniform_quantize" %[[VAL_0]]
2756+
// CHECK-SAME: : (tensor<2xf32>) -> tensor<2xi8>
2757+
%0 = stablehlo.composite "stablehlo.uniform_quantize" %arg0 {decomposition = @stablehlo.uniform_quantize.impl} : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2758+
2759+
// CHECK: %[[QUANT_1:.*]] = stablehlo.composite "stablehlo.uniform_quantize" %[[VAL_1]]
2760+
// CHECK-SAME: : (tensor<2xf32>) -> tensor<2xi8>
2761+
%1 = stablehlo.composite "stablehlo.uniform_quantize" %arg1 {decomposition = @stablehlo.uniform_quantize.impl1} : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2762+
2763+
// CHECK: %[[ADD:.*]] = stablehlo.composite "stablehlo.add" %[[QUANT_0]], %[[QUANT_1]]
2764+
// CHECK-SAME: : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
2765+
%2 = stablehlo.composite "stablehlo.add" %0, %1 {decomposition = @stablehlo.add.impl} : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>, tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2766+
2767+
// CHECK: %[[DEQUANT_0:.*]] = stablehlo.composite "stablehlo.uniform_dequantize" %[[ADD]]
2768+
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xf32>
2769+
%3 = stablehlo.composite "stablehlo.uniform_dequantize" %2 {decomposition = @stablehlo.uniform_dequantize.impl} : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32>
2770+
2771+
// CHECK: return %[[DEQUANT_0]] : tensor<2xf32>
2772+
return %3 : tensor<2xf32>
2773+
}
2774+
2775+
// CHECK-LABEL: func.func private @stablehlo.uniform_quantize.impl
2776+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi8> {
2777+
func.func private @stablehlo.uniform_quantize.impl(%arg0 : tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> {
2778+
%0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2779+
return %0: tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2780+
}
2781+
2782+
// CHECK-LABEL: func.func private @stablehlo.uniform_quantize.impl1
2783+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi8> {
2784+
func.func private @stablehlo.uniform_quantize.impl1(%arg0 : tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> {
2785+
%0 = stablehlo.uniform_quantize %arg0 : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2786+
return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2787+
}
2788+
2789+
// CHECK-LABEL: func.func private @stablehlo.add.impl
2790+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi8>,
2791+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi8>) -> tensor<2xi8> {
2792+
func.func private @stablehlo.add.impl(%arg0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>, %arg1 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2x!quant.uniform<i8:f32, 1.000000e-01>> {
2793+
%0 = stablehlo.add %arg0, %arg1 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2794+
return %0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>
2795+
}
2796+
2797+
// CHECK-LABEL: func.func private @stablehlo.uniform_dequantize.impl
2798+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi8>) -> tensor<2xf32> {
2799+
func.func private @stablehlo.uniform_dequantize.impl(%arg0 : tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32> {
2800+
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<2x!quant.uniform<i8:f32, 1.000000e-01>>) -> tensor<2xf32>
2801+
return %0: tensor<2xf32>
2802+
}

stablehlo/transforms/StablehloLegalizeQuantToMath.cpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -1318,13 +1318,14 @@ class ConvertGenericOp : public ConversionPattern {
13181318
ConversionPatternRewriter &rewriter) const override {
13191319
// This pattern only handle selected ops.
13201320
if (!isa<stablehlo::BitcastConvertOp, stablehlo::BroadcastInDimOp,
1321-
stablehlo::ConcatenateOp, stablehlo::ConstantOp,
1322-
stablehlo::DynamicReshapeOp, stablehlo::DynamicSliceOp,
1323-
stablehlo::GatherOp, stablehlo::MaxOp, stablehlo::MinOp,
1324-
stablehlo::PadOp, stablehlo::ReduceWindowOp, stablehlo::ReshapeOp,
1325-
stablehlo::ReturnOp, stablehlo::SelectOp, stablehlo::SliceOp,
1326-
stablehlo::TransposeOp, stablehlo::GetDimensionSizeOp,
1327-
stablehlo::DynamicBroadcastInDimOp>(op)) {
1321+
stablehlo::CompositeOp, stablehlo::ConcatenateOp,
1322+
stablehlo::ConstantOp, stablehlo::DynamicReshapeOp,
1323+
stablehlo::DynamicSliceOp, stablehlo::GatherOp, stablehlo::MaxOp,
1324+
stablehlo::MinOp, stablehlo::PadOp, stablehlo::ReduceWindowOp,
1325+
stablehlo::ReshapeOp, stablehlo::ReturnOp, stablehlo::SelectOp,
1326+
stablehlo::SliceOp, stablehlo::TransposeOp,
1327+
stablehlo::GetDimensionSizeOp, stablehlo::DynamicBroadcastInDimOp>(
1328+
op)) {
13281329
return rewriter.notifyMatchFailure(
13291330
op, "Unsupported op for performing type change");
13301331
}

0 commit comments

Comments
 (0)