Skip to content

Commit e7dd278

Browse files
committed
Integrate Changes from kadu-v: webonnx#214
Add: initial slice operator wip Update: the implementation of slice operator to move the fixed inputs to attributes Update: concat operator to set an axis attribute Update: Concat.wgsl to adapt any axis WIP: yolox nano sample Update: the post processing of yolox_ nano sample Cleanup: the example code for yolox_nano Cleanup Update: README.md Remove: unnecessary images
1 parent dc1b3ac commit e7dd278

File tree

10 files changed

+1933
-20
lines changed

10 files changed

+1933
-20
lines changed

Cargo.lock

+497-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ fn test_matmul_square_matrix() {
346346
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sin">Sin</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sin-7">7</a>|||
347347
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sinh">Sinh</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sinh-9">9</a>|||
348348
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Size">Size</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Size-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Size-1">1</a>|||
349-
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Slice">Slice</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-10">10</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-1">1</a>|||
349+
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Slice">Slice</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-10">10</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Slice-1">1</a>|||
350350
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softplus">Softplus</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Softplus-1">1</a>||
351351
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softsign">Softsign</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Softsign-1">1</a>||
352352
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#SpaceToDepth">SpaceToDepth</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#SpaceToDepth-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#SpaceToDepth-1">1</a>|
@@ -398,6 +398,8 @@ fn test_matmul_square_matrix() {
398398
* For `MatMul` and `Gemm`, the matrix dimensions must be divisible by 2, or the output matrix must be of size (1, N). Matrix
399399
multiplication only supports floats, not integers (this is a WebGPU/WGSL limitation).
400400

401+
* The Slice operator can only be computed for axes of length one. (i.e., there must always be exactly one axis.)
402+
401403
### Shape inference
402404

403405
WONNX needs to know the shape of input and output tensors for each operation in order to generate shader code for executing

wonnx/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ futures = "^0.3.26"
4444
parking_lot = { version = "0.11.1", features = ["wasm-bindgen"] }
4545

4646
[dev-dependencies]
47-
image = "0.24.2"
47+
image = "0.25.1"
48+
imageproc = "0.24.0"
4849
ndarray = "0.15.4"
4950
approx = "0.5.1"
5051
pollster = "0.3.0"

wonnx/examples/yolox_nano.rs

+306
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
use image::imageops;
2+
use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
3+
use imageproc::drawing::draw_hollow_rect_mut;
4+
use imageproc::rect::Rect;
5+
use log::info;
6+
use std::collections::HashMap;
7+
use std::convert::TryInto;
8+
use std::time::Instant;
9+
use std::vec;
10+
use std::{
11+
fs,
12+
io::{BufRead, BufReader},
13+
path::Path,
14+
};
15+
use wonnx::WonnxError;
16+
17+
/*-----------------------------------------------------------------------------
18+
Post processing
19+
--------------------------------------------------------------------------------*/
20+
fn draw_rect(image: &mut ImageBuffer<Rgb<u8>, Vec<u8>>, x1: f32, y1: f32, x2: f32, y2: f32) {
21+
let x1 = x1 as u32;
22+
let y1 = y1 as u32;
23+
let x2 = x2 as u32;
24+
let y2 = y2 as u32;
25+
let rect = Rect::at(x1 as i32, y1 as i32).of_size(x2 - x1 as u32, (y2 - y1) as u32);
26+
draw_hollow_rect_mut(image, rect, Rgb([255, 0, 0]));
27+
}
28+
29+
fn calc_loc(positions: &Vec<(f32, f32, f32, f32)>) -> Vec<(f32, f32, f32, f32)> {
30+
let mut locs = vec![];
31+
32+
// calc girds
33+
let (h, w) = (416, 416);
34+
let strides = vec![8, 16, 32];
35+
let mut h_grids = vec![];
36+
let mut w_grids = vec![];
37+
38+
for stride in strides.iter() {
39+
let mut h_grid = vec![0.0; h / stride];
40+
let mut w_grid = vec![0.0; w / stride];
41+
42+
for i in 0..h / stride {
43+
h_grid[i] = i as f32;
44+
}
45+
for i in 0..w / stride {
46+
w_grid[i] = i as f32;
47+
}
48+
h_grids.push(h_grid);
49+
w_grids.push(w_grid);
50+
}
51+
let acc = vec![0, 52 * 52, 52 * 52 + 26 * 26, 52 * 52 + 26 * 26 + 13 * 13];
52+
53+
for (i, stride) in strides.iter().enumerate() {
54+
let h_grid = &h_grids[i];
55+
let w_grid = &w_grids[i];
56+
let idx = acc[i];
57+
58+
for (i, y) in h_grid.iter().enumerate() {
59+
for (j, x) in w_grid.iter().enumerate() {
60+
let p = idx + i * w / stride + j;
61+
let (px, py, pw, ph) = positions[p];
62+
let (x, y) = ((x + px) * *stride as f32, (y + py) * *stride as f32);
63+
let (ww, hh) = (pw.exp() * *stride as f32, ph.exp() * *stride as f32);
64+
let loc = (x - ww / 2.0, y - hh / 2.0, x + ww / 2.0, y + hh / 2.0);
65+
locs.push(loc);
66+
}
67+
}
68+
}
69+
locs
70+
}
71+
72+
fn non_max_suppression(
73+
boxes: &Vec<(f32, f32, f32, f32)>,
74+
scores: &Vec<f32>,
75+
score_threshold: f32,
76+
iou_threshold: f32,
77+
) -> Vec<(usize, (f32, f32, f32, f32))> {
78+
let mut new_boxes = vec![];
79+
let mut sorted_indices = (0..boxes.len()).collect::<Vec<_>>();
80+
sorted_indices.sort_by(|a, b| scores[*a].partial_cmp(&scores[*b]).unwrap());
81+
82+
while let Some(last) = sorted_indices.pop() {
83+
let mut remove_list = vec![];
84+
let score = scores[last];
85+
let bbox = boxes[last];
86+
let mut numerator = (
87+
bbox.0 * score,
88+
bbox.1 * score,
89+
bbox.2 * score,
90+
bbox.3 * score,
91+
);
92+
let mut denominator = score;
93+
94+
for i in 0..sorted_indices.len() {
95+
let idx = sorted_indices[i];
96+
let (x1, y1, x2, y2) = boxes[idx];
97+
let (x1_, y1_, x2_, y2_) = boxes[last];
98+
let box1_area = (x2 - x1) * (y2 - y1);
99+
100+
let inter_x1 = x1.max(x1_);
101+
let inter_y1 = y1.max(y1_);
102+
let inter_x2 = x2.min(x2_);
103+
let inter_y2 = y2.min(y2_);
104+
let inter_w = (inter_x2 - inter_x1).max(0.0);
105+
let inter_h = (inter_y2 - inter_y1).max(0.0);
106+
let inter_area = inter_w * inter_h;
107+
let area1 = (x2 - x1) * (y2 - y1);
108+
let area2 = (x2_ - x1_) * (y2_ - y1_);
109+
let union_area = area1 + area2 - inter_area;
110+
let iou = inter_area / union_area;
111+
112+
if scores[idx] < score_threshold {
113+
remove_list.push(i);
114+
} else if iou > iou_threshold {
115+
remove_list.push(i);
116+
let w = scores[idx] * iou;
117+
numerator = (
118+
numerator.0 + boxes[idx].0 * w,
119+
numerator.1 + boxes[idx].1 * w,
120+
numerator.2 + boxes[idx].2 * w,
121+
numerator.3 + boxes[idx].3 * w,
122+
);
123+
denominator += w;
124+
} else if inter_area / box1_area > 0.7 {
125+
remove_list.push(i);
126+
}
127+
}
128+
for i in remove_list.iter().rev() {
129+
sorted_indices.remove(*i);
130+
}
131+
let new_bbox = (
132+
numerator.0 / denominator,
133+
numerator.1 / denominator,
134+
numerator.2 / denominator,
135+
numerator.3 / denominator,
136+
);
137+
new_boxes.push((last, new_bbox));
138+
}
139+
new_boxes
140+
}
141+
142+
fn post_process(preds: &[f32]) -> Vec<(String, f32, f32, f32, f32, f32)> {
143+
let labels = get_coco_labels();
144+
let mut positions = vec![];
145+
let mut classes = vec![];
146+
let mut objectnesses = vec![];
147+
for i in 0..3549 {
148+
let offset = i * 85;
149+
let objectness = preds[offset + 4];
150+
151+
let (class, score) = preds[offset + 5..offset + 85]
152+
.iter()
153+
.enumerate()
154+
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
155+
.unwrap();
156+
let class = labels[class].clone();
157+
let x1 = preds[offset];
158+
let y1 = preds[offset + 1];
159+
let x2 = preds[offset + 2];
160+
let y2 = preds[offset + 3];
161+
classes.push((class, score));
162+
positions.push((x1, y1, x2, y2));
163+
objectnesses.push(objectness);
164+
}
165+
166+
let locs = calc_loc(&positions);
167+
168+
let mut result = vec![];
169+
// filter by objectness
170+
let indices = non_max_suppression(&locs, &objectnesses, 0.5, 0.3);
171+
for bbox in indices {
172+
let (i, (x, y, w, h)) = bbox;
173+
let (class, &score) = &classes[i];
174+
result.push((class.clone(), score, x, y, w, h));
175+
}
176+
result
177+
}
178+
179+
/*-----------------------------------------------------------------------------
180+
Pre processing
181+
--------------------------------------------------------------------------------*/
182+
fn padding_image(image: ImageBuffer<Rgb<u8>, Vec<u8>>) -> ImageBuffer<Rgb<u8>, Vec<u8>> {
183+
let (width, height) = image.dimensions();
184+
let target_size = if width > height { width } else { height };
185+
let mut new_image = ImageBuffer::new(target_size as u32, target_size as u32);
186+
let x_offset = (target_size as u32 - width) / 2;
187+
let y_offset = (target_size as u32 - height) / 2;
188+
for j in 0..height {
189+
for i in 0..width {
190+
let pixel = image.get_pixel(i, j);
191+
new_image.put_pixel(i + x_offset, j + y_offset, *pixel);
192+
}
193+
}
194+
new_image
195+
}
196+
197+
fn load_image() -> (Vec<f32>, ImageBuffer<Rgb<u8>, Vec<u8>>) {
198+
let args: Vec<String> = std::env::args().collect();
199+
let image_path = if args.len() == 2 {
200+
Path::new(&args[1]).to_path_buf()
201+
} else {
202+
Path::new(env!("CARGO_MANIFEST_DIR"))
203+
.join("../data/images")
204+
.join("dog.jpg")
205+
};
206+
207+
let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(image_path).unwrap().to_rgb8();
208+
let image_buffer = padding_image(image_buffer);
209+
let image_buffer = imageops::resize(&image_buffer, 416, 416, FilterType::Nearest);
210+
211+
// convert image to Vec<f32> with channel first format
212+
let mut image = vec![0.0; 3 * 416 * 416];
213+
for j in 0..416 {
214+
for i in 0..416 {
215+
let pixel = image_buffer.get_pixel(i as u32, j as u32);
216+
let channels = pixel.channels();
217+
for c in 0..3 {
218+
image[c * 416 * 416 + j * 416 + i] = channels[c] as f32;
219+
}
220+
}
221+
}
222+
return (image, image_buffer);
223+
}
224+
225+
fn get_coco_labels() -> Vec<String> {
226+
// Download the ImageNet class labels, matching SqueezeNet's classes.
227+
let labels_path = Path::new(env!("CARGO_MANIFEST_DIR"))
228+
.join("../data/models")
229+
.join("coco-classes.txt");
230+
let file = BufReader::new(fs::File::open(labels_path).unwrap());
231+
232+
file.lines().map(|line| line.unwrap()).collect()
233+
}
234+
235+
/*-----------------------------------------------------------------------------
236+
Main
237+
--------------------------------------------------------------------------------*/
238+
// Hardware management
239+
async fn execute_gpu() -> Result<Vec<(String, f32, f32, f32, f32, f32)>, WonnxError> {
240+
let mut input_data = HashMap::new();
241+
let (image, _) = load_image();
242+
let images = image.as_slice().try_into().unwrap();
243+
input_data.insert("images".to_string(), images);
244+
245+
let model_path = Path::new(env!("CARGO_MANIFEST_DIR"))
246+
.join("../data/models")
247+
.join("yolox_nano.onnx");
248+
let session = wonnx::Session::from_path(model_path).await?;
249+
let time_pre_compute = Instant::now();
250+
251+
info!("Start Compute");
252+
let result = session.run(&input_data).await?;
253+
let time_post_compute = Instant::now();
254+
println!(
255+
"time: first_prediction: {:#?}",
256+
time_post_compute - time_pre_compute
257+
);
258+
259+
info!("Start Post Processing");
260+
let time_pre_compute = Instant::now();
261+
let output = result.get("output").unwrap();
262+
let output = output.try_into().unwrap();
263+
let positions = post_process(output);
264+
let time_post_compute = Instant::now();
265+
println!(
266+
"time: post_processing: {:#?}",
267+
time_post_compute - time_pre_compute
268+
);
269+
270+
Ok(positions)
271+
}
272+
273+
async fn run() {
274+
// Output shape is [1, 3549, 85]
275+
// 85 = 4 (bounding box) + 1 (objectness) + 80 (class probabilities)
276+
let preds = execute_gpu().await.unwrap();
277+
278+
let (_, image_buffer) = load_image();
279+
let mut image_buffer = image_buffer;
280+
for (class, score, x0, y0, x1, y1) in preds.iter() {
281+
println!(
282+
"class: {}, score: {}, x0: {}, y0: {}, x1: {}, y1: {}",
283+
class, *score, *x0, *y0, *x1, *y1
284+
);
285+
draw_rect(&mut image_buffer, *x0, *y0, *x1, *y1);
286+
}
287+
image_buffer.save("yolox_predict.jpg").unwrap();
288+
}
289+
290+
fn main() {
291+
#[cfg(not(target_arch = "wasm32"))]
292+
{
293+
env_logger::init();
294+
let time_pre_compute = Instant::now();
295+
296+
pollster::block_on(run());
297+
let time_post_compute = Instant::now();
298+
println!("time: main: {:#?}", time_post_compute - time_pre_compute);
299+
}
300+
#[cfg(target_arch = "wasm32")]
301+
{
302+
// std::panic::set_hook(Box::new(console_error_panic_hook::hook));
303+
// console_log::init().expect("could not initialize logger");
304+
wasm_bindgen_futures::spawn_local(run());
305+
}
306+
}

0 commit comments

Comments
 (0)