You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The standard dataset_generator for MNIST dataset does not include a .getbatch() method and, as a result, getting a batch is quite slow, at least on CPU.
# dataset root directory
dir <- "./dataset"
# download dataset
train_ds <- mnist_dataset(
dir,
download = TRUE,
transform = transform_to_tensor
)
# dataloader
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader iterator
train_iter <- train_dl$.iter()
microbenchmark( b <- $.next())
The timings are:
Unit: milliseconds
expr min lq mean median uq max neval
b <- train_iter$.next() 45.5263 47.78455 52.74117 49.19825 52.95675 87.2219 100
As explained in the vignette, the dataloader uses the .getitem() method iteratively to return a batch in absence of a .getmatch() method.
Interestingly, it seems that the .getitem() method mitgh be used as .getbatch() method without any change:
# mnist_dataset .getitem() method
> train_ds$.getitem
function (index)
{
img <- self$data[index, , ]
target <- self$targets[index]
if (!is.null(self$transform))
img <- self$transform(img)
if (!is.null(self$target_transform))
target <- self$target_transform(target)
list(x = img, y = target)
}
<environment: 0x000001527445be68>
It is easy to add a .getbatch() to the exsiting mnist_dataset dataset generator:
# create a new dataset generator that extends mnist_dataset
mnist_dataset2 <- dataset(
inherit = mnist_dataset,
.getbatch = function(index) {
self$.getitem(index)
}
)
Let's measure the performance with this new dataset generator:
# create a dataset with the new dataset generator
train_ds2 <- mnist_dataset2(
dir,
download = TRUE,
transform = transform_to_tensor
)
# create a dataloder with the new dataset
train_dl2 <- dataloader(train_ds2, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader
train_iter2 <- train_dl2$.iter()
microbenchmark::microbenchmark(train_iter2$.next())
Unit: milliseconds
expr min lq mean median uq max neval
train_iter2$.next() 3.995601 4.328151 5.430246 4.601451 4.965501 11.7692 100
The new dataloader is almost 10 times faster!
That saids, it seems that the newdata loader cannot be used in place of train_dl in this example which uses luz to train the network:
expected input[1, 28, 128, 28] to have 1 channels, but got 28 channels instead
I don't have a PC with GPU to test whether there is a similar improvement when the data are loaded on the GPU. I also wonder why the .getbatch() function is not always implemented since it seems an easy way to improve performance. Though I did not investigate the origin the error, the luz::fit method should be able to accept data_loader with a .getbatch method.
The text was updated successfully, but these errors were encountered:
gavril0
changed the title
.getbatch() method in mnist_dataset function improves performance markedly
adding .getbatch() method in mnist_dataset function improves performance markedly
Mar 15, 2024
gavril0
changed the title
adding .getbatch() method in mnist_dataset function improves performance markedly
adding .getbatch() method to mnist_dataset dataset generator improves performance markedly
Mar 15, 2024
The standard dataset_generator for MNIST dataset does not include a
.getbatch()
method and, as a result, getting a batch is quite slow, at least on CPU.The timings are:
As explained in the vignette, the dataloader uses the
.getitem()
method iteratively to return a batch in absence of a.getmatch()
method.Interestingly, it seems that the
.getitem()
method mitgh be used as.getbatch()
method without any change:It is easy to add a
.getbatch()
to the exsitingmnist_dataset
dataset generator:Let's measure the performance with this new dataset generator:
The new dataloader is almost 10 times faster!
That saids, it seems that the newdata loader cannot be used in place of
train_dl
in this example which usesluz
to train the network:It yields an error message
expected input[1, 28, 128, 28] to have 1 channels, but got 28 channels instead
I don't have a PC with GPU to test whether there is a similar improvement when the data are loaded on the GPU. I also wonder why the
.getbatch()
function is not always implemented since it seems an easy way to improve performance. Though I did not investigate the origin the error, theluz::fit
method should be able to accept data_loader with a.getbatch
method.The text was updated successfully, but these errors were encountered: