Skip to content

Try some QD8-BF16 Experiments #11466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def get_per_channel_dtype(
if force_fp32
else XNNDatatype.xnn_datatype_fp16
)
elif node_dtype is not None and node_dtype == torch.bfloat16:
dtype = XNNDatatype.xnn_datatype_bf16

return dtype

Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
def _check_node_has_valid_dtype(self, node):
valid_dtypes = {
torch.float32,
torch.bfloat16,
torch.float16,
torch.int8,
torch.qint8,
Expand Down
50 changes: 12 additions & 38 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ std::pair<float, float> getOutputMinMax(const NodePtr node) noexcept {
}

/*
Converts flatbuffer xnn data type to xnnpack data type
Converts flatbuffer xnn data type to xnnpack data type.

NOTE:
Flatbuffer Enum Values are not the same as XNNPACK's datatype enum values.
*/
xnn_datatype getDataType(const DataType& data_type) {
switch (data_type) {
Expand All @@ -121,6 +124,14 @@ xnn_datatype getDataType(const DataType& data_type) {
return xnn_datatype::xnn_datatype_qdint8;
case DataType::xnn_datatype_qbint4:
return xnn_datatype::xnn_datatype_qbint4;
case DataType::xnn_datatype_qpint8:
return xnn_datatype::xnn_datatype_qpint8;
case DataType::xnn_datatype_int32:
return xnn_datatype::xnn_datatype_int32;
case DataType::xnn_datatype_pfp32:
return xnn_datatype::xnn_datatype_pfp32;
case DataType::xnn_datatype_bf16:
return xnn_datatype::xnn_datatype_bf16;
default:
return xnn_datatype::xnn_datatype_invalid;
}
Expand Down Expand Up @@ -1888,42 +1899,6 @@ Error defineStaticSliceNode(
return Error::Ok;
}

/*
Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
using the remapped ids to map the serialized ids,
to the new ids generated when defining the tensor value
*/
Error defineScaledDotProductAttentionNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();

xnn_status status = xnn_define_scaled_dot_product_attention(
subgraph_ptr,
xnn_attention_logits_cap_type_none, // cap_type
nullptr, // cap_value - not used
remapped_ids.at(graph_node->query_id()),
remapped_ids.at(graph_node->key_id()),
remapped_ids.at(graph_node->value_id()),
remapped_ids.at(graph_node->scale_id()),
remapped_ids.at(graph_node->mask_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create SDPA node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Defines batch matrix multiply node into the subgraph,
using the remapped ids to map the serialized ids,
Expand Down Expand Up @@ -2023,7 +1998,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Concatenate4)
_DEFINE(Concatenate5)
_DEFINE(StaticSlice)
_DEFINE(ScaledDotProductAttention)
_DEFINE(BatchMatrixMultiply)
case fb_xnnpack::XNodeUnion::NONE:
default: // Adding here as a catch all, just in case
Expand Down
9 changes: 9 additions & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ enum XNNDatatype : short {
xnn_datatype_qdint8 = 9,
/// Quantized 4-bit signed integer with shared blockwise quantization parameters.
xnn_datatype_qbint4 = 10,
/// Dynamically quantized 8-bit signed integers packed with their per-row
/// quantization parameters.
xnn_datatype_qpint8 = 11,
/// 32-bit signed integers.
xnn_datatype_int32 = 12,
/// IEEE754 single-precision packed floating-point.
xnn_datatype_pfp32 = 13,
/// BFloat16, i.e. the upper 16 bits of a float32.
xnn_datatype_bf16 = 14,
}

// type of quantization
Expand Down
9 changes: 9 additions & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ enum XNNDatatype : short {
xnn_datatype_qdint8 = 9,
/// Quantized 4-bit signed integer with shared blockwise quantization parameters.
xnn_datatype_qbint4 = 10,
/// Dynamically quantized 8-bit signed integers packed with their per-row
/// quantization parameters.
xnn_datatype_qpint8 = 11,
/// 32-bit signed integers.
xnn_datatype_int32 = 12,
/// IEEE754 single-precision packed floating-point.
xnn_datatype_pfp32 = 13,
/// BFloat16, i.e. the upper 16 bits of a float32.
xnn_datatype_bf16 = 14,
}

// type of quantization
Expand Down
4 changes: 4 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ class XNNDatatype(IntEnum):
xnn_datatype_qcint4 = 8
xnn_datatype_qdint8 = 9
xnn_datatype_qbint4 = 10
xnn_datatype_qpint8 = 11
xnn_datatype_int32 = 12
xnn_datatype_pfp32 = 13
xnn_datatype_bf16 = 14


@dataclass
Expand Down
7 changes: 5 additions & 2 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.ic = input_channels
self.oc = output_channels

assert dtype in [torch.float, torch.half], "Unsupported op dtype"
assert dtype in [torch.bfloat16, torch.float, torch.half], "Unsupported op dtype"
self.op_dtype = dtype
self.in_size = in_size

Expand Down Expand Up @@ -432,6 +432,7 @@ def _test_groupwise_dq_linear(
)
.to_executorch()
.serialize()
.dump_artifact("/Users/maxren/Desktop/oss/executorch/linear_qd8_bf16.pte")
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)

Expand Down Expand Up @@ -676,7 +677,6 @@ def _test_qd8_per_token_weight_per_channel_group_int4(
M_sizes = [1, 2, 17, 31]
K_sizes = [32, 32, 64, 128]
bl_sizes = [32, 32, 32, 64]
N_sizes = [2, 17, 92, 128]

for input_rank in range(2, 4):
for use_bias in [True, False]:
Expand Down Expand Up @@ -831,6 +831,9 @@ def test_linear_qd8_f16_per_token_weight_per_channel_group_int4(self):
def test_linear_qd8_f32_per_token_weight_per_channel_group_int4(self):
self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.float)

def test_linear_qd8_bf16_per_token_weight_per_channel_group_int4(self):
self._test_qd8_per_token_weight_per_channel_group_int4(dtype=torch.bfloat16)

@unittest.skipIf(
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
)
Expand Down
7 changes: 6 additions & 1 deletion backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def fn(x):
random_inputs.append(
torch.randn(input_shapes[arg_idx]).to(
dtype=self.example_inputs[arg_idx].dtype
)
)*100
)

yield tuple(random_inputs)
Expand Down Expand Up @@ -714,6 +714,9 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
assert (
ref.shape == model.shape
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
print(f"actual dtype: {model.dtype}, ref dtype: {ref.dtype}")
print(model)
print(ref)
assert torch.allclose(
model,
ref,
Expand Down Expand Up @@ -773,6 +776,8 @@ def _calculate_reference_output(
return the quantization scale as well.
"""

cqp = torch.ops.torchao.choose_qparams_affine.default(*inputs, 'ASYMMETRIC', [1, 32], torch.int8, None, None, None, torch.float32, torch.int8)
print(f"inv_scale: {1/cqp[0]}, zero_point: {cqp[1]}")
# Locate the output node.
output_node = None
for node in program.graph.nodes:
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/third-party/XNNPACK
Submodule XNNPACK updated 12278 files
10 changes: 5 additions & 5 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,11 +815,11 @@ def _to_edge_and_lower_llama_xnnpack(

modelname = f"xnnpack_dq_{modelname}"

if xnnpack_extended_ops:
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"xnnpack_{modelname}"
# if xnnpack_extended_ops:
# partitioners.append(
# get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
# )
# modelname = f"xnnpack_{modelname}"

logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
Expand Down
Loading