Skip to content

adding .getbatch() method to mnist_dataset dataset generator improves performance markedly #106

Open
@gavril0

Description

@gavril0

The standard dataset_generator for MNIST dataset does not include a .getbatch() method and, as a result, getting a batch is quite slow, at least on CPU.

# dataset root directory
dir <- "./dataset"

# download dataset
train_ds <- mnist_dataset(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)
# dataloader 
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader iterator
train_iter <- train_dl$.iter()
microbenchmark( b <- $.next())

The timings are:

Unit: milliseconds
                    expr     min       lq     mean   median       uq     max neval
 b <- train_iter$.next() 45.5263 47.78455 52.74117 49.19825 52.95675 87.2219   100

As explained in the vignette, the dataloader uses the .getitem() method iteratively to return a batch in absence of a .getmatch() method.

Interestingly, it seems that the .getitem() method mitgh be used as .getbatch() method without any change:

# mnist_dataset .getitem() method
> train_ds$.getitem
function (index) 
{
    img <- self$data[index, , ]
    target <- self$targets[index]
    if (!is.null(self$transform)) 
        img <- self$transform(img)
    if (!is.null(self$target_transform)) 
        target <- self$target_transform(target)
    list(x = img, y = target)
}
<environment: 0x000001527445be68>

It is easy to add a .getbatch() to the exsiting mnist_dataset dataset generator:

# create a new dataset generator that extends mnist_dataset
mnist_dataset2 <- dataset(
  inherit = mnist_dataset,
  .getbatch = function(index) {
    self$.getitem(index)
  }
)

Let's measure the performance with this new dataset generator:

# create a dataset with the new dataset generator
train_ds2 <- mnist_dataset2(
  dir,
  download = TRUE,
  transform = transform_to_tensor
)
# create a dataloder with the new dataset
train_dl2 <- dataloader(train_ds2, batch_size = 128, shuffle = TRUE)
# get a batch via the dataloader
train_iter2 <- train_dl2$.iter()
microbenchmark::microbenchmark(train_iter2$.next())
Unit: milliseconds
                expr      min       lq     mean   median       uq     max neval
 train_iter2$.next() 3.995601 4.328151 5.430246 4.601451 4.965501 11.7692   100

The new dataloader is almost 10 times faster!

That saids, it seems that the newdata loader cannot be used in place of train_dl in this example which uses luz to train the network:

fitted <- mnist_module %>%
  setup(
    loss = nn_cross_entropy_loss(),
    optimizer = optim_adam,
    metrics = list(
      luz_metric_accuracy()
    )
  ) %>%
  fit(train_dl, epochs = 1, valid_data = test_dl)

It yields an error message

expected input[1, 28, 128, 28] to have 1 channels, but got 28 channels instead

I don't have a PC with GPU to test whether there is a similar improvement when the data are loaded on the GPU. I also wonder why the .getbatch() function is not always implemented since it seems an easy way to improve performance. Though I did not investigate the origin the error, the luz::fit method should be able to accept data_loader with a .getbatch method.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions