Fix: auto-strip ONNX post-training quantization on model load#1025
Open
philippelaporteconcordia wants to merge 2 commits into
Open
Fix: auto-strip ONNX post-training quantization on model load#1025philippelaporteconcordia wants to merge 2 commits into
philippelaporteconcordia wants to merge 2 commits into
Conversation
…zkonduit#942) ONNX Runtime's PTQ inserts QuantizeLinear, DequantizeLinear, DynamicQuantizeLinear, MatMulInteger, ConvInteger, and QLinear* operators that tract cannot analyse, producing an opaque "Failed analyse for node ... ConvHir" panic. EZKL already quantizes internally via the `scale` run argument, so a pre-quantized model is both redundant and unsupported. Scan the parsed InferenceModel right after `model_for_read` and surface a `GraphError::UnsupportedQuantizationOps` listing the offending nodes, with guidance to export the float model or strip the Q/DQ pairs. Cover both the static `QuantizeLinear`/`DequantizeLinear` pattern and the dynamic `DynamicQuantizeLinear`/`MatMulInteger` pattern from the issue. Add a new tests/quantization_detection.rs integration suite with two checked-in fixtures under tests/assets/: - quantized_qdq.onnx — minimal QuantizeLinear -> DequantizeLinear -> Conv graph exercising the static-PTQ path. - quantized_dynamic.onnx — DynamicQuantizeLinear -> MatMulInteger graph mirroring the issue's onnxruntime quantize_dynamic output. Each fixture is loaded via Model::new and the test asserts the new UnsupportedQuantizationOps variant is returned with a non-empty report. Verified end-to-end on the issue's face_landmark_quantized.onnx: the detector reports 44 offending nodes with the actionable message instead of the original tract panic. Existing `quantize_dequantize` example is unaffected (its exported ONNX is a folded float Gemm with no Q/DQ ops).
…onduit#942) ONNX Runtime's PTQ inserts QuantizeLinear, DequantizeLinear, DynamicQuantizeLinear, MatMulInteger, ConvInteger, and QLinear* operators that tract cannot analyse, producing an opaque "Failed analyse for node ... ConvHir" panic. EZKL already quantizes internally via the `scale` run argument, so a pre-quantized model is both redundant and unsupported. Add `src/graph/dequantize.rs`, a protobuf-level rewriter exposed as `dequantize::apply(&mut ModelProto) -> Result<DequantizationReport, DequantizationError>`. It collapses three patterns: * `QuantizeLinear -> DequantizeLinear` activation identity pairs are folded into a direct edge. * Standalone `DequantizeLinear(W_int, scale, zp)` on weight initializers is folded into a single float initializer (`(W_int - zp) * scale`). * The `DynamicQuantizeLinear -> ConvInteger/MatMulInteger -> Cast -> Mul` fusion that `quantize_dynamic` emits is collapsed to a plain `Conv`/`MatMul` over (x, dequantized_W). Spatial attributes on `ConvInteger` are preserved on the replacement `Conv`. The pass runs automatically inside `Model::new`: read bytes, decode via `prost::Message::decode`, rewrite, re-encode, hand cleaned bytes to tract through a `Cursor`. The existing `reject_onnx_quantization_ops` detector survives as a safety net for unsupported patterns (e.g. QLinearConv) and for the new `--disable-quantization-fixup` opt-out flag, which is the only way to surface the safety-net error today. Add an `ezkl dequantize -M <input.onnx> -O <output.onnx>` subcommand that exposes the same rewrite as a one-shot tool — useful for inspecting what the auto-pass did, sharing cleaned models, or feeding non-EZKL toolchains. Tests: * 8 unit tests in `src/graph/dequantize.rs` cover each pattern, idempotence, a float-only no-op case, an unsupported `QLinearConv` case, and the shared-`DynamicQuantizeLinear`-feeding-multiple- integer-ops scenario that previously broke producer lookup. * `tests/quantization_detection.rs` (4 tests): default `Model::new` accepts each Q/DQ fixture; with `disable_quantization_fixup=true` the safety net fires with `UnsupportedQuantizationOps`. * `tests/dequantize_pipeline.rs` (2 tests): shells out to `ezkl dequantize`, then loads the cleaned model with the auto-pass *disabled* to prove the persisted bytes alone are accepted. * `tests/dequantize_e2e.rs` (3 tests, new) drives the full pipeline end-to-end on a pre-quantized fixture: - `gen-settings → calibrate-settings → compile-circuit → gen-witness → mock` succeeds with no manual dequantize step; the witness output (after dequantising via the calibrated output scales) agrees with a tract inference of the equivalent float model within ~0.5 per element. - `--disable-quantization-fixup` halts at `gen-settings` with `UnsupportedQuantizationOps`. - `#[ignore]`-gated companion runs the full SNARK `setup → prove → verify` on top (~4 s on the fixture, opt in with `cargo test -- --ignored`). * `tests/python/binding_tests.py::test_py_run_args` round-trips the new `PyRunArgs.disable_quantization_fixup` attribute. * Two checked-in fixtures under `tests/assets/`: - `quantized_qdq.onnx` — minimal QuantizeLinear → DequantizeLinear → Conv graph (with explicit Conv attrs so the cleaned model is also valid for `gen-settings`). - `quantized_dynamic.onnx` — DynamicQuantizeLinear → MatMulInteger graph mirroring the issue's `quantize_dynamic` output. * Mirror fixture under `examples/onnx/quantized_qdq/` plus `quantized_qdq` added to `tests/integration_tests.rs::TESTS[]` (bumped from `[&str; 100]` to `[&str; 101]`); two `seq!` macro ranges bumped from `0..99`/`0..=99` to `0..=100` so the new fixture (and the previously-uncovered `large_mlp` at idx 99) get picked up by 36 existing mock/prove/accuracy wrappers. Verified end-to-end on the issue's `face_landmark_quantized.onnx`: * `ezkl gen-settings -M face_landmark_quantized.onnx` → succeeds (45 dynamic-quantize fusions auto-rewritten transparently). * `ezkl dequantize -M face_landmark_quantized.onnx -O face_clean.onnx` → reports the per-pattern rewrite counts and writes the cleaned model. * `ezkl gen-settings -M face_landmark_quantized.onnx --disable-quantization-fixup` → safety-net error listing 44 unrecognised quantization operators with an actionable message pointing at `ezkl dequantize`. The dequantize pass adds essentially zero overhead to `gen-settings` wall time (the ~90 s on this model is intrinsic to tract analysis + ezkl circuit construction; quantized-vs-cleaned timings are within noise). Adds `prost = "0.11"` as a direct optional dep (gated on the existing `onnx` feature) so we can `Message::decode`/`encode` `tract_onnx::pb` types directly. Tract already pulled prost in transitively; the direct dep just pins the major version we compile against.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #942.
EZKL crashes with an opaque tract panic ("Failed analyse for node ... ConvHir") when given an ONNX model that was post-training-quantized by
onnxruntime.quantization(e.g.quantize_dynamic(weight_type=QInt8)). The PTQ ops —QuantizeLinear,DequantizeLinear,DynamicQuantizeLinear,MatMulInteger,ConvInteger, theQLinear*family — are not analysable by tract, and they're conceptually redundant with EZKL's own internalscale-driven quantization.This PR makes pre-quantized models load and run end-to-end without any manual preprocessing.
What changed
Auto-dequantize on model load. A new in-process pass at
src/graph/dequantize.rs(dequantize::apply(&mut ModelProto)) runs insideModel::newand canonicalises three patterns back to float equivalents:QuantizeLinear → DequantizeLinearwith shared scale/zp is folded into a direct edge.DequantizeLinear—DequantizeLinear(W_int, scale, zp)on a weight initializer is folded into a single float initializer ((W_int - zp) * scale).DynamicQuantizeLinear+ integer-op fusion — theDynamicQuantizeLinear → ConvInteger/MatMulInteger → Cast → Mulsubgraph thatquantize_dynamicemits is collapsed to a plainConv/MatMulover(x, dequantized_W). Spatial attributes onConvIntegerare preserved on the replacementConv. Trailing biasAddnodes are left untouched.The pass is purely protobuf-level: read bytes →
prost::Message::decode→ rewrite → re-encode → hand cleaned bytes to tract via aCursor. Idempotent on already-clean models, and a no-op on float-only graphs.Safety net + opt-out. A pre-existing detector (
Model::reject_onnx_quantization_ops) survives as a fallback for patterns we don't rewrite (QLinearConv,QLinearMatMul,QLinearAdd, …) and for users who pass the new--disable-quantization-fixupflag. In both cases the loader returnsGraphError::UnsupportedQuantizationOpswith an actionable message that names the offending nodes and points at thedequantizesubcommand.ezkl dequantizesubcommand. Same rewrite, exposed as a one-shot tool that writes a cleaned.onnxto disk — useful for inspection, audit, sharing, or feeding non-EZKL toolchains:ezkl dequantize -M input.onnx -O cleaned.onnx
The command prints a per-pattern report (
collapsed N Q/DQ pairs, folded N weight DQ, replaced N dynamic-quantize fusions).Bindings + CLI surface.
RunArgs.disable_quantization_fixup: bool(defaultfalse), CLI:--disable-quantization-fixup.PyRunArgs.disable_quantization_fixupmirrors the field for the Python bindings.Dependencies. Adds
prost = "0.11"as a direct optional dep (gated on the existingonnxfeature), version-pinned to match tract's transitive use so theMessagetrait impls resolve consistently.Demo on the issue's model
$ ezkl gen-settings -M face_landmark_quantized.onnx
[*] succeeded
← auto-dequantize transparently rewrote 45 fusions
$ ezkl dequantize -M face_landmark_quantized.onnx -O face_clean.onnx
[*] wrote cleaned ONNX to face_clean.onnx
(collapsed 0 Q/DQ pairs, folded 0 weight DQ, replaced 45 dynamic-quantize fusions)
$ ezkl gen-settings -M face_landmark_quantized.onnx --disable-quantization-fixup
[E] [graph] model contains ONNX quantization operators EZKL cannot rewrite
(conv2d_1__52:0_QuantizeLinear (DynamicQuantizeLinear), … (+39 more)).
EZKL handles quantization internally via the scale run argument and
transparently strips post-training-quantization patterns it recognises.
The operators above were not recognised — please export the original
floating-point model, or run ezkl dequantize -M <input.onnx> -O <output.onnx>
to inspect the partial rewrite.
The dequantize pass adds essentially zero overhead to
gen-settings: timed at ~0.02 s on the issue's 86-node face_landmark model;gen-settingsend-to-end timings on the quantized vs. pre-cleaned variant are within noise (~90 s either way, dominated by tract analysis + circuit construction).Tests
cargo test --lib graph::dequantizeQLinearConvreporting, shared-DynamicQuantizeLinear-feeding-multiple-integer-opscargo test --test quantization_detectionModel::newaccepts each Q/DQ fixture; withdisable_quantization_fixup=truethe safety net firescargo test --test dequantize_pipelineezkl dequantizesubcommand round-trips; cleaned model loads with auto-pass disabledcargo test --test dequantize_e2egen-settings → calibrate → compile-circuit → gen-witness → mockon the pre-quantized fixture; witness output (after dequantising via calibrated scales) matches a tract inference of the equivalent float model within0.5per element. Negative test:--disable-quantization-fixuphalts atgen-settings.#[ignore]d companion drives the full SNARKsetup → prove → verify(~4 s on the fixture; opt in withcargo test -- --ignored)tests/python/binding_tests.py::test_py_run_argsPyRunArgs.disable_quantization_fixupround-triptests/integration_tests.rs::TESTS[]quantized_qdqfixture picked up by 36 existing wrappers (mock_, kzg_prove_and_verify_, accuracy_measurement_*) for free coverage. Smoke-testedmock_public_outputs_::tests_100_expectspasses with witness max-abs-error of2.4e-4Two small fixtures are checked in:
tests/assets/quantized_qdq.onnx— minimalQuantizeLinear → DequantizeLinear → Convgraph (440 B).tests/assets/quantized_dynamic.onnx—DynamicQuantizeLinear → MatMulIntegergraph mirroring the issue'squantize_dynamicoutput (697 B).examples/onnx/quantized_qdq/{network.onnx,input.json}— same Q/DQ fixture plus a deterministic input, picked up by the existingTESTS[]harness.Clippy clean on all changed files.
cargo fmt --checkclean on all changed files (the existing pre-merge diffs ineth.rs/pfsys/srs.rsare untouched).Test plan
cargo test --lib graph::dequantize— 8/8 passcargo test --test quantization_detection— 4/4 passcargo test --test dequantize_pipeline— 2/2 passcargo test --test dequantize_e2e— 2/2 pass; 1 ignoredcargo test --test dequantize_e2e -- --ignored— full SNARK 1/1 pass (~3.9 s)cargo test --test integration_tests mock_public_outputs_::tests_100_expects— 1/1 pass (~76 s)cargo clippy --features ezkl --tests --no-deps— no new warnings on changed filescargo fmt --check— no diffs on changed filesface_landmark_quantized.onnx: default load succeeds,--disable-quantization-fixupproduces the actionable error,ezkl dequantizeproduces a cleaned model thatgen-settingsthen accepts.Notes for reviewers
--disable-quantization-fixupflag exists primarily for debugging the rewriter or for users who deliberately want EZKL to see the original pre-quantized graph.prost = "0.11"pin is the right way to consumetract_onnx::pb::*directly, and (2) whether theseq!range bumps intests/integration_tests.rs(from0..99to0..=100, which incidentally adds coverage for the previously-uncoveredlarge_mlpat idx 99) match your testing intent.Future work
This PR closes the immediate bug, but leaves room for follow-ups in two distinct directions.
Native (circuit-level) support for quantization ops
The current PR rewrites PTQ patterns back to their float equivalents before the constraint compiler ever sees them. An alternative direction would be to
src/circuit/ops/and mapping them onto field-element arithmetic. That would unlock things this PR does not:(W_int - W_zp) * W_scale, which today re-introduces fixed-point noise on top of the int8 noise already baked into the model.The work would touch
src/circuit/ops/poly/(linear arithmetic),src/circuit/ops/lookup/(round/clamp at quantization boundaries), andsrc/graph/utilities.rs(op dispatch). It's a substantial piece of work — probably 2–3× the size of this PR — and would benefit from maintainer guidance on the field-arithmetic representation of(x - zp) * scaleand zero-point handling.Broader ONNX quantization-op coverage
Today the dequantize pass handles the patterns ONNX Runtime's
quantize_dynamicemits and the textbook activation Q/DQ identity pair. The safety-net detector catches everything else and reports it. Concrete extensions, ranked by likelihood-of-being-needed:QLinearConv/QLinearMatMul/QLinearAdd/QLinearMul/QLinearGlobalAveragePool/QLinearLeakyRelu/QLinearSigmoid/QLinearConcat— the QOperator family that ORT's staticquantize_staticemits (as opposed toquantize_dynamic, which we already handle). Same conceptual rewrite —fold the inline scale/zp tensors and emit a plain float equivalent — but each op has its own argument layout and there are many of them. Probably the highest-value extension since static PTQ is the more common ORT path in production.
Per-channel quantization parameters. The dequantize pass currently rejects per-channel scale/zp (see
DequantizationError::UnsupportedQuantParamShape). Adding per-axis broadcasting indequantize_weightwould cover most CNN weight quantization in the wild (weight_axis = 0forConv,weight_axis = 1forMatMul). Mechanical change, ~50 lines.MatMulIntegerToFloat— an ORT-internal fused op that combinesMatMulInteger + Cast + Mulinto one node. Adding it to the pattern table next toMatMulIntegerwould be a small extension.Quantization Aware Training (QAT) graphs —
torch.ao.quantization.convert(...)-exported models. Pattern-wise these often look like the static QOperator family, but with per-channel scales — so this lands naturally if (1) and (2) ship.tf2onnx-converted TFLite quantized models. TFLite uses asymmetric uint8 quantization with sometimes-different graph topology after conversion. Worth a separate fixture and pattern-by-pattern triage.fp16/bfloat16weight tensors. Thetensor_to_f32helper currently errors onFloat16/Bfloat16. Straightforward to add (usehalf::f16for the reinterpret), but probably belongs as part of a broader half-precision support story rather than this dequantize pass.Symmetric vs. asymmetric quantization. The current code already handles both (zp can be zero or non-zero, signed or unsigned), but the test fixtures are all int8 symmetric. Adding an asymmetric uint8 fixture would harden coverage.
Low-risk follow-ups
gen-settingswrites the rewrite counts to adebug!log. Promoting that toinfo!(or surfacing it in--verbosemode) would let users see at-a-glance that their model was auto-rewritten — useful to dispel the "what just happened to my graph?" question.ezkl dequantize --check— a flag that runs the rewrite, reports what would be done, but does not write the cleaned file. Useful for CI integration where teams want to assert their models don't need rewriting.debug_assert!after the rewrite would catch any future regression where a pattern accidentally re-introduces a Q/DQ op it just removed.