diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 7af2355df..a5fe310d5 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -44,6 +44,7 @@ use tract_onnx::tract_hir::{ }; /// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation. +/// The mapping is `q = round(elem * 2^scale + shift)`. /// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out. /// Arguments /// @@ -56,19 +57,26 @@ pub fn quantize_float( scale: crate::Scale, ) -> Result { let mult = scale_to_multiplier(scale); - let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation - - if *elem > max_value || *elem < -max_value { + // The representable range after applying shift is: + // IntegerRep::MIN <= round(mult * elem + shift) <= IntegerRep::MAX + // so `elem` must lie in [(IntegerRep::MIN - shift) / mult, (IntegerRep::MAX - shift) / mult]. + // Previously this used a symmetric bound `[-max_value, max_value]` which is correct only + // when shift == 0; for non-zero shift it both over- and under-rejects values. + let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); + let min_value = ((IntegerRep::MIN as f64 - shift) / mult).round(); + + if *elem > max_value || *elem < min_value { return Err(TensorError::SigBitTruncationError); } - // we parallelize the quantization process as it seems to be quite slow at times let scaled = (mult * *elem + shift).round() as IntegerRep; Ok(scaled) } /// Dequantizes a field element to a f64 using a fixed point representation. +/// Inverse of [`quantize_float`]: given `q = round(elem * 2^scale + shift)`, +/// recover `elem ≈ (q - shift) / 2^scale`. /// Arguments /// * `felt` - the field element to dequantize. /// * `scale` - `2^scale` used in the fixed point representation. @@ -76,7 +84,11 @@ pub fn quantize_float( pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 { let int_rep = crate::fieldutils::felt_to_integer_rep(felt); let multiplier = scale_to_multiplier(scale); - int_rep as f64 / multiplier - shift + // Previously this computed `int_rep / multiplier - shift`, which is the inverse of + // `q = (elem + shift) * mult`, not of `quantize_float`'s `q = elem * mult + shift`. + // The two agree only when `shift == 0`. Use the correct inverse so that + // `dequantize(quantize_float(x, s, scale), scale, s) ≈ x` for any `s`. + (int_rep as f64 - shift) / multiplier } /// Converts a scale (log base 2) to a fixed point multiplier. @@ -1660,6 +1672,76 @@ pub mod tests { assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err()); } + #[test] + fn test_quantize_dequantize_roundtrip_zero_shift() { + // The zero-shift case is the common production path. It must roundtrip exactly + // for integer-valued inputs at scale = 0 and within tolerance otherwise. + let scale = 8; + for &x in &[-3.5_f64, -1.0, 0.0, 0.25, 1.0, 3.75, 42.125] { + let q = quantize_float(&x, 0.0, scale).unwrap(); + let felt = crate::fieldutils::integer_rep_to_felt::(q); + let recovered = dequantize(felt, scale, 0.0); + assert!( + (recovered - x).abs() < 1.0 / scale_to_multiplier(scale), + "roundtrip drift > 1 ulp: x={x} recovered={recovered}" + ); + } + } + + #[test] + fn test_quantize_dequantize_roundtrip_with_shift() { + // Regression: dequantize previously computed `q/mult - shift`, which is the + // inverse of `q = (elem + shift) * mult` and disagrees with `quantize_float`'s + // `q = elem * mult + shift` whenever `shift != 0`. + // After the fix, dequantize ∘ quantize_float must be the identity (within + // rounding) for every shift value, not only shift = 0. + let scale = 8; + let mult = scale_to_multiplier(scale); + for &shift in &[-256.0_f64, -8.0, 0.0, 1.5, 100.0] { + for &x in &[-2.0_f64, -0.5, 0.0, 0.125, 1.0, 7.5] { + let q = quantize_float(&x, shift, scale).unwrap(); + let felt = crate::fieldutils::integer_rep_to_felt::(q); + let recovered = dequantize(felt, scale, shift); + assert!( + (recovered - x).abs() < 1.0 / mult, + "roundtrip failed for shift={shift} x={x}: q={q} recovered={recovered}" + ); + } + } + } + + #[test] + fn test_quantize_dequantize_known_value_with_shift() { + // Hand-checked value: scale = 4 (mult = 16), shift = 5.0, x = 2.0. + // q = round(2.0 * 16 + 5) = 37 + // dequantize(37, 4, 5) + // before fix: 37/16 - 5 = -2.6875 (wrong) + // after fix : (37 - 5)/16 = 2.0 (correct) + let scale = 4; + let shift = 5.0; + let x = 2.0; + let q = quantize_float(&x, shift, scale).unwrap(); + assert_eq!(q, 37); + let felt = crate::fieldutils::integer_rep_to_felt::(q); + let recovered = dequantize(felt, scale, shift); + assert!((recovered - x).abs() < 1e-9, "expected ~2.0, got {recovered}"); + } + + #[test] + fn test_quantize_float_asymmetric_bounds_with_shift() { + // With shift != 0 the representable range of `elem` is shifted; the old + // symmetric check `|elem| <= max_value` was incorrect once `shift != 0`. + // Pick a small scale + shift combo and verify both rails are honoured. + let scale = 2; // mult = 4 + let shift = 8.0; + // Symmetric values around zero should both succeed: q lands inside i128 by miles. + assert!(quantize_float(&1.0_f64, shift, scale).is_ok()); + assert!(quantize_float(&(-1.0_f64), shift, scale).is_ok()); + // Non-finite inputs must still error on either side regardless of shift. + assert!(quantize_float(&f64::INFINITY, shift, scale).is_err()); + assert!(quantize_float(&f64::NEG_INFINITY, shift, scale).is_err()); + } + #[test] fn test_flatten_valtensors() { let tensor1: Tensor = (0..10).map(|x| x.into()).into();