-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start adding a more advanced vignette
- Loading branch information
Showing
3 changed files
with
125 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
*.html | ||
*.R |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
--- | ||
title: "Custom loops with luz" | ||
output: rmarkdown::html_vignette | ||
vignette: > | ||
%\VignetteIndexEntry{Custom loops with luz} | ||
%\VignetteEngine{knitr::rmarkdown} | ||
%\VignetteEncoding{UTF-8} | ||
--- | ||
|
||
```{r, include = FALSE} | ||
knitr::opts_chunk$set( | ||
collapse = TRUE, | ||
comment = "#>" | ||
) | ||
``` | ||
|
||
```{r setup} | ||
library(torch) | ||
library(luz) | ||
``` | ||
|
||
Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop. | ||
|
||
In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will find describe how luz allows the user to get fine grained control of the training loop. | ||
|
||
A part from the use of callbacks there are three more ways that you can use luz depending on how much control you need: | ||
|
||
- **Multiple optimizers or losses**: You might be optimizing two loss functions each with its own optimizer, but you still don't want to modify the `backward()` - `zero_grad()` and `step()` calls. This is common in models like GANs (Generative Adversarial Networks) when you have competing neural networks trained with different losses and optimizers. | ||
|
||
- **Fully flexible step:** You might want to be in control of how to call `backward()`, `zero_grad()`and `step()` as well as maybe having more control of gradient computation. For example, you might want to use 'virtual batch sizes', ie. you accumulate the gradients for a few steps before updating the weights. | ||
|
||
- **Completely flexible loop**: Your training loop can be anything you want but you still want to use luz to handle device placement of the dataloaders, optimizers and models. | ||
|
||
Now we will describe each of these approaches by implementing the same model for each API. | ||
|
||
Let's consider a simplified version of the `net` that we implemented in the getting started vignette: | ||
|
||
``` {.r} | ||
net <- nn_module( | ||
"Net", | ||
initialize = function() { | ||
self$fc1 <- nn_linear(100, 50) | ||
self$fc1 <- nn_linear(50, 10) | ||
}, | ||
forward = function(x) { | ||
x %>% | ||
self$fc1() %>% | ||
nnf_relu() %>% | ||
self$fc2() | ||
} | ||
) | ||
``` | ||
|
||
Using the highest level of luz API we would fit it using: | ||
|
||
``` {.r} | ||
fitted <- net %>% | ||
setup( | ||
loss = nn_cross_entropy_loss(), | ||
optimizer = optim_adam, | ||
metrics = list( | ||
luz_metric_accuracy | ||
) | ||
) %>% | ||
fit(train_dl, epochs = 10, valid_data = test_dl) | ||
``` | ||
|
||
## Multiple optimizers | ||
|
||
Suppose we want to do an experiment where we train the first fully connected layer using a learning rate of 0.1 and the second one using learning rate of 0.01. Both minimizing the same `nn_cross_entropy_loss()` but for the first layer we want to add L1 regularization on the weights. | ||
|
||
In order to use luz for this we will implement two methods in the `net` module: | ||
|
||
- `optimizers`: returns a named list of optimizers depending on the `ctx`. | ||
|
||
- `loss`: computes the loss depending on the selected optimizer. | ||
|
||
Let's go to the code: | ||
|
||
```{r} | ||
net <- nn_module( | ||
"Net", | ||
initialize = function() { | ||
self$fc1 <- nn_linear(100, 50) | ||
self$fc1 <- nn_linear(50, 10) | ||
}, | ||
forward = function(x) { | ||
x %>% | ||
self$fc1() %>% | ||
nnf_relu() %>% | ||
self$fc2() | ||
}, | ||
optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) { | ||
list( | ||
opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1), | ||
opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2) | ||
) | ||
}, | ||
loss = function(input, target) { | ||
pred <- ctx$model(input) | ||
if (ctx$opt_name == "opt_fc1") | ||
nnf_cross_entropy(pred, target) + torch_norm(self$fc1$weight, p = 1) | ||
else if (ctx$opt_name == "opt_fc2") | ||
nnf_cross_entropy(pred, target) | ||
} | ||
) | ||
``` | ||
|
||
Notice that model optimizers will be initialized according to the `optimizers()` method return value. In this case, we are initializing the optimizers using different model parameters and learning rates. | ||
|
||
The `loss()` method is responsible for computing the loss that will be then backpropagated to compute gradients and update the weights. This `loss()` method can access the `ctx` object that will contain a `opt_name` field, describing which optimizer is currently being used. Note that this function will be called once for each optimizer for each training and validation step. | ||
|
||
We can finally `setup` and `fit` this module, however we no longer need to specify optimizers and loss functions. | ||
|
||
```{r} | ||
fitted <- net %>% | ||
setup(metrics = list( | ||
luz_metric_accuracy | ||
)) %>% | ||
fit(train_dl, epochs = 10, valid_data = test_dl) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters