Skip to content

Commit 3a98706

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Including mixed quant Conv1D op in Jarvis (#14865)
Summary: # Context With the goal of porting mHML on Executorch, a few operators are missing. The main focus is on improving performance for the operators used by the model. # Summary This diff includes a general and HiFi4 optimized Convolution 1D operator. Specifically, it adds both a standard Convolution 1D implementation and a version optimized for HiFi4 DSPs, ensuring better performance on supported hardware. --- #hthtemplate Reviewed By: skrtskrtfb Differential Revision: D81652570
1 parent f64c864 commit 3a98706

File tree

5 files changed

+163
-0
lines changed

5 files changed

+163
-0
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,8 @@
553553
kernels:
554554
- arg_meta: null
555555
kernel_name: impl::HiFi::quantized_w8a32_linear_out
556+
557+
- func: cadence::quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)
558+
kernels:
559+
- arg_meta: null
560+
kernel_name: impl::HiFi::quantized_w8a32_conv_out

backends/cadence/aot/ops_registrations.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,12 @@
571571
"quantized_w8a32_linear.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
572572
)
573573

574+
lib.define(
575+
"quantized_w8a32_conv(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale) -> Tensor"
576+
)
577+
lib.define(
578+
"quantized_w8a32_conv.out(Tensor input, Tensor weight, float w_scale, Tensor bias, float b_scale, *, Tensor(a!) output) -> Tensor(a!)"
579+
)
574580

575581
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
576582
aten_lib = Library("aten", "FRAGMENT")
@@ -2589,3 +2595,32 @@ def quantized_w8a32_linear_meta(
25892595
assert src_shape[-1] == weight_shape[-1]
25902596
src_shape[-1] = weight_shape[0]
25912597
return src.new_empty(src_shape, dtype=src.dtype)
2598+
2599+
2600+
@register_fake("cadence::quantized_w8a32_conv")
2601+
def quantized_w8a32_conv_meta(
2602+
src: torch.Tensor,
2603+
weight: torch.Tensor,
2604+
w_scale: float,
2605+
bias: torch.Tensor,
2606+
b_scale: float,
2607+
) -> torch.Tensor:
2608+
# src comes in shape [batch, in_channel, in_length]
2609+
# weight comes in shape [out_ch, in_ch, kernel_dim]
2610+
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
2611+
assert len(src.shape) == 3
2612+
2613+
kernel_size, out_channels, in_channels = weight.shape
2614+
assert in_channels == src.shape[-1]
2615+
2616+
# Compute the output tensor size
2617+
output_size = get_conv1d_output_size(
2618+
src.permute(0, 2, 1).shape,
2619+
out_channels,
2620+
stride=1,
2621+
padding=0,
2622+
dilation=1,
2623+
kernel_size=kernel_size,
2624+
channel_last=False,
2625+
)
2626+
return src.new_empty(output_size, dtype=src.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MixedW8A32ConvPattern,
2728
MixedW8A32LinearPattern,
2829
ReluPattern0,
2930
ReluPattern1,
@@ -478,6 +479,52 @@ def get_args_and_kwargs_softmax(
478479
out_zero_point_tensor,
479480
)
480481
kwargs = {}
482+
483+
return args, kwargs
484+
485+
486+
def get_args_and_kwargs_mixed_w8a32_conv(
487+
graph_module: GraphModule,
488+
other_inputs: List[fx.Node],
489+
weights_inputs: List[fx.Node],
490+
dequants_weights: List[fx.Node],
491+
bias_inputs: List[fx.Node],
492+
dequants_biases: List[fx.Node],
493+
op_node: fx.Node,
494+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
495+
# Stride, padding, dilation, groups not supported yet
496+
if len(op_node.args) > 3:
497+
assert op_node.args[3] == [1] # Stride
498+
if len(op_node.args) > 4:
499+
assert op_node.args[4] == [0] # Padding
500+
if len(op_node.args) > 5:
501+
assert op_node.args[5] == [1] # Dilation
502+
if len(op_node.args) > 6:
503+
assert op_node.args[6] == 1 # Groups
504+
505+
assert len(dequants_weights) == 1
506+
assert len(dequants_biases) == 1
507+
W_scale_ = dequants_weights[0].args[1]
508+
B_scale_ = dequants_biases[0].args[1]
509+
510+
transposed_inputs = graph_module.graph.call_function(
511+
torch.ops.aten.permute.default,
512+
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
513+
)
514+
transposed_weights = graph_module.graph.call_function(
515+
torch.ops.aten.permute.default,
516+
(weights_inputs[0], [2, 0, 1]), # NCL -> NLC
517+
)
518+
519+
args = (
520+
transposed_inputs,
521+
transposed_weights,
522+
W_scale_,
523+
bias_inputs[0],
524+
B_scale_,
525+
)
526+
kwargs = {}
527+
481528
return args, kwargs
482529

483530

@@ -650,6 +697,16 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
650697
bias_inputs,
651698
dequants_biases,
652699
)
700+
elif isinstance(pattern, MixedW8A32ConvPattern):
701+
args, kwargs = get_args_and_kwargs_mixed_w8a32_conv(
702+
graph_module,
703+
other_inputs,
704+
weights_inputs,
705+
dequants_weights,
706+
bias_inputs,
707+
dequants_biases,
708+
op_node,
709+
)
653710

654711
fused = graph_module.graph.call_function(
655712
pattern.replacement_op(),

backends/cadence/aot/quantizer/patterns.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,65 @@ def get_anchors(
599599

600600
def replacement_op(self) -> OpOverload:
601601
return torch.ops.cadence.quantized_w8a32_linear.default
602+
603+
604+
class MixedW8A32ConvPattern(QuantizationPattern):
605+
def partition_types(self) -> List[OpOverload]:
606+
return [torch.ops.aten.conv1d.default]
607+
608+
def get_anchors(
609+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
610+
) -> Tuple[PartitionAnchors, fx.Node]:
611+
# pyre-ignore[29]
612+
conv_layer = fused_partition[0].nodes[-1]
613+
614+
# Bail if the arguments have different shapes than expected
615+
# Stride, padding, dilation and groups are not supported
616+
if len(conv_layer.args) != 3 or len(conv_layer.kwargs) > 0:
617+
return (
618+
PartitionAnchors(
619+
empty=True,
620+
),
621+
conv_layer,
622+
)
623+
624+
cnn_weights = conv_layer.args[1]
625+
if hasattr(cnn_weights.meta, "tensor_meta"):
626+
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
627+
# Bail if the channels are not multiple of 4 (SIMD)
628+
if cnn_weights_shape[0] % 4 != 0:
629+
return (
630+
PartitionAnchors(
631+
empty=True,
632+
),
633+
conv_layer,
634+
)
635+
if cnn_weights_shape[1] % 4 != 0:
636+
return (
637+
PartitionAnchors(
638+
empty=True,
639+
),
640+
conv_layer,
641+
)
642+
# Bail if the kernel size is not 3
643+
if cnn_weights_shape[2] != 3:
644+
return (
645+
PartitionAnchors(
646+
empty=True,
647+
),
648+
conv_layer,
649+
)
650+
651+
return (
652+
PartitionAnchors(
653+
inputs=[],
654+
weights=[(conv_layer, 1)],
655+
biases=[(conv_layer, 2)],
656+
output=[],
657+
others=[(conv_layer, 0)],
658+
),
659+
conv_layer,
660+
)
661+
662+
def replacement_op(self) -> OpOverload:
663+
return torch.ops.cadence.quantized_w8a32_conv.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
LayerNormPattern,
2525
LinearPattern,
2626
MatmulPattern,
27+
MixedW8A32ConvPattern,
2728
MixedW8A32LinearPattern,
2829
QuantizationPattern,
2930
ReluPattern0,
@@ -321,6 +322,9 @@ def __init__(self) -> None:
321322
quantizers.append(
322323
CadenceAtenQuantizer(MixedW8A32LinearPattern(), qconfig_A32W8sym)
323324
)
325+
quantizers.append(
326+
CadenceAtenQuantizer(MixedW8A32ConvPattern(), qconfig_A32W8sym)
327+
)
324328
super().__init__(quantizers)
325329

326330

0 commit comments

Comments
 (0)