Skip to content

Increasing Discrepancy Between AIMET Simulation and QNN Models with Network Depth #3978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pei0033 opened this issue Apr 17, 2025 · 2 comments
Assignees

Comments

@pei0033
Copy link

pei0033 commented Apr 17, 2025

Hello.
I've observed a systematic increase in the difference between AIMET quantization simulation results and actual QNN model outputs as the network depth increases. This discrepancy could potentially impact the reliability of quantization predictions for deeper networks.

Experimental Setup

  • Model: Simple residual network
  • Each layer consists of: Conv2d(64,64,3) -> BatchNorm2d -> ReLU6
  • Input shape: (1, 64, 128, 128)
  • Quantization settings:
    • 8-bit quantization for both parameters and outputs
    • Training range learning with TF initialization scheme
    • Per-channel quantization configuration

Observed Behavior

The difference between AIMET simulation and QNN execution increases significantly with network depth:

Number of Layers MSE Difference L1 Difference
1 5.64e-05 0.0012
3 3.77e-04 0.0071
5 9.35e-04 0.0148
7 2.35e-03 0.0275
9 5.58e-03 0.0459
11 9.93e-03 0.0641
13 1.52e-02 0.0818
15 2.36e-02 0.1026
17 3.27e-02 0.1223
19 4.18e-02 0.1374
21 5.88e-02 0.1540
23 6.14e-02 0.1633
25 7.35e-02 0.1825
27 8.77e-02 0.1991
29 1.04e-01 0.2153

Questions

  1. Is this behavior expected?
  2. Are there known limitations or assumptions in AIMET simulation that might explain this divergence?
  3. Are there recommended practices for more accurate simulation of deeper networks?

Reproducible Code

Here's the complete code to reproduce this issue:

import torch
import os

from aimet_common.defs import QuantScheme
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch import model_preparer
from aimet_torch import batch_norm_fold
import qai_hub as hub

import shutil
torch.manual_seed(1517)

class SimpleModel(torch.nn.Module):
    def __init__(self, num_layers=10):
        super(SimpleModel, self).__init__()
        self.features = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, 3, padding="same"),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU6(),
        ) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.features:
            x = torch.add(x, layer(x))
        return x


input_shape = (1, 64, 128, 128)

def main(num_layers: int):
    # Step 1: Create and prepare model
    model = SimpleModel(num_layers)
    model = model_preparer.prepare_model(model)  # Prepare for quantization
    batch_norm_fold.fold_all_batch_norms(model, input_shapes=input_shape)  # Fold BN for better quantization

    # Step 2: Setup quantization simulation
    dummy_input = torch.randn(input_shape)  # Create random input tensor
    sim = QuantizationSimModel(
        model,
        dummy_input=dummy_input,
        quant_scheme=QuantScheme.training_range_learning_with_tf_init,  # Use TF initialization
        default_param_bw=8,      # 8-bit quantization for parameters
        default_output_bw=8,     # 8-bit quantization for activations
        config_file=get_path_for_per_channel_config()  # Use per-channel quantization
    )

    # Step 3: Calibrate the quantization parameters
    def pass_calibration_data(model: torch.nn.Module):
        model.eval()
        # Pass random data through model 10 times for calibration
        for _ in range(10):
            model(torch.randn(input_shape))

    sim.compute_encodings(pass_calibration_data)

    # Step 4: Export the quantized model
    model_dir = f"simple_{num_layers}_layer_model"
    file_prefix = f"simple_{num_layers}_layer_model"
    os.makedirs(model_dir, exist_ok=True)
    sim.export(
        model_dir,
        file_prefix,
        dummy_input=dummy_input
    )
    
    # Step 5: Prepare model for QNN compilation
    # Create .aimet directory and copy necessary files
    aimet_dir = f"{model_dir}.aimet"
    os.makedirs(aimet_dir, exist_ok=True)
    shutil.copy(f"{model_dir}/{file_prefix}.encodings", f"{aimet_dir}/{file_prefix}.encodings")
    shutil.copy(f"{model_dir}/{file_prefix}.onnx", f"{aimet_dir}/{file_prefix}.onnx")

    # Step 6: Compile model for target device
    compile_job = hub.submit_compile_job(
        name = f"simple_{num_layers}_layer_model",
        model = aimet_dir,
        device = hub.Device("Samsung Galaxy S24 Ultra"),  # Target device
        options = f"--target_runtime qnn_context_binary --compute_unit all",
    )
    compile_job.download_target_model(f"{model_dir}/{file_prefix}.bin")

    # Step 7: Run inference on target device
    inference_job = hub.submit_inference_job(
        model = compile_job.get_target_model(),
        device = hub.Device("Samsung Galaxy S24 Ultra"),
        inputs = {list(compile_job.target_shapes.keys())[0]: [dummy_input.detach().numpy()]},
    )

    # Get inference results
    data = inference_job.download_output_data()

    # Step 8: Compare AIMET simulation vs QNN results
    torch_output = sim.model(dummy_input)  # AIMET simulation output
    qnn_output = torch.from_numpy(list(data.values())[0][0])  # QNN actual output

    # Calculate differences using MSE and L1 metrics
    mse_diff = torch.nn.functional.mse_loss(torch_output, qnn_output)
    l1_diff = torch.nn.functional.l1_loss(torch_output, qnn_output)

    print(f"num_layers: {num_layers}, MSE diff: {mse_diff}, L1 diff: {l1_diff}")


if __name__ == "__main__":
    # Test models with different depths (1 to 29 layers, odd numbers only)
    for num_layers in range(1, 30, 2):
        try:
            main(num_layers)
        except Exception as e:
            print(f"Error: {e}")

Environment

  • aimet-torch version: 2.3.0+cu121
  • AI_HUB version : 0.26.0
  • Device tested: Samsung Galaxy S24 Ultra
  • Python version: 3.10
  • PyTorch version: 2.4.0
@huijjj
Copy link

huijjj commented Apr 17, 2025

Observing same phenomena with the floating points(fp16) too, is there a way to properly simulate qnn with pytorch?
If not, how could we do QATs?

@quic-bhushans
Copy link
Contributor

Thanks @pei0033 for detailed issue and repro script. Will run this internally and get back to you with initial feedback.

@quic-kyunggeu quic-kyunggeu self-assigned this Apr 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants