|
24 | 24 | LayerNormPattern,
|
25 | 25 | LinearPattern,
|
26 | 26 | MatmulPattern,
|
| 27 | + MixedW8A32ConvPattern, |
27 | 28 | MixedW8A32LinearPattern,
|
28 | 29 | ReluPattern0,
|
29 | 30 | ReluPattern1,
|
@@ -478,6 +479,52 @@ def get_args_and_kwargs_softmax(
|
478 | 479 | out_zero_point_tensor,
|
479 | 480 | )
|
480 | 481 | kwargs = {}
|
| 482 | + |
| 483 | + return args, kwargs |
| 484 | + |
| 485 | + |
| 486 | +def get_args_and_kwargs_mixed_w8a32_conv( |
| 487 | + graph_module: GraphModule, |
| 488 | + other_inputs: List[fx.Node], |
| 489 | + weights_inputs: List[fx.Node], |
| 490 | + dequants_weights: List[fx.Node], |
| 491 | + bias_inputs: List[fx.Node], |
| 492 | + dequants_biases: List[fx.Node], |
| 493 | + op_node: fx.Node, |
| 494 | +) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: |
| 495 | + # Stride, padding, dilation, groups not supported yet |
| 496 | + if len(op_node.args) > 3: |
| 497 | + assert op_node.args[3] == [1] # Stride |
| 498 | + if len(op_node.args) > 4: |
| 499 | + assert op_node.args[4] == [0] # Padding |
| 500 | + if len(op_node.args) > 5: |
| 501 | + assert op_node.args[5] == [1] # Dilation |
| 502 | + if len(op_node.args) > 6: |
| 503 | + assert op_node.args[6] == 1 # Groups |
| 504 | + |
| 505 | + assert len(dequants_weights) == 1 |
| 506 | + assert len(dequants_biases) == 1 |
| 507 | + W_scale_ = dequants_weights[0].args[1] |
| 508 | + B_scale_ = dequants_biases[0].args[1] |
| 509 | + |
| 510 | + transposed_inputs = graph_module.graph.call_function( |
| 511 | + torch.ops.aten.permute.default, |
| 512 | + (other_inputs[0], [0, 2, 1]), # NCL -> NLC |
| 513 | + ) |
| 514 | + transposed_weights = graph_module.graph.call_function( |
| 515 | + torch.ops.aten.permute.default, |
| 516 | + (weights_inputs[0], [2, 0, 1]), # NCL -> NLC |
| 517 | + ) |
| 518 | + |
| 519 | + args = ( |
| 520 | + transposed_inputs, |
| 521 | + transposed_weights, |
| 522 | + W_scale_, |
| 523 | + bias_inputs[0], |
| 524 | + B_scale_, |
| 525 | + ) |
| 526 | + kwargs = {} |
| 527 | + |
481 | 528 | return args, kwargs
|
482 | 529 |
|
483 | 530 |
|
@@ -650,6 +697,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
|
650 | 697 | bias_inputs,
|
651 | 698 | dequants_biases,
|
652 | 699 | )
|
| 700 | + elif isinstance(pattern, MixedW8A32ConvPattern): |
| 701 | + args, kwargs = get_args_and_kwargs_mixed_w8a32_conv( |
| 702 | + graph_module, |
| 703 | + other_inputs, |
| 704 | + weights_inputs, |
| 705 | + dequants_weights, |
| 706 | + bias_inputs, |
| 707 | + dequants_biases, |
| 708 | + op_node, |
| 709 | + ) |
653 | 710 |
|
654 | 711 | fused = graph_module.graph.call_function(
|
655 | 712 | pattern.replacement_op(),
|
|
0 commit comments