From 3b751e01842c494f35bb8b90acf158a9be6c229c Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 26 May 2025 23:55:53 -0700 Subject: [PATCH 1/2] [QD8-BF16-QB4] Update XNNPACK flatbuffer with new XNNPACK Datatypes --- backends/xnnpack/runtime/XNNCompiler.cpp | 13 ++++++++++++- backends/xnnpack/serialization/runtime_schema.fbs | 9 +++++++++ backends/xnnpack/serialization/schema.fbs | 9 +++++++++ .../xnnpack/serialization/xnnpack_graph_schema.py | 4 ++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 445744e991..d412e18cc1 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -97,7 +97,10 @@ std::pair getOutputMinMax(const NodePtr node) noexcept { } /* -Converts flatbuffer xnn data type to xnnpack data type +Converts flatbuffer xnn data type to xnnpack data type. + +NOTE: +Flatbuffer Enum Values are not the same as XNNPACK's datatype enum values. */ xnn_datatype getDataType(const DataType& data_type) { switch (data_type) { @@ -121,6 +124,14 @@ xnn_datatype getDataType(const DataType& data_type) { return xnn_datatype::xnn_datatype_qdint8; case DataType::xnn_datatype_qbint4: return xnn_datatype::xnn_datatype_qbint4; + case DataType::xnn_datatype_qpint8: + return xnn_datatype::xnn_datatype_qpint8; + case DataType::xnn_datatype_int32: + return xnn_datatype::xnn_datatype_int32; + case DataType::xnn_datatype_pfp32: + return xnn_datatype::xnn_datatype_pfp32; + case DataType::xnn_datatype_bf16: + return xnn_datatype::xnn_datatype_bf16; default: return xnn_datatype::xnn_datatype_invalid; } diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index f10ba3d1b8..99f9e4e5fb 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -29,6 +29,15 @@ enum XNNDatatype : short { xnn_datatype_qdint8 = 9, /// Quantized 4-bit signed integer with shared blockwise quantization parameters. xnn_datatype_qbint4 = 10, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 11, + /// 32-bit signed integers. + xnn_datatype_int32 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, } // type of quantization diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 565eb4c3bb..e3ed4061e9 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -29,6 +29,15 @@ enum XNNDatatype : short { xnn_datatype_qdint8 = 9, /// Quantized 4-bit signed integer with shared blockwise quantization parameters. xnn_datatype_qbint4 = 10, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 11, + /// 32-bit signed integers. + xnn_datatype_int32 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, } // type of quantization diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 2a3ccaf2a0..4e23e199de 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -413,6 +413,10 @@ class XNNDatatype(IntEnum): xnn_datatype_qcint4 = 8 xnn_datatype_qdint8 = 9 xnn_datatype_qbint4 = 10 + xnn_datatype_qpint8 = 11 + xnn_datatype_int32 = 12 + xnn_datatype_pfp32 = 13 + xnn_datatype_bf16 = 14 @dataclass From d4845a03d5a59f2a36e86f9159e2bcbbfdc2bec0 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 6 Jun 2025 18:58:43 -0700 Subject: [PATCH 2/2] [WIP] Test out the new bf16 kernels --- backends/xnnpack/operators/node_visitor.py | 2 + .../partition/config/xnnpack_config.py | 1 + backends/xnnpack/runtime/XNNCompiler.cpp | 37 ------------------- backends/xnnpack/test/ops/test_linear.py | 7 +++- backends/xnnpack/test/tester/tester.py | 7 +++- backends/xnnpack/third-party/XNNPACK | 2 +- backends/xnnpack/third-party/pthreadpool | 2 +- examples/models/llama/export_llama_lib.py | 10 ++--- 8 files changed, 21 insertions(+), 47 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 8470184d80..8c980870d0 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -271,6 +271,8 @@ def get_per_channel_dtype( if force_fp32 else XNNDatatype.xnn_datatype_fp16 ) + elif node_dtype is not None and node_dtype == torch.bfloat16: + dtype = XNNDatatype.xnn_datatype_bf16 return dtype diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index df6067a7d6..a40be5bfa8 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -219,6 +219,7 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): def _check_node_has_valid_dtype(self, node): valid_dtypes = { torch.float32, + torch.bfloat16, torch.float16, torch.int8, torch.qint8, diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index d412e18cc1..0aa92d1279 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1899,42 +1899,6 @@ Error defineStaticSliceNode( return Error::Ok; } -/* -Defines Scaled Dot Product Attention (SDPA) node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineScaledDotProductAttentionNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention(); - - xnn_status status = xnn_define_scaled_dot_product_attention( - subgraph_ptr, - xnn_attention_logits_cap_type_none, // cap_type - nullptr, // cap_value - not used - remapped_ids.at(graph_node->query_id()), - remapped_ids.at(graph_node->key_id()), - remapped_ids.at(graph_node->value_id()), - remapped_ids.at(graph_node->scale_id()), - remapped_ids.at(graph_node->mask_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create SDPA node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Defines batch matrix multiply node into the subgraph, using the remapped ids to map the serialized ids, @@ -2034,7 +1998,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate4) _DEFINE(Concatenate5) _DEFINE(StaticSlice) - _DEFINE(ScaledDotProductAttention) _DEFINE(BatchMatrixMultiply) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index 421e59c0b0..dcdd05633b 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -63,7 +63,7 @@ def __init__( self.ic = input_channels self.oc = output_channels - assert dtype in [torch.float, torch.half], "Unsupported op dtype" + assert dtype in [torch.bfloat16, torch.float, torch.half], "Unsupported op dtype" self.op_dtype = dtype self.in_size = in_size @@ -432,6 +432,7 @@ def _test_groupwise_dq_linear( ) .to_executorch() .serialize() + .dump_artifact("/Users/maxren/Desktop/oss/executorch/linear_qd8_bf16.pte") .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) @@ -676,7 +677,6 @@ def _test_qd8_per_token_weight_per_channel_group_int4( M_sizes = [1, 2, 17, 31] K_sizes = [32, 32, 64, 128] bl_sizes = [32, 32, 32, 64] - N_sizes = [2, 17, 92, 128] for input_rank in range(2, 4): for use_bias in [True, False]: @@ -831,6 +831,9 @@ def test_linear_qd8_f16_per_token_weight_per_channel_group_int4(self): def test_linear_qd8_f32_per_token_weight_per_channel_group_int4(self): self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.float) + def test_linear_qd8_bf16_per_token_weight_per_channel_group_int4(self): + self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.bfloat16) + @unittest.skipIf( not torchao_installed, "Per Channel Group Quantization Required TorchAO" ) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index dcdafebd6f..0f54ea01ed 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -536,7 +536,7 @@ def fn(x): random_inputs.append( torch.randn(input_shapes[arg_idx]).to( dtype=self.example_inputs[arg_idx].dtype - ) + )*100 ) yield tuple(random_inputs) @@ -714,6 +714,9 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): assert ( ref.shape == model.shape ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" + print(f"actual dtype: {model.dtype}, ref dtype: {ref.dtype}") + print(model) + print(ref) assert torch.allclose( model, ref, @@ -773,6 +776,8 @@ def _calculate_reference_output( return the quantization scale as well. """ + cqp = torch.ops.torchao.choose_qparams_affine.default(*inputs, 'ASYMMETRIC', [1, 32], torch.int8, None, None, None, torch.float32, torch.int8) + print(f"inv_scale: {1/cqp[0]}, zero_point: {cqp[1]}") # Locate the output node. output_node = None for node in program.graph.nodes: diff --git a/backends/xnnpack/third-party/XNNPACK b/backends/xnnpack/third-party/XNNPACK index 4ea82e595b..4b106fa608 160000 --- a/backends/xnnpack/third-party/XNNPACK +++ b/backends/xnnpack/third-party/XNNPACK @@ -1 +1 @@ -Subproject commit 4ea82e595b36106653175dcb04b2aa532660d0d8 +Subproject commit 4b106fa60892b33b2bddba11dbdb64550e2dfb3a diff --git a/backends/xnnpack/third-party/pthreadpool b/backends/xnnpack/third-party/pthreadpool index 4fe0e1e183..dcc9f28589 160000 --- a/backends/xnnpack/third-party/pthreadpool +++ b/backends/xnnpack/third-party/pthreadpool @@ -1 +1 @@ -Subproject commit 4fe0e1e183925bf8cfa6aae24237e724a96479b8 +Subproject commit dcc9f28589066af0dbd4555579281230abbf74dd diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a3102886f..17e93f0ba4 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -815,11 +815,11 @@ def _to_edge_and_lower_llama_xnnpack( modelname = f"xnnpack_dq_{modelname}" - if xnnpack_extended_ops: - partitioners.append( - get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) - ) - modelname = f"xnnpack_{modelname}" + # if xnnpack_extended_ops: + # partitioners.append( + # get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) + # ) + # modelname = f"xnnpack_{modelname}" logging.info("Lowering model using following partitioner(s): ") for partitioner in partitioners: