Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 87 additions & 5 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use tract_onnx::tract_hir::{
};

/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// The mapping is `q = round(elem * 2^scale + shift)`.
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
/// Arguments
///
Expand All @@ -56,27 +57,38 @@ pub fn quantize_float(
scale: crate::Scale,
) -> Result<IntegerRep, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation

if *elem > max_value || *elem < -max_value {
// The representable range after applying shift is:
// IntegerRep::MIN <= round(mult * elem + shift) <= IntegerRep::MAX
// so `elem` must lie in [(IntegerRep::MIN - shift) / mult, (IntegerRep::MAX - shift) / mult].
// Previously this used a symmetric bound `[-max_value, max_value]` which is correct only
// when shift == 0; for non-zero shift it both over- and under-rejects values.
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round();
let min_value = ((IntegerRep::MIN as f64 - shift) / mult).round();

if *elem > max_value || *elem < min_value {
return Err(TensorError::SigBitTruncationError);
}

// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as IntegerRep;

Ok(scaled)
}

/// Dequantizes a field element to a f64 using a fixed point representation.
/// Inverse of [`quantize_float`]: given `q = round(elem * 2^scale + shift)`,
/// recover `elem ≈ (q - shift) / 2^scale`.
/// Arguments
/// * `felt` - the field element to dequantize.
/// * `scale` - `2^scale` used in the fixed point representation.
/// * `shift` - offset used in the fixed point representation.
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
let int_rep = crate::fieldutils::felt_to_integer_rep(felt);
let multiplier = scale_to_multiplier(scale);
int_rep as f64 / multiplier - shift
// Previously this computed `int_rep / multiplier - shift`, which is the inverse of
// `q = (elem + shift) * mult`, not of `quantize_float`'s `q = elem * mult + shift`.
// The two agree only when `shift == 0`. Use the correct inverse so that
// `dequantize(quantize_float(x, s, scale), scale, s) ≈ x` for any `s`.
(int_rep as f64 - shift) / multiplier
}

/// Converts a scale (log base 2) to a fixed point multiplier.
Expand Down Expand Up @@ -1660,6 +1672,76 @@ pub mod tests {
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
}

#[test]
fn test_quantize_dequantize_roundtrip_zero_shift() {
// The zero-shift case is the common production path. It must roundtrip exactly
// for integer-valued inputs at scale = 0 and within tolerance otherwise.
let scale = 8;
for &x in &[-3.5_f64, -1.0, 0.0, 0.25, 1.0, 3.75, 42.125] {
let q = quantize_float(&x, 0.0, scale).unwrap();
let felt = crate::fieldutils::integer_rep_to_felt::<Fp>(q);
let recovered = dequantize(felt, scale, 0.0);
assert!(
(recovered - x).abs() < 1.0 / scale_to_multiplier(scale),
"roundtrip drift > 1 ulp: x={x} recovered={recovered}"
);
}
}

#[test]
fn test_quantize_dequantize_roundtrip_with_shift() {
// Regression: dequantize previously computed `q/mult - shift`, which is the
// inverse of `q = (elem + shift) * mult` and disagrees with `quantize_float`'s
// `q = elem * mult + shift` whenever `shift != 0`.
// After the fix, dequantize ∘ quantize_float must be the identity (within
// rounding) for every shift value, not only shift = 0.
let scale = 8;
let mult = scale_to_multiplier(scale);
for &shift in &[-256.0_f64, -8.0, 0.0, 1.5, 100.0] {
for &x in &[-2.0_f64, -0.5, 0.0, 0.125, 1.0, 7.5] {
let q = quantize_float(&x, shift, scale).unwrap();
let felt = crate::fieldutils::integer_rep_to_felt::<Fp>(q);
let recovered = dequantize(felt, scale, shift);
assert!(
(recovered - x).abs() < 1.0 / mult,
"roundtrip failed for shift={shift} x={x}: q={q} recovered={recovered}"
);
}
}
}

#[test]
fn test_quantize_dequantize_known_value_with_shift() {
// Hand-checked value: scale = 4 (mult = 16), shift = 5.0, x = 2.0.
// q = round(2.0 * 16 + 5) = 37
// dequantize(37, 4, 5)
// before fix: 37/16 - 5 = -2.6875 (wrong)
// after fix : (37 - 5)/16 = 2.0 (correct)
let scale = 4;
let shift = 5.0;
let x = 2.0;
let q = quantize_float(&x, shift, scale).unwrap();
assert_eq!(q, 37);
let felt = crate::fieldutils::integer_rep_to_felt::<Fp>(q);
let recovered = dequantize(felt, scale, shift);
assert!((recovered - x).abs() < 1e-9, "expected ~2.0, got {recovered}");
}

#[test]
fn test_quantize_float_asymmetric_bounds_with_shift() {
// With shift != 0 the representable range of `elem` is shifted; the old
// symmetric check `|elem| <= max_value` was incorrect once `shift != 0`.
// Pick a small scale + shift combo and verify both rails are honoured.
let scale = 2; // mult = 4
let shift = 8.0;
// Symmetric values around zero should both succeed: q lands inside i128 by miles.
assert!(quantize_float(&1.0_f64, shift, scale).is_ok());
assert!(quantize_float(&(-1.0_f64), shift, scale).is_ok());
// Non-finite inputs must still error on either side regardless of shift.
assert!(quantize_float(&f64::INFINITY, shift, scale).is_err());
assert!(quantize_float(&f64::NEG_INFINITY, shift, scale).is_err());
}

#[test]
fn test_flatten_valtensors() {
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();
Expand Down