Skip to content

Commit da2dd3d

Browse files
authored
addition of whoi_small_plankton and whoi_plankton datasets (#236)
1 parent f26d82e commit da2dd3d

33 files changed

+598
-119
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ Imports:
4848
glue,
4949
zeallot
5050
Suggests:
51+
arrow,
5152
magick,
53+
prettyunits,
5254
testthat,
5355
coro,
5456
R.matlab,

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ export(transform_ten_crop)
188188
export(transform_to_tensor)
189189
export(transform_vflip)
190190
export(vision_make_grid)
191+
export(whoi_plankton_dataset)
192+
export(whoi_small_plankton_dataset)
191193
importFrom(grDevices,dev.off)
192194
importFrom(graphics,polygon)
193195
importFrom(jsonlite,fromJSON)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* Added `lfw_people_dataset()` and `lfw_pairs_dataset()` for loading Labelled Faces in the Wild (LFW) datasets (@DerrickUnleashed, #203).
66
* Added `places365_dataset()`for loading the Places365 dataset (@koshtiakanksha, #196).
77
* Added `pascal_segmentation_dataset()`, and `pascal_detection_dataset()` for loading the Pascal Visual Object Classes datasets (@DerrickUnleashed, #209).
8+
* Added `whoi_plankton_dataset()`, and `whoi_small_plankton_dataset()` (@cregouby, #236).
89

910
## New models
1011

R/dataset-caltech.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ caltech101_dataset <- torch::dataset(
8181
self$image_indices <- c(self$image_indices, seq_along(imgs))
8282
}
8383

84-
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
84+
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
8585
},
8686

8787
.getitem = function(index) {
@@ -205,7 +205,7 @@ caltech256_dataset <- torch::dataset(
205205
}, seq_along(self$classes), images_per_class, SIMPLIFY = FALSE),
206206
use.names = FALSE
207207
)
208-
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
208+
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
209209
},
210210

211211
check_exists = function() {

R/dataset-coco.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ coco_detection_dataset <- torch::dataset(
6464
) {
6565

6666
year <- match.arg(year)
67-
split <- if (train) "train" else "val"
67+
split <- ifelse(train, "train", "val")
6868

6969
root <- fs::path_expand(root)
7070
self$root <- root
@@ -76,7 +76,7 @@ coco_detection_dataset <- torch::dataset(
7676

7777
self$data_dir <- fs::path(root, glue::glue("coco{year}"))
7878

79-
image_year <- if (year == "2016") "2014" else year
79+
image_year <- ifelse(year == "2016", "2014", year)
8080
self$image_dir <- fs::path(self$data_dir, glue::glue("{split}{image_year}"))
8181
self$annotation_file <- fs::path(self$data_dir, "annotations",
8282
glue::glue("instances_{split}{year}.json"))
@@ -288,7 +288,7 @@ coco_caption_dataset <- torch::dataset(
288288
) {
289289

290290
year <- match.arg(year)
291-
split <- if (train) "train" else "val"
291+
split <- ifelse(train, "train", "val")
292292

293293
root <- fs::path_expand(root)
294294
self$root <- root
@@ -329,7 +329,7 @@ coco_caption_dataset <- torch::dataset(
329329
image_id <- ann$image_id
330330
y <- ann$caption
331331

332-
prefix <- if (self$split == "train") "COCO_train2014_" else "COCO_val2014_"
332+
prefix <- ifelse(self$split == "train", "COCO_train2014_", "COCO_val2014_")
333333
filename <- paste0(prefix, sprintf("%012d", image_id), ".jpg")
334334
image_path <- fs::path(self$image_dir, filename)
335335

R/dataset-eurosat.R

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#' @inheritParams mnist_dataset
1313
#' @param root (Optional) Character. The root directory where the dataset will be stored.
1414
#' if empty, will use the default `rappdirs::user_cache_dir("torch")`.
15-
#' @param split Character. Must be one of `train`, `val`, or `test`.
15+
#' @param split One of `"train"`, `"val"`, or `"test"`. Default is `"val"`.
1616
#'
1717
#' @return A `torch::dataset` object. Each item is a list with:
1818
#' * `x`: a 64x64 image tensor with 3 (RGB) or 13 (all bands) channels
@@ -39,7 +39,7 @@ eurosat_dataset <- torch::dataset(
3939

4040
initialize = function(
4141
root = tempdir(),
42-
split = "train",
42+
split = "val",
4343
download = FALSE,
4444
transform = NULL,
4545
target_transform = NULL
@@ -53,7 +53,7 @@ eurosat_dataset <- torch::dataset(
5353
self$images_dir <- file.path(self$root, class(self)[1], "images")
5454
self$split_file <- file.path(self$root, fs::path_ext_remove(basename(self$split_url)))
5555

56-
if (download){
56+
if (download) {
5757
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
5858
self$download()
5959
}
@@ -184,5 +184,3 @@ eurosat100_dataset <- torch::dataset(
184184
split_url = "https://huggingface.co/datasets/torchgeo/eurosat/resolve/main/eurosat-100-{split}.txt?download=true",
185185
archive_size = "7 MB"
186186
)
187-
188-

R/dataset-fer.R

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@
3131
fer_dataset <- dataset(
3232
name = "fer_dataset",
3333
archive_size = "90 MB",
34+
url = "https://huggingface.co/datasets/JimmyUnleashed/FER-2013/resolve/main/fer2013.tar.gz",
35+
md5 = "ca95d94fe42f6ce65aaae694d18c628a",
36+
classes = c(
37+
"Angry",
38+
"Disgust",
39+
"Fear",
40+
"Happy",
41+
"Sad",
42+
"Surprise",
43+
"Neutral"
44+
),
3445

3546
initialize = function(
3647
root = tempdir(),
@@ -39,25 +50,25 @@ fer_dataset <- dataset(
3950
target_transform = NULL,
4051
download = FALSE
4152
) {
42-
4353
self$root <- root
4454
self$train <- train
4555
self$transform <- transform
4656
self$target_transform <- target_transform
47-
self$split <- if (train) "Train" else "Test"
57+
self$split <- ifelse(train, "Train", "Test")
4858
self$folder_name <- "fer2013"
49-
self$url <- "https://huggingface.co/datasets/JimmyUnleashed/FER-2013/resolve/main/fer2013.tar.gz"
50-
self$md5 <- "ca95d94fe42f6ce65aaae694d18c628a"
51-
self$classes <- c("Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral")
5259
self$class_to_idx <- setNames(seq_along(self$classes), self$classes)
5360

54-
if (download){
55-
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
61+
if (download) {
62+
cli_inform(
63+
"Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available."
64+
)
5665
self$download()
5766
}
5867

5968
if (!self$check_files()) {
60-
runtime_error("Dataset not found. You can use `download = TRUE` to download it.")
69+
runtime_error(
70+
"Dataset not found. You can use `download = TRUE` to download it."
71+
)
6172
}
6273

6374
csv_file <- file.path(self$root, self$folder_name, "fer2013.csv")
@@ -87,11 +98,13 @@ fer_dataset <- dataset(
8798

8899
y <- self$y[i]
89100

90-
if (!is.null(self$transform))
101+
if (!is.null(self$transform)) {
91102
x <- self$transform(x)
103+
}
92104

93-
if (!is.null(self$target_transform))
105+
if (!is.null(self$target_transform)) {
94106
y <- self$target_transform(y)
107+
}
95108

96109
list(x = x, y = y)
97110
},
@@ -112,11 +125,14 @@ fer_dataset <- dataset(
112125

113126
archive <- download_and_cache(self$url)
114127

115-
if (!tools::md5sum(archive) == self$md5)
128+
if (!tools::md5sum(archive) == self$md5) {
116129
runtime_error("Corrupt file! Delete the file in {archive} and try again.")
130+
}
117131

118132
untar(archive, exdir = self$root)
119-
cli_inform("Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully.")
133+
cli_inform(
134+
"Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully."
135+
)
120136
},
121137

122138
check_files = function() {

R/dataset-fgvc.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ fgvc_aircraft_dataset <- dataset(
7878
target_transform = NULL,
7979
download = FALSE
8080
) {
81-
81+
8282
self$root <- root
8383
self$split <- split
8484
self$annotation_level <- annotation_level
@@ -132,10 +132,10 @@ fgvc_aircraft_dataset <- dataset(
132132
.getitem = function(index) {
133133
x <- jpeg::readJPEG(self$image_paths[index]) * 255
134134

135-
y <- if (self$annotation_level == "all") {
136-
as.integer(self$labels_df[index, ])
135+
if (self$annotation_level == "all") {
136+
y <- as.integer(self$labels_df[index, ])
137137
} else {
138-
self$labels_df[[self$annotation_level]][index]
138+
y <- self$labels_df[[self$annotation_level]][index]
139139
}
140140

141141
if (!is.null(self$transform)) {

R/dataset-flickr.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ flickr8k_caption_dataset <- torch::dataset(
6363
self$transform <- transform
6464
self$target_transform <- target_transform
6565
self$train <- train
66-
self$split <- if (train) "train" else "test"
67-
66+
self$split <- ifelse(train, "train", "test")
67+
6868
if (download)
6969
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
7070
self$download()
@@ -130,7 +130,7 @@ flickr8k_caption_dataset <- torch::dataset(
130130

131131
download = function() {
132132

133-
if (self$check_exists())
133+
if (self$check_exists())
134134
return()
135135

136136
cli_inform("Downloading {.cls {class(self)[[1]]}}...")
@@ -173,10 +173,10 @@ flickr8k_caption_dataset <- torch::dataset(
173173
caption_index <- self$captions[[index]]
174174
y <- self$classes[[caption_index]]
175175

176-
if (!is.null(self$transform))
176+
if (!is.null(self$transform))
177177
x <- self$transform(x)
178178

179-
if (!is.null(self$target_transform))
179+
if (!is.null(self$target_transform))
180180
y <- self$target_transform(y)
181181

182182
list(x = x, y = y)
@@ -225,13 +225,13 @@ flickr30k_caption_dataset <- torch::dataset(
225225
self$transform <- transform
226226
self$target_transform <- target_transform
227227
self$train <- train
228-
self$split <- if (train) "train" else "test"
228+
self$split <- ifelse(train, "train", "test")
229229

230230
if (download)
231231
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
232232
self$download()
233233

234-
if (!self$check_exists())
234+
if (!self$check_exists())
235235
cli_abort("Dataset not found. Use `download = TRUE` to download it.")
236236

237237
captions_path <- file.path(self$raw_folder, "dataset_flickr30k.json")

R/dataset-lfw.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ lfw_people_dataset <- torch::dataset(
132132
self$classes <- class_names
133133
self$class_to_idx <- class_to_idx
134134

135-
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
135+
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
136136
},
137137

138138
download = function() {
@@ -283,7 +283,7 @@ lfw_pairs_dataset <- torch::dataset(
283283
self$pairs <- do.call(rbind, pair_list)
284284
self$img_path <- c(self$pairs$img1, self$pairs$img2)
285285

286-
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
286+
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
287287
},
288288

289289
.getitem = function(index) {

0 commit comments

Comments
 (0)