Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding .getbatch() method to mnist_dataset dataset generator improves performance markedly #106

Open
gavril0 opened this issue Mar 15, 2024 · 0 comments

Comments

@gavril0
Copy link

gavril0 commented Mar 15, 2024

The standard dataset_generator for MNIST dataset does not include a .getbatch() method and, as a result, getting a batch is quite slow, at least on CPU.

# dataset root directory
dir <- "./dataset"

# download dataset
train_ds <- mnist_dataset(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)
# dataloader 
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader iterator
train_iter <- train_dl$.iter()
microbenchmark( b <- $.next())

The timings are:

Unit: milliseconds
                    expr     min       lq     mean   median       uq     max neval
 b <- train_iter$.next() 45.5263 47.78455 52.74117 49.19825 52.95675 87.2219   100

As explained in the vignette, the dataloader uses the .getitem() method iteratively to return a batch in absence of a .getmatch() method.

Interestingly, it seems that the .getitem() method mitgh be used as .getbatch() method without any change:

# mnist_dataset .getitem() method
> train_ds$.getitem
function (index) 
{
    img <- self$data[index, , ]
    target <- self$targets[index]
    if (!is.null(self$transform)) 
        img <- self$transform(img)
    if (!is.null(self$target_transform)) 
        target <- self$target_transform(target)
    list(x = img, y = target)
}
<environment: 0x000001527445be68>

It is easy to add a .getbatch() to the exsiting mnist_dataset dataset generator:

# create a new dataset generator that extends mnist_dataset
mnist_dataset2 <- dataset(
  inherit = mnist_dataset,
  .getbatch = function(index) {
    self$.getitem(index)
  }
)

Let's measure the performance with this new dataset generator:

# create a dataset with the new dataset generator
train_ds2 <- mnist_dataset2(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)
# create a dataloder with the new dataset
train_dl2 <- dataloader(train_ds2, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader
train_iter2 <- train_dl2$.iter()
microbenchmark::microbenchmark(train_iter2$.next())
Unit: milliseconds
                expr      min       lq     mean   median       uq     max neval
 train_iter2$.next() 3.995601 4.328151 5.430246 4.601451 4.965501 11.7692   100

The new dataloader is almost 10 times faster!

That saids, it seems that the newdata loader cannot be used in place of train_dl in this example which uses luz to train the network:

fitted <- mnist_module %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy()
    )
  ) %>%
  fit(train_dl, epochs = 1, valid_data = test_dl)

It yields an error message

expected input[1, 28, 128, 28] to have 1 channels, but got 28 channels instead

I don't have a PC with GPU to test whether there is a similar improvement when the data are loaded on the GPU. I also wonder why the .getbatch() function is not always implemented since it seems an easy way to improve performance. Though I did not investigate the origin the error, the luz::fit method should be able to accept data_loader with a .getbatch method.

@gavril0 gavril0 changed the title .getbatch() method in mnist_dataset function improves performance markedly adding .getbatch() method in mnist_dataset function improves performance markedly Mar 15, 2024
@gavril0 gavril0 changed the title adding .getbatch() method in mnist_dataset function improves performance markedly adding .getbatch() method to mnist_dataset dataset generator improves performance markedly Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant