You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello.
I've observed a systematic increase in the difference between AIMET quantization simulation results and actual QNN model outputs as the network depth increases. This discrepancy could potentially impact the reliability of quantization predictions for deeper networks.
Experimental Setup
Model: Simple residual network
Each layer consists of: Conv2d(64,64,3) -> BatchNorm2d -> ReLU6
Input shape: (1, 64, 128, 128)
Quantization settings:
8-bit quantization for both parameters and outputs
Training range learning with TF initialization scheme
Per-channel quantization configuration
Observed Behavior
The difference between AIMET simulation and QNN execution increases significantly with network depth:
Number of Layers
MSE Difference
L1 Difference
1
5.64e-05
0.0012
3
3.77e-04
0.0071
5
9.35e-04
0.0148
7
2.35e-03
0.0275
9
5.58e-03
0.0459
11
9.93e-03
0.0641
13
1.52e-02
0.0818
15
2.36e-02
0.1026
17
3.27e-02
0.1223
19
4.18e-02
0.1374
21
5.88e-02
0.1540
23
6.14e-02
0.1633
25
7.35e-02
0.1825
27
8.77e-02
0.1991
29
1.04e-01
0.2153
Questions
Is this behavior expected?
Are there known limitations or assumptions in AIMET simulation that might explain this divergence?
Are there recommended practices for more accurate simulation of deeper networks?
Reproducible Code
Here's the complete code to reproduce this issue:
importtorchimportosfromaimet_common.defsimportQuantSchemefromaimet_common.quantsim_config.utilsimportget_path_for_per_channel_configfromaimet_torch.quantsimimportQuantizationSimModelfromaimet_torchimportmodel_preparerfromaimet_torchimportbatch_norm_foldimportqai_hubashubimportshutiltorch.manual_seed(1517)
classSimpleModel(torch.nn.Module):
def__init__(self, num_layers=10):
super(SimpleModel, self).__init__()
self.features=torch.nn.ModuleList([torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, padding="same"),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU6(),
) for_inrange(num_layers)])
defforward(self, x):
forlayerinself.features:
x=torch.add(x, layer(x))
returnxinput_shape= (1, 64, 128, 128)
defmain(num_layers: int):
# Step 1: Create and prepare modelmodel=SimpleModel(num_layers)
model=model_preparer.prepare_model(model) # Prepare for quantizationbatch_norm_fold.fold_all_batch_norms(model, input_shapes=input_shape) # Fold BN for better quantization# Step 2: Setup quantization simulationdummy_input=torch.randn(input_shape) # Create random input tensorsim=QuantizationSimModel(
model,
dummy_input=dummy_input,
quant_scheme=QuantScheme.training_range_learning_with_tf_init, # Use TF initializationdefault_param_bw=8, # 8-bit quantization for parametersdefault_output_bw=8, # 8-bit quantization for activationsconfig_file=get_path_for_per_channel_config() # Use per-channel quantization
)
# Step 3: Calibrate the quantization parametersdefpass_calibration_data(model: torch.nn.Module):
model.eval()
# Pass random data through model 10 times for calibrationfor_inrange(10):
model(torch.randn(input_shape))
sim.compute_encodings(pass_calibration_data)
# Step 4: Export the quantized modelmodel_dir=f"simple_{num_layers}_layer_model"file_prefix=f"simple_{num_layers}_layer_model"os.makedirs(model_dir, exist_ok=True)
sim.export(
model_dir,
file_prefix,
dummy_input=dummy_input
)
# Step 5: Prepare model for QNN compilation# Create .aimet directory and copy necessary filesaimet_dir=f"{model_dir}.aimet"os.makedirs(aimet_dir, exist_ok=True)
shutil.copy(f"{model_dir}/{file_prefix}.encodings", f"{aimet_dir}/{file_prefix}.encodings")
shutil.copy(f"{model_dir}/{file_prefix}.onnx", f"{aimet_dir}/{file_prefix}.onnx")
# Step 6: Compile model for target devicecompile_job=hub.submit_compile_job(
name=f"simple_{num_layers}_layer_model",
model=aimet_dir,
device=hub.Device("Samsung Galaxy S24 Ultra"), # Target deviceoptions=f"--target_runtime qnn_context_binary --compute_unit all",
)
compile_job.download_target_model(f"{model_dir}/{file_prefix}.bin")
# Step 7: Run inference on target deviceinference_job=hub.submit_inference_job(
model=compile_job.get_target_model(),
device=hub.Device("Samsung Galaxy S24 Ultra"),
inputs= {list(compile_job.target_shapes.keys())[0]: [dummy_input.detach().numpy()]},
)
# Get inference resultsdata=inference_job.download_output_data()
# Step 8: Compare AIMET simulation vs QNN resultstorch_output=sim.model(dummy_input) # AIMET simulation outputqnn_output=torch.from_numpy(list(data.values())[0][0]) # QNN actual output# Calculate differences using MSE and L1 metricsmse_diff=torch.nn.functional.mse_loss(torch_output, qnn_output)
l1_diff=torch.nn.functional.l1_loss(torch_output, qnn_output)
print(f"num_layers: {num_layers}, MSE diff: {mse_diff}, L1 diff: {l1_diff}")
if__name__=="__main__":
# Test models with different depths (1 to 29 layers, odd numbers only)fornum_layersinrange(1, 30, 2):
try:
main(num_layers)
exceptExceptionase:
print(f"Error: {e}")
Environment
aimet-torch version: 2.3.0+cu121
AI_HUB version : 0.26.0
Device tested: Samsung Galaxy S24 Ultra
Python version: 3.10
PyTorch version: 2.4.0
The text was updated successfully, but these errors were encountered:
Hello.
I've observed a systematic increase in the difference between AIMET quantization simulation results and actual QNN model outputs as the network depth increases. This discrepancy could potentially impact the reliability of quantization predictions for deeper networks.
Experimental Setup
Observed Behavior
The difference between AIMET simulation and QNN execution increases significantly with network depth:
Questions
Reproducible Code
Here's the complete code to reproduce this issue:
Environment
The text was updated successfully, but these errors were encountered: