Skip to content

Commit

Permalink
Fix typo + improve last section with an example.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Oct 27, 2022
1 parent 5163f6b commit f268937
Showing 1 changed file with 59 additions and 12 deletions.
71 changes: 59 additions & 12 deletions vignettes/articles/checkpoints.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ interrupt <- luz_callback(
)
```

Let's now start training adding the `luz_callback_autoresume()`:
Let's now start training adding the `luz_callback_auto_resume()`:

```{r, error=TRUE}
autoresume <- luz_callback_auto_resume(path = "state.pt")
Expand Down Expand Up @@ -105,17 +105,6 @@ so you can have the full results:
plot(results)
```

### Custom callbacks state

Sometimes callbacks also need to keep their internal state in order to allow
continuing training exactly from where it stopped. In this case, callbacks can
implement the `state_dict()` and the `load_state_dict()` methods that are
automatically called when saving and reloading checkpoints.

The `state_dict()` method must return a named list of objects that compose
the state of the callback. `load_state_dict()` takes the same named list returned
by `state_dict()` and restores the callback state.

## Checkpointing

Sometimes you want to have more control over how checkpoints are handled.
Expand Down Expand Up @@ -196,3 +185,61 @@ plot(results)
```{r include=FALSE}
fs::dir_delete("checkpoints")
```

### Custom callbacks state

Sometimes callbacks also need to keep their internal state in order to allow
continuing training exactly from where it stopped. In this case, callbacks can
implement the `state_dict()` and the `load_state_dict()` methods that are
automatically called when saving and reloading checkpoints.

For example, suppose that you have a callback that tracks gradients for weights at
every epoch. You want to use the tracked weights to further analyse the training
procedure. It could be implemented like:

```{r}
cb_weight_grad <- luz_callback(
"weight_grad",
gradients = list(),
initialize = function(track_weights) {
self$track_weights
},
on_train_batch_before_step = function() {
gradients[[ctx$epoch]] <- list()
for (w in self$track_weights) {
gradients[[ctx$epoch]][[w]] <- self$model$parameters[[w]]
}
}
)
```

In the above example, the `gradients` field is a **state** in the callback. If
training fails for some reason, `gradients` will be lost. If it's important for
you to also checkpoint the callback state, you can implement the `state_dict()`
method must returning a named list of objects that compose the state of the callback
and `load_state_dict()` taking the same named list returned by `state_dict()` and
restoring the callback state.

The callback above could be reimplemented with:

```{r}
cb_weight_grad <- luz_callback(
"weight_grad",
gradients = list(),
initialize = function(track_weights) {
self$track_weights
},
on_train_batch_before_step = function() {
gradients[[ctx$epoch]] <- list()
for (w in self$track_weights) {
gradients[[ctx$epoch]][[w]] <- self$model$parameters[[w]]
}
},
state_dict = function() {
list(gradients = self$gradients)
},
load_state_dict = function(d) {
self$gradients <- d$gradients
}
)
```

0 comments on commit f268937

Please sign in to comment.