|
| 1 | +use image::imageops; |
| 2 | +use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb}; |
| 3 | +use imageproc::drawing::draw_hollow_rect_mut; |
| 4 | +use imageproc::rect::Rect; |
| 5 | +use log::info; |
| 6 | +use std::collections::HashMap; |
| 7 | +use std::convert::TryInto; |
| 8 | +use std::time::Instant; |
| 9 | +use std::vec; |
| 10 | +use std::{ |
| 11 | + fs, |
| 12 | + io::{BufRead, BufReader}, |
| 13 | + path::Path, |
| 14 | +}; |
| 15 | +use wonnx::WonnxError; |
| 16 | + |
| 17 | +/*----------------------------------------------------------------------------- |
| 18 | + Post processing |
| 19 | +--------------------------------------------------------------------------------*/ |
| 20 | +fn draw_rect(image: &mut ImageBuffer<Rgb<u8>, Vec<u8>>, x1: f32, y1: f32, x2: f32, y2: f32) { |
| 21 | + let x1 = x1 as u32; |
| 22 | + let y1 = y1 as u32; |
| 23 | + let x2 = x2 as u32; |
| 24 | + let y2 = y2 as u32; |
| 25 | + let rect = Rect::at(x1 as i32, y1 as i32).of_size(x2 - x1 as u32, (y2 - y1) as u32); |
| 26 | + draw_hollow_rect_mut(image, rect, Rgb([255, 0, 0])); |
| 27 | +} |
| 28 | + |
| 29 | +fn calc_loc(positions: &Vec<(f32, f32, f32, f32)>) -> Vec<(f32, f32, f32, f32)> { |
| 30 | + let mut locs = vec![]; |
| 31 | + |
| 32 | + // calc girds |
| 33 | + let (h, w) = (416, 416); |
| 34 | + let strides = vec![8, 16, 32]; |
| 35 | + let mut h_grids = vec![]; |
| 36 | + let mut w_grids = vec![]; |
| 37 | + |
| 38 | + for stride in strides.iter() { |
| 39 | + let mut h_grid = vec![0.0; h / stride]; |
| 40 | + let mut w_grid = vec![0.0; w / stride]; |
| 41 | + |
| 42 | + for i in 0..h / stride { |
| 43 | + h_grid[i] = i as f32; |
| 44 | + } |
| 45 | + for i in 0..w / stride { |
| 46 | + w_grid[i] = i as f32; |
| 47 | + } |
| 48 | + h_grids.push(h_grid); |
| 49 | + w_grids.push(w_grid); |
| 50 | + } |
| 51 | + let acc = vec![0, 52 * 52, 52 * 52 + 26 * 26, 52 * 52 + 26 * 26 + 13 * 13]; |
| 52 | + |
| 53 | + for (i, stride) in strides.iter().enumerate() { |
| 54 | + let h_grid = &h_grids[i]; |
| 55 | + let w_grid = &w_grids[i]; |
| 56 | + let idx = acc[i]; |
| 57 | + |
| 58 | + for (i, y) in h_grid.iter().enumerate() { |
| 59 | + for (j, x) in w_grid.iter().enumerate() { |
| 60 | + let p = idx + i * w / stride + j; |
| 61 | + let (px, py, pw, ph) = positions[p]; |
| 62 | + let (x, y) = ((x + px) * *stride as f32, (y + py) * *stride as f32); |
| 63 | + let (ww, hh) = (pw.exp() * *stride as f32, ph.exp() * *stride as f32); |
| 64 | + let loc = (x - ww / 2.0, y - hh / 2.0, x + ww / 2.0, y + hh / 2.0); |
| 65 | + locs.push(loc); |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + locs |
| 70 | +} |
| 71 | + |
| 72 | +fn non_max_suppression( |
| 73 | + boxes: &Vec<(f32, f32, f32, f32)>, |
| 74 | + scores: &Vec<f32>, |
| 75 | + score_threshold: f32, |
| 76 | + iou_threshold: f32, |
| 77 | +) -> Vec<(usize, (f32, f32, f32, f32))> { |
| 78 | + let mut new_boxes = vec![]; |
| 79 | + let mut sorted_indices = (0..boxes.len()).collect::<Vec<_>>(); |
| 80 | + sorted_indices.sort_by(|a, b| scores[*a].partial_cmp(&scores[*b]).unwrap()); |
| 81 | + |
| 82 | + while let Some(last) = sorted_indices.pop() { |
| 83 | + let mut remove_list = vec![]; |
| 84 | + let score = scores[last]; |
| 85 | + let bbox = boxes[last]; |
| 86 | + let mut numerator = ( |
| 87 | + bbox.0 * score, |
| 88 | + bbox.1 * score, |
| 89 | + bbox.2 * score, |
| 90 | + bbox.3 * score, |
| 91 | + ); |
| 92 | + let mut denominator = score; |
| 93 | + |
| 94 | + for i in 0..sorted_indices.len() { |
| 95 | + let idx = sorted_indices[i]; |
| 96 | + let (x1, y1, x2, y2) = boxes[idx]; |
| 97 | + let (x1_, y1_, x2_, y2_) = boxes[last]; |
| 98 | + let box1_area = (x2 - x1) * (y2 - y1); |
| 99 | + |
| 100 | + let inter_x1 = x1.max(x1_); |
| 101 | + let inter_y1 = y1.max(y1_); |
| 102 | + let inter_x2 = x2.min(x2_); |
| 103 | + let inter_y2 = y2.min(y2_); |
| 104 | + let inter_w = (inter_x2 - inter_x1).max(0.0); |
| 105 | + let inter_h = (inter_y2 - inter_y1).max(0.0); |
| 106 | + let inter_area = inter_w * inter_h; |
| 107 | + let area1 = (x2 - x1) * (y2 - y1); |
| 108 | + let area2 = (x2_ - x1_) * (y2_ - y1_); |
| 109 | + let union_area = area1 + area2 - inter_area; |
| 110 | + let iou = inter_area / union_area; |
| 111 | + |
| 112 | + if scores[idx] < score_threshold { |
| 113 | + remove_list.push(i); |
| 114 | + } else if iou > iou_threshold { |
| 115 | + remove_list.push(i); |
| 116 | + let w = scores[idx] * iou; |
| 117 | + numerator = ( |
| 118 | + numerator.0 + boxes[idx].0 * w, |
| 119 | + numerator.1 + boxes[idx].1 * w, |
| 120 | + numerator.2 + boxes[idx].2 * w, |
| 121 | + numerator.3 + boxes[idx].3 * w, |
| 122 | + ); |
| 123 | + denominator += w; |
| 124 | + } else if inter_area / box1_area > 0.7 { |
| 125 | + remove_list.push(i); |
| 126 | + } |
| 127 | + } |
| 128 | + for i in remove_list.iter().rev() { |
| 129 | + sorted_indices.remove(*i); |
| 130 | + } |
| 131 | + let new_bbox = ( |
| 132 | + numerator.0 / denominator, |
| 133 | + numerator.1 / denominator, |
| 134 | + numerator.2 / denominator, |
| 135 | + numerator.3 / denominator, |
| 136 | + ); |
| 137 | + new_boxes.push((last, new_bbox)); |
| 138 | + } |
| 139 | + new_boxes |
| 140 | +} |
| 141 | + |
| 142 | +fn post_process(preds: &[f32]) -> Vec<(String, f32, f32, f32, f32, f32)> { |
| 143 | + let labels = get_coco_labels(); |
| 144 | + let mut positions = vec![]; |
| 145 | + let mut classes = vec![]; |
| 146 | + let mut objectnesses = vec![]; |
| 147 | + for i in 0..3549 { |
| 148 | + let offset = i * 85; |
| 149 | + let objectness = preds[offset + 4]; |
| 150 | + |
| 151 | + let (class, score) = preds[offset + 5..offset + 85] |
| 152 | + .iter() |
| 153 | + .enumerate() |
| 154 | + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) |
| 155 | + .unwrap(); |
| 156 | + let class = labels[class].clone(); |
| 157 | + let x1 = preds[offset]; |
| 158 | + let y1 = preds[offset + 1]; |
| 159 | + let x2 = preds[offset + 2]; |
| 160 | + let y2 = preds[offset + 3]; |
| 161 | + classes.push((class, score)); |
| 162 | + positions.push((x1, y1, x2, y2)); |
| 163 | + objectnesses.push(objectness); |
| 164 | + } |
| 165 | + |
| 166 | + let locs = calc_loc(&positions); |
| 167 | + |
| 168 | + let mut result = vec![]; |
| 169 | + // filter by objectness |
| 170 | + let indices = non_max_suppression(&locs, &objectnesses, 0.5, 0.3); |
| 171 | + for bbox in indices { |
| 172 | + let (i, (x, y, w, h)) = bbox; |
| 173 | + let (class, &score) = &classes[i]; |
| 174 | + result.push((class.clone(), score, x, y, w, h)); |
| 175 | + } |
| 176 | + result |
| 177 | +} |
| 178 | + |
| 179 | +/*----------------------------------------------------------------------------- |
| 180 | + Pre processing |
| 181 | +--------------------------------------------------------------------------------*/ |
| 182 | +fn padding_image(image: ImageBuffer<Rgb<u8>, Vec<u8>>) -> ImageBuffer<Rgb<u8>, Vec<u8>> { |
| 183 | + let (width, height) = image.dimensions(); |
| 184 | + let target_size = if width > height { width } else { height }; |
| 185 | + let mut new_image = ImageBuffer::new(target_size as u32, target_size as u32); |
| 186 | + let x_offset = (target_size as u32 - width) / 2; |
| 187 | + let y_offset = (target_size as u32 - height) / 2; |
| 188 | + for j in 0..height { |
| 189 | + for i in 0..width { |
| 190 | + let pixel = image.get_pixel(i, j); |
| 191 | + new_image.put_pixel(i + x_offset, j + y_offset, *pixel); |
| 192 | + } |
| 193 | + } |
| 194 | + new_image |
| 195 | +} |
| 196 | + |
| 197 | +fn load_image() -> (Vec<f32>, ImageBuffer<Rgb<u8>, Vec<u8>>) { |
| 198 | + let args: Vec<String> = std::env::args().collect(); |
| 199 | + let image_path = if args.len() == 2 { |
| 200 | + Path::new(&args[1]).to_path_buf() |
| 201 | + } else { |
| 202 | + Path::new(env!("CARGO_MANIFEST_DIR")) |
| 203 | + .join("../data/images") |
| 204 | + .join("dog.jpg") |
| 205 | + }; |
| 206 | + |
| 207 | + let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(image_path).unwrap().to_rgb8(); |
| 208 | + let image_buffer = padding_image(image_buffer); |
| 209 | + let image_buffer = imageops::resize(&image_buffer, 416, 416, FilterType::Nearest); |
| 210 | + |
| 211 | + // convert image to Vec<f32> with channel first format |
| 212 | + let mut image = vec![0.0; 3 * 416 * 416]; |
| 213 | + for j in 0..416 { |
| 214 | + for i in 0..416 { |
| 215 | + let pixel = image_buffer.get_pixel(i as u32, j as u32); |
| 216 | + let channels = pixel.channels(); |
| 217 | + for c in 0..3 { |
| 218 | + image[c * 416 * 416 + j * 416 + i] = channels[c] as f32; |
| 219 | + } |
| 220 | + } |
| 221 | + } |
| 222 | + return (image, image_buffer); |
| 223 | +} |
| 224 | + |
| 225 | +fn get_coco_labels() -> Vec<String> { |
| 226 | + // Download the ImageNet class labels, matching SqueezeNet's classes. |
| 227 | + let labels_path = Path::new(env!("CARGO_MANIFEST_DIR")) |
| 228 | + .join("../data/models") |
| 229 | + .join("coco-classes.txt"); |
| 230 | + let file = BufReader::new(fs::File::open(labels_path).unwrap()); |
| 231 | + |
| 232 | + file.lines().map(|line| line.unwrap()).collect() |
| 233 | +} |
| 234 | + |
| 235 | +/*----------------------------------------------------------------------------- |
| 236 | + Main |
| 237 | +--------------------------------------------------------------------------------*/ |
| 238 | +// Hardware management |
| 239 | +async fn execute_gpu() -> Result<Vec<(String, f32, f32, f32, f32, f32)>, WonnxError> { |
| 240 | + let mut input_data = HashMap::new(); |
| 241 | + let (image, _) = load_image(); |
| 242 | + let images = image.as_slice().try_into().unwrap(); |
| 243 | + input_data.insert("images".to_string(), images); |
| 244 | + |
| 245 | + let model_path = Path::new(env!("CARGO_MANIFEST_DIR")) |
| 246 | + .join("../data/models") |
| 247 | + .join("yolox_nano.onnx"); |
| 248 | + let session = wonnx::Session::from_path(model_path).await?; |
| 249 | + let time_pre_compute = Instant::now(); |
| 250 | + |
| 251 | + info!("Start Compute"); |
| 252 | + let result = session.run(&input_data).await?; |
| 253 | + let time_post_compute = Instant::now(); |
| 254 | + println!( |
| 255 | + "time: first_prediction: {:#?}", |
| 256 | + time_post_compute - time_pre_compute |
| 257 | + ); |
| 258 | + |
| 259 | + info!("Start Post Processing"); |
| 260 | + let time_pre_compute = Instant::now(); |
| 261 | + let output = result.get("output").unwrap(); |
| 262 | + let output = output.try_into().unwrap(); |
| 263 | + let positions = post_process(output); |
| 264 | + let time_post_compute = Instant::now(); |
| 265 | + println!( |
| 266 | + "time: post_processing: {:#?}", |
| 267 | + time_post_compute - time_pre_compute |
| 268 | + ); |
| 269 | + |
| 270 | + Ok(positions) |
| 271 | +} |
| 272 | + |
| 273 | +async fn run() { |
| 274 | + // Output shape is [1, 3549, 85] |
| 275 | + // 85 = 4 (bounding box) + 1 (objectness) + 80 (class probabilities) |
| 276 | + let preds = execute_gpu().await.unwrap(); |
| 277 | + |
| 278 | + let (_, image_buffer) = load_image(); |
| 279 | + let mut image_buffer = image_buffer; |
| 280 | + for (class, score, x0, y0, x1, y1) in preds.iter() { |
| 281 | + println!( |
| 282 | + "class: {}, score: {}, x0: {}, y0: {}, x1: {}, y1: {}", |
| 283 | + class, *score, *x0, *y0, *x1, *y1 |
| 284 | + ); |
| 285 | + draw_rect(&mut image_buffer, *x0, *y0, *x1, *y1); |
| 286 | + } |
| 287 | + image_buffer.save("yolox_predict.jpg").unwrap(); |
| 288 | +} |
| 289 | + |
| 290 | +fn main() { |
| 291 | + #[cfg(not(target_arch = "wasm32"))] |
| 292 | + { |
| 293 | + env_logger::init(); |
| 294 | + let time_pre_compute = Instant::now(); |
| 295 | + |
| 296 | + pollster::block_on(run()); |
| 297 | + let time_post_compute = Instant::now(); |
| 298 | + println!("time: main: {:#?}", time_post_compute - time_pre_compute); |
| 299 | + } |
| 300 | + #[cfg(target_arch = "wasm32")] |
| 301 | + { |
| 302 | + // std::panic::set_hook(Box::new(console_error_panic_hook::hook)); |
| 303 | + // console_log::init().expect("could not initialize logger"); |
| 304 | + wasm_bindgen_futures::spawn_local(run()); |
| 305 | + } |
| 306 | +} |
0 commit comments