|
| 1 | +#' Pascal VOC Segmentation Dataset |
| 2 | +#' |
| 3 | +#' The Pascal Visual Object Classes (VOC) dataset is a widely used benchmark for object detection and semantic segmentation tasks in computer vision. |
| 4 | +#' |
| 5 | +#' This dataset provides RGB images along with per-pixel class segmentation masks for 20 object categories, plus a background class. |
| 6 | +#' Each pixel in the mask is labeled with a class index corresponding to one of the predefined semantic categories. |
| 7 | +#' |
| 8 | +#' The VOC dataset was released in yearly editions (2007 to 2012), with slight variations in data splits and annotation formats. |
| 9 | +#' Notably, only the 2007 edition includes a separate `test` split; all other years (2008–2012) provide only the `train`, `val`, and `trainval` splits. |
| 10 | +#' |
| 11 | +#' The dataset defines 21 semantic classes: `"background"`, `"aeroplane"`, `"bicycle"`, `"bird"`, `"boat"`, `"bottle"`, `"bus"`, `"car"`, `"cat"`, `"chair"`, |
| 12 | +#' `"cow"`, `"dining table"`, `"dog"`, `"horse"`, `"motorbike"`, `"person"`, `"potted plant"`, `"sheep"`, `"sofa"`, `"train"`, and `"tv/monitor"`. |
| 13 | +#' They are available through the `classes` variable of the dataset object. |
| 14 | +#' |
| 15 | +#' This dataset is frequently used for training and evaluating semantic segmentation models, and supports tasks requiring dense, per-pixel annotations. |
| 16 | +#' |
| 17 | +#' @inheritParams oxfordiiitpet_dataset |
| 18 | +#' @param root Character. Root directory where the dataset will be stored under `root/pascal_voc_<year>`. |
| 19 | +#' @param year Character. VOC dataset version to use. One of `"2007"`, `"2008"`, `"2009"`, `"2010"`, `"2011"`, or `"2012"`. Default is `"2012"`. |
| 20 | +#' @param split Character. One of `"train"`, `"val"`, `"trainval"`, or `"test"`. Determines the dataset split. Default is `"train"`. |
| 21 | +#' |
| 22 | +#' @return A torch dataset of class \code{pascal_segmentation_dataset}. |
| 23 | +#' |
| 24 | +#' The returned list inherits class \code{image_with_segmentation_mask}, which allows generic visualization |
| 25 | +#' utilities to be applied. |
| 26 | +#' |
| 27 | +#' Each element is a named list with the following structure: |
| 28 | +#' - `x`: a H x W x 3 array representing the RGB image. |
| 29 | +#' - `y`: A named list containing: |
| 30 | +#' - `masks`: A `torch_tensor` of dtype `bool` and shape `(21, H, W)`, representing a multi-channel segmentation mask. |
| 31 | +#' Each of the 21 channels corresponds to a Pascal VOC classes |
| 32 | +#' - `labels`: An integer vector indicating the indices of the classes present in the mask. |
| 33 | +#' |
| 34 | +#' @examples |
| 35 | +#' \dontrun{ |
| 36 | +#' # Load Pascal VOC segmentation dataset (2007 train split) |
| 37 | +#' pascal_seg <- pascal_segmentation_dataset( |
| 38 | +#' transform = transform_to_tensor, |
| 39 | +#' download = TRUE, |
| 40 | +#' year = "2007" |
| 41 | +#' ) |
| 42 | +#' |
| 43 | +#' # Access the first image and its mask |
| 44 | +#' first_item <- pascal_seg[1] |
| 45 | +#' first_item$x # Image |
| 46 | +#' first_item$y$masks # Segmentation mask |
| 47 | +#' first_item$y$labels # Unique class labels in the mask |
| 48 | +#' pascal_seg$classes[first_item$y$labels] # Class names |
| 49 | +#' |
| 50 | +#' # Visualise the first image and its mask |
| 51 | +#' masked_img <- draw_segmentation_masks(first_item) |
| 52 | +#' tensor_image_browse(masked_img) |
| 53 | +#' |
| 54 | +#' # Load Pascal VOC detection dataset (2007 train split) |
| 55 | +#' pascal_det <- pascal_detection_dataset( |
| 56 | +#' transform = transform_to_tensor, |
| 57 | +#' download = TRUE, |
| 58 | +#' year = "2007" |
| 59 | +#' ) |
| 60 | +#' |
| 61 | +#' # Access the first image and its bounding boxes |
| 62 | +#' first_item <- pascal_det[1] |
| 63 | +#' first_item$x # Image |
| 64 | +#' first_item$y$labels # Object labels |
| 65 | +#' first_item$y$boxes # Bounding box tensor |
| 66 | +#' |
| 67 | +#' # Visualise the first image with bounding boxes |
| 68 | +#' boxed_img <- draw_bounding_boxes(first_item) |
| 69 | +#' tensor_image_browse(boxed_img) |
| 70 | +#' } |
| 71 | +#' |
| 72 | +#' @name pascal_voc_datasets |
| 73 | +#' @title Pascal VOC Datasets |
| 74 | +#' @rdname pascal_voc_datasets |
| 75 | +#' @family segmentation_dataset |
| 76 | +#' @export |
| 77 | +pascal_segmentation_dataset <- torch::dataset( |
| 78 | + name = "pascal_segmentation_dataset", |
| 79 | + |
| 80 | + resources = data.frame( |
| 81 | + year = c("2007", "2007", "2008", "2009", "2010", "2011", "2012"), |
| 82 | + type = c("trainval", "test", "trainval", "trainval", "trainval", "trainval", "trainval"), |
| 83 | + url = c("https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtrainval_06-Nov-2007.tar", |
| 84 | + "https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtest_06-Nov-2007.tar", |
| 85 | + "https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtrainval_14-Jul-2008.tar", |
| 86 | + "https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtrainval_11-May-2009.tar", |
| 87 | + "https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtrainval_03-May-2010.tar", |
| 88 | + "https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtrainval_25-May-2011.tar", |
| 89 | + "https://huggingface.co/datasets/JimmyUnleashed/Pascal_VOC/resolve/main/VOCtrainval_11-May-2012.tar"), |
| 90 | + md5 = c("c52e279531787c972589f7e41ab4ae64", |
| 91 | + "b6e924de25625d8de591ea690078ad9f", |
| 92 | + "2629fa636546599198acfcfbfcf1904a", |
| 93 | + "59065e4b188729180974ef6572f6a212", |
| 94 | + "da459979d0c395079b5c75ee67908abb", |
| 95 | + "6c3384ef61512963050cb5d687e5bf1e", |
| 96 | + "6cd6e144f989b92b3379bac3b3de84fd"), |
| 97 | + size = c("440 MB", "440 MB", "550 MB", "890 MB", "1.3 GB", "1.7 GB", "1.9 GB") |
| 98 | + ), |
| 99 | + classes = c( |
| 100 | + "background", "aeroplane", "bicycle", "bird", "boat", |
| 101 | + "bottle", "bus", "car", "cat", "chair", |
| 102 | + "cow", "dining table", "dog", "horse", "motorbike", |
| 103 | + "person", "potted plant", "sheep", "sofa", "train", |
| 104 | + "tv/monitor" |
| 105 | + ), |
| 106 | + voc_colormap = c( |
| 107 | + c(0, 0, 0), c(128, 0, 0), c(0, 128, 0), c(128, 128, 0), |
| 108 | + c(0, 0, 128), c(128, 0, 128), c(0, 128, 128), c(128, 128, 128), |
| 109 | + c(64, 0, 0), c(192, 0, 0), c(64, 128, 0), c(192, 128, 0), |
| 110 | + c(64, 0, 128), c(192, 0, 128), c(64, 128, 128), c(192, 128, 128), |
| 111 | + c(0, 64, 0), c(128, 64, 0), c(0, 192, 0), c(128, 192, 0), |
| 112 | + c(0, 64, 128) |
| 113 | + ), |
| 114 | + |
| 115 | + initialize = function( |
| 116 | + root = tempdir(), |
| 117 | + year = "2012", |
| 118 | + split = "train", |
| 119 | + transform = NULL, |
| 120 | + target_transform = NULL, |
| 121 | + download = FALSE |
| 122 | + ) { |
| 123 | + self$root_path <- root |
| 124 | + self$year <- match.arg(year, choices = unique(self$resources$year)) |
| 125 | + self$split <- match.arg(split, choices = c("train", "val", "trainval", "test")) |
| 126 | + self$transform <- transform |
| 127 | + self$target_transform <- target_transform |
| 128 | + if (self$split == "test"){ |
| 129 | + self$archive_key <- "test" |
| 130 | + } else { |
| 131 | + self$archive_key <- "trainval" |
| 132 | + } |
| 133 | + self$archive_size <- self$resources[self$resources$year == self$year & self$resources$type == self$archive_key,]$size |
| 134 | + |
| 135 | + if (download) { |
| 136 | + cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.") |
| 137 | + self$download() |
| 138 | + } |
| 139 | + |
| 140 | + if (!self$check_exists()) { |
| 141 | + cli_abort("Dataset not found. You can use `download = TRUE` to download it.") |
| 142 | + } |
| 143 | + |
| 144 | + data_file <- file.path(self$processed_folder, paste0(self$split, ".rds")) |
| 145 | + data <- readRDS(data_file) |
| 146 | + self$img_path <- data$img_path |
| 147 | + self$mask_paths <- data$mask_paths |
| 148 | + |
| 149 | + cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.") |
| 150 | + }, |
| 151 | + |
| 152 | + download = function() { |
| 153 | + |
| 154 | + if (self$check_exists()) { |
| 155 | + return() |
| 156 | + } |
| 157 | + |
| 158 | + cli_inform("Downloading {.cls {class(self)[[1]]}}...") |
| 159 | + |
| 160 | + fs::dir_create(self$raw_folder) |
| 161 | + fs::dir_create(self$processed_folder) |
| 162 | + |
| 163 | + resource <- self$resources[self$resources$year == self$year & self$resources$type == self$archive_key,] |
| 164 | + archive <- download_and_cache(resource$url, prefix = class(self)[1]) |
| 165 | + actual_md5 <- tools::md5sum(archive) |
| 166 | + |
| 167 | + if (actual_md5 != resource$md5) { |
| 168 | + runtime_error("Corrupt file! Delete the file in {archive} and try again.") |
| 169 | + } |
| 170 | + |
| 171 | + utils::untar(archive, exdir = self$raw_folder) |
| 172 | + |
| 173 | + voc_dir <- file.path(self$raw_folder, "VOCdevkit", paste0("VOC", self$year)) |
| 174 | + voc_root <- self$raw_folder |
| 175 | + if (self$year == "2011") { |
| 176 | + voc_root <- file.path(voc_root, "TrainVal") |
| 177 | + } |
| 178 | + voc_dir <- file.path(voc_root, "VOCdevkit", paste0("VOC", self$year)) |
| 179 | + |
| 180 | + split_file <- file.path(voc_dir, "ImageSets", "Segmentation", paste0(self$split, ".txt")) |
| 181 | + |
| 182 | + ids <- readLines(split_file) |
| 183 | + img_path <- file.path(voc_dir, "JPEGImages", paste0(ids, ".jpg")) |
| 184 | + mask_paths <- file.path(voc_dir, "SegmentationClass", paste0(ids, ".png")) |
| 185 | + |
| 186 | + saveRDS(list( |
| 187 | + img_path = img_path, |
| 188 | + mask_paths = mask_paths |
| 189 | + ), file.path(self$processed_folder, paste0(self$split, ".rds"))) |
| 190 | + |
| 191 | + cli_inform("Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully.") |
| 192 | + }, |
| 193 | + |
| 194 | + check_exists = function() { |
| 195 | + fs::file_exists(file.path(self$processed_folder, paste0(self$split, ".rds"))) |
| 196 | + }, |
| 197 | + |
| 198 | + .getitem = function(index) { |
| 199 | + |
| 200 | + img_path <- self$img_path[index] |
| 201 | + mask_path <- self$mask_paths[index] |
| 202 | + |
| 203 | + x <- jpeg::readJPEG(img_path) |
| 204 | + mask_data <- png::readPNG(mask_path) * 255 |
| 205 | + |
| 206 | + flat_mask <- matrix( |
| 207 | + c(as.vector(mask_data[, , 1]), |
| 208 | + as.vector(mask_data[, , 2]), |
| 209 | + as.vector(mask_data[, , 3])), |
| 210 | + ncol = 3 |
| 211 | + ) |
| 212 | + colormap_mat <- matrix(self$voc_colormap, ncol = 3, byrow = TRUE) |
| 213 | + rgb_to_int <- function(mat) { |
| 214 | + as.integer(mat[, 1]) * 256^2 + as.integer(mat[, 2]) * 256 + as.integer(mat[, 3]) |
| 215 | + } |
| 216 | + match_indices <- match(rgb_to_int(flat_mask), rgb_to_int(colormap_mat)) - 1 |
| 217 | + class_idx <- matrix(match_indices, nrow = dim(mask_data)[1], ncol = dim(mask_data)[2]) |
| 218 | + class_idx_tensor <- torch_tensor(class_idx, dtype = torch_long()) |
| 219 | + class_ids <- torch_arange(0, 20, dtype = torch_long())$view(c(21, 1, 1)) |
| 220 | + masks <- (class_ids == class_idx_tensor$unsqueeze(1))$to(dtype = torch_bool()) |
| 221 | + labels <- which(as_array(masks$any(dim = c(2, 3)))) |
| 222 | + |
| 223 | + y <- list(labels = labels, masks = masks) |
| 224 | + |
| 225 | + if (!is.null(self$transform)) { |
| 226 | + x <- self$transform(x) |
| 227 | + } |
| 228 | + if (!is.null(self$target_transform)) { |
| 229 | + y <- self$target_transform(y) |
| 230 | + } |
| 231 | + |
| 232 | + structure(list(x = x, y = y), class = "image_with_segmentation_mask") |
| 233 | + }, |
| 234 | + |
| 235 | + .length = function() { |
| 236 | + length(self$img_path) |
| 237 | + }, |
| 238 | + |
| 239 | + active = list( |
| 240 | + raw_folder = function() { |
| 241 | + file.path(self$root_path, paste0("pascal_voc_", self$year), "raw") |
| 242 | + }, |
| 243 | + processed_folder = function() { |
| 244 | + file.path(self$root_path, paste0("pascal_voc_", self$year), "processed") |
| 245 | + } |
| 246 | + ) |
| 247 | +) |
| 248 | + |
| 249 | +#' Pascal VOC Detection Dataset |
| 250 | +#' |
| 251 | +#' @inheritParams pascal_segmentation_dataset |
| 252 | +#' |
| 253 | +#' @return A torch dataset of class \code{pascal_detection_dataset}. |
| 254 | +#' |
| 255 | +#' The returned list inherits class \code{image_with_bounding_box}, which allows generic visualization |
| 256 | +#' utilities to be applied. |
| 257 | +#' |
| 258 | +#' Each element is a named list: |
| 259 | +#' - `x`: a H x W x 3 array representing the RGB image. |
| 260 | +#' - `y`: a list with: |
| 261 | +#' - `labels`: a character vector with object class names. |
| 262 | +#' - `boxes`: a tensor of shape (N, 4) with bounding box coordinates in `(xmin, ymin, xmax, ymax)` format. |
| 263 | +#' |
| 264 | +#' @rdname pascal_voc_datasets |
| 265 | +#' @family detection_dataset |
| 266 | +#' @export |
| 267 | +pascal_detection_dataset <- torch::dataset( |
| 268 | + name = "pascal_detection_dataset", |
| 269 | + |
| 270 | + inherit = pascal_segmentation_dataset, |
| 271 | + |
| 272 | + initialize = function( |
| 273 | + root = tempdir(), |
| 274 | + year = "2012", |
| 275 | + split = "train", |
| 276 | + transform = NULL, |
| 277 | + target_transform = NULL, |
| 278 | + download = FALSE |
| 279 | + ) { |
| 280 | + |
| 281 | + self$root_path <- root |
| 282 | + self$year <- match.arg(year, choices = unique(self$resources$year)) |
| 283 | + self$split <- match.arg(split, choices = c("train", "val", "trainval", "test")) |
| 284 | + self$transform <- transform |
| 285 | + self$target_transform <- target_transform |
| 286 | + if (self$split == "test") { |
| 287 | + self$archive_key <- "test" |
| 288 | + } else { |
| 289 | + self$archive_key <- "trainval" |
| 290 | + } |
| 291 | + self$archive_size <- self$resources[self$resources$year == self$year & self$resources$type == self$archive_key,]$size |
| 292 | + |
| 293 | + if (download) { |
| 294 | + cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.") |
| 295 | + self$download() |
| 296 | + } |
| 297 | + |
| 298 | + if (!self$check_exists()) { |
| 299 | + cli_abort("Dataset not found. You can use `download = TRUE` to download it.") |
| 300 | + } |
| 301 | + |
| 302 | + voc_dir <- file.path(self$raw_folder, "VOCdevkit", paste0("VOC", self$year)) |
| 303 | + if (self$year == "2011") { |
| 304 | + voc_dir <- file.path(self$raw_folder, "TrainVal", "VOCdevkit", paste0("VOC", self$year)) |
| 305 | + } |
| 306 | + |
| 307 | + ids_file <- file.path(voc_dir, "ImageSets", "Main", paste0(self$split, ".txt")) |
| 308 | + ids <- readLines(ids_file) |
| 309 | + |
| 310 | + self$img_path <- file.path(voc_dir, "JPEGImages", paste0(ids, ".jpg")) |
| 311 | + self$annotation_paths <- file.path(voc_dir, "Annotations", paste0(ids, ".xml")) |
| 312 | + |
| 313 | + if (!requireNamespace("xml2", quietly = TRUE)) { |
| 314 | + install.packages("xml2") |
| 315 | + } |
| 316 | + |
| 317 | + cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.") |
| 318 | + }, |
| 319 | + |
| 320 | + .getitem = function(index) { |
| 321 | + |
| 322 | + x <- jpeg::readJPEG(self$img_path[index]) |
| 323 | + ann_path <- self$annotation_paths[index] |
| 324 | + y <- self$parse_voc_xml(xml2::read_xml(ann_path)) |
| 325 | + |
| 326 | + if (!is.null(self$transform)) { |
| 327 | + x <- self$transform(x) |
| 328 | + } |
| 329 | + if (!is.null(self$target_transform)) { |
| 330 | + y <- self$target_transform(y) |
| 331 | + } |
| 332 | + |
| 333 | + structure(list(x = x, y = y), class = "image_with_bounding_box") |
| 334 | + }, |
| 335 | + |
| 336 | + parse_voc_xml = function(xml) { |
| 337 | + objects <- xml2::xml_find_all(xml, ".//object") |
| 338 | + |
| 339 | + labels <- xml2::xml_text(xml2::xml_find_all(objects, "name")) |
| 340 | + |
| 341 | + bboxes <- xml2::xml_find_all(objects, "bndbox") |
| 342 | + |
| 343 | + xmin <- xml2::xml_integer(xml2::xml_find_all(bboxes, "xmin")) |
| 344 | + ymin <- xml2::xml_integer(xml2::xml_find_all(bboxes, "ymin")) |
| 345 | + xmax <- xml2::xml_integer(xml2::xml_find_all(bboxes, "xmax")) |
| 346 | + ymax <- xml2::xml_integer(xml2::xml_find_all(bboxes, "ymax")) |
| 347 | + |
| 348 | + boxes <- torch_tensor(data.frame(xmin, ymin, xmax, ymax) %>% as.matrix(), dtype = torch_long()) |
| 349 | + |
| 350 | + list( |
| 351 | + labels = labels, |
| 352 | + boxes = boxes |
| 353 | + ) |
| 354 | + } |
| 355 | +) |
0 commit comments