Skip to content

Commit 71206cf

Browse files
NXP backend: Add MobileNetV2 example model and test (#12892)
### Summary Add MobileNetV2 model as example and for integration testing ### Test plan Support for testing full conversion on this model is included in `run_aot_example.sh`. --------- Co-authored-by: Lukas Sztefek <[email protected]>
1 parent 18c71e8 commit 71206cf

File tree

5 files changed

+131
-5
lines changed

5 files changed

+131
-5
lines changed

.github/workflows/pull.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,8 +860,9 @@ jobs:
860860
# Run pytest
861861
PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh
862862
863-
# Run aot example:
864-
PYTHON_EXECUTABLE=python bash examples/nxp/run_aot_example.sh
863+
# Run aot examples:
864+
PYTHON_EXECUTABLE=python bash examples/nxp/run_aot_example.sh cifar10
865+
PYTHON_EXECUTABLE=python bash examples/nxp/run_aot_example.sh mobilenetv2
865866
866867
867868
test-vulkan-models-linux:

backends/nxp/tests/executors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def inference(
5757
return output.detach().numpy()
5858
elif isinstance(output, tuple) and len(output) == 1:
5959
return output[0].detach().numpy()
60+
elif isinstance(output, tuple):
61+
output_names = self.edge_program.graph_signature.user_outputs
62+
63+
return {
64+
name: tensor.detach().numpy()
65+
for (name, tensor) in zip(output_names, output)
66+
}
6067

6168
raise RuntimeError(
6269
"Edge program inference with multiple outputs not implemented"

examples/nxp/aot_neutron_compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
3939

40+
from .models.mobilenet_v2 import MobilenetV2
41+
4042
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4143
logging.basicConfig(level=logging.INFO, format=FORMAT)
4244

@@ -87,7 +89,7 @@ def get_model_and_inputs_from_name(model_name: str):
8789
logging.warning(
8890
"Using a model from examples/models not all of these are currently supported"
8991
)
90-
model, example_inputs, _ = EagerModelFactory.create_model(
92+
model, example_inputs, _, _ = EagerModelFactory.create_model(
9193
*MODEL_NAME_TO_MODEL[model_name]
9294
)
9395
else:
@@ -100,6 +102,7 @@ def get_model_and_inputs_from_name(model_name: str):
100102

101103
models = {
102104
"cifar10": CifarNet,
105+
"mobilenetv2": MobilenetV2,
103106
}
104107

105108

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import itertools
7+
from typing import Iterator
8+
9+
import torch
10+
import torchvision
11+
12+
from executorch.examples.models.mobilenet_v2 import MV2Model
13+
from torch.utils.data import DataLoader
14+
from torchvision import transforms
15+
16+
17+
class MobilenetV2(MV2Model):
18+
19+
def get_calibration_inputs(
20+
self, batch_size: int = 1
21+
) -> Iterator[tuple[torch.Tensor]]:
22+
"""
23+
Returns an iterator for the Imagenette validation dataset, downloading it if necessary.
24+
25+
Args:
26+
batch_size (int): The batch size for the iterator.
27+
28+
Returns:
29+
iterator: An iterator that yields batches of images from the Imagnetette validation dataset.
30+
"""
31+
dataloader = self.get_dataset(batch_size)
32+
33+
# Return the iterator
34+
dataloader_iterable = itertools.starmap(
35+
lambda data, label: (data,), iter(dataloader)
36+
)
37+
38+
# We want approximately 500 samples
39+
batch_count = 500 // batch_size
40+
return itertools.islice(dataloader_iterable, batch_count)
41+
42+
def get_dataset(self, batch_size):
43+
# Define data transformations
44+
data_transforms = transforms.Compose(
45+
[
46+
transforms.Resize((224, 224)),
47+
transforms.ToTensor(),
48+
transforms.Normalize(
49+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
50+
), # ImageNet stats
51+
]
52+
)
53+
54+
dataset = torchvision.datasets.Imagenette(
55+
root="./data", split="val", transform=data_transforms, download=True
56+
)
57+
dataloader = torch.utils.data.DataLoader(
58+
dataset,
59+
batch_size=batch_size,
60+
shuffle=False,
61+
num_workers=1,
62+
)
63+
return dataloader
64+
65+
66+
def gather_samples_per_class_from_dataloader(
67+
dataloader, num_samples_per_class=10
68+
) -> list[tuple]:
69+
"""
70+
Gathers a specified number of samples for each class from a DataLoader.
71+
72+
Args:
73+
dataloader (DataLoader): The PyTorch DataLoader object.
74+
num_samples_per_class (int): The number of samples to gather for each class. Defaults to 10.
75+
76+
Returns:
77+
samples: A list of (sample, label) tuples.
78+
"""
79+
80+
if not isinstance(dataloader, DataLoader):
81+
raise TypeError("dataloader must be a torch.utils.data.DataLoader object")
82+
if not isinstance(num_samples_per_class, int) or num_samples_per_class <= 0:
83+
raise ValueError("num_samples_per_class must be a positive integer")
84+
85+
labels = sorted(
86+
set([label for _, label in dataloader.dataset])
87+
) # Get unique labels from the dataset
88+
samples_per_label = {label: [] for label in labels} # Initialize dictionary
89+
90+
for sample, label in dataloader:
91+
label = label.item()
92+
if len(samples_per_label[label]) < num_samples_per_class:
93+
samples_per_label[label].append((sample, label))
94+
95+
samples = []
96+
97+
for label in labels:
98+
samples.extend(samples_per_label[label])
99+
100+
return samples
101+
102+
103+
def generate_input_samples_file():
104+
model = MobilenetV2()
105+
dataloader = model.get_dataset(batch_size=1)
106+
samples = gather_samples_per_class_from_dataloader(
107+
dataloader, num_samples_per_class=2
108+
)
109+
110+
torch.save(samples, "calibration_data.pt")
111+
112+
113+
if __name__ == "__main__":
114+
generate_input_samples_file()

examples/nxp/run_aot_example.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ set -eux
77

88
SCRIPT_DIR=$(dirname $(readlink -fm $0))
99
EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR))
10+
MODEL=${1:-"cifar10"}
1011

1112
cd $EXECUTORCH_DIR
1213

1314
# Run the AoT example
1415
python -m examples.nxp.aot_neutron_compile --quantize \
15-
--delegate --neutron_converter_flavor SDK_25_03 -m cifar10
16+
--delegate --neutron_converter_flavor SDK_25_03 -m ${MODEL}
1617
# verify file exists
17-
test -f cifar10_nxp_delegate.pte
18+
test -f ${MODEL}_nxp_delegate.pte

0 commit comments

Comments
 (0)