Skip to content
68 changes: 54 additions & 14 deletions docs/source-pytorch/accelerators/gpu_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,71 @@
GPU training (FAQ)
==================

******************************************************************
How should I adjust the learning rate when using multiple devices?
******************************************************************
***************************************************************
How should I adjust the batch size when using multiple devices?
***************************************************************

When using distributed training make sure to modify your learning rate according to your effective
batch size.
Lightning automatically shards your data across multiple GPUs, meaning that each device only sees a unique subset of your
data, but the `batch_size` in your DataLoader remains the same. This means that the effective batch size e.g. the
total number of samples processed in one forward/backward pass is

Let's say you have a batch size of 7 in your dataloader.
.. math::

.. testcode::
\text{Effective Batch Size} = \text{DataLoader Batch Size} \times \text{Number of Devices} \times \text{Number of Nodes}

class LitModel(LightningModule):
def train_dataloader(self):
return Dataset(..., batch_size=7)

Whenever you use multiple devices and/or nodes, your effective batch size will be 7 * devices * num_nodes.
A couple of examples to illustrate this:

.. code-block:: python

# effective batch size = 7 * 8
dataloader = DataLoader(..., batch_size=7)

# Single GPU: effective batch size = 7
Trainer(accelerator="gpu", devices=1)

# Multi-GPU: effective batch size = 7 * 8 = 56
Trainer(accelerator="gpu", devices=8, strategy=...)

# effective batch size = 7 * 8 * 10
# Multi-node: effective batch size = 7 * 8 * 10 = 560
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy=...)

In general you should be able to use the same `batch_size` in your DataLoader regardless of the number of devices you are
using.

.. note::

If you want distributed training to work exactly the same as single GPU training, you need to set the `batch_size`
in your DataLoader to `original_batch_size / num_devices` to maintain the same effective batch size. However, this
can lead to poor GPU utilization.

----

******************************************************************
How should I adjust the learning rate when using multiple devices?
******************************************************************

Because the effective batch size is larger when using multiple devices, you need to adjust your learning rate
accordingly. Because the learning rate is a hyperparameter that controls how much to change the model in response to
the estimated error each time the model weights are updated, it is important to scale it with the effective batch size.

In general, there are two common scaling rules:

1. **Linear scaling**: Increase the learning rate linearly with the number of devices.

.. code-block:: python

# Example: Linear scaling
base_lr = 1e-3
num_devices = 8
scaled_lr = base_lr * num_devices # 8e-3

2. **Square root scaling**: Increase the learning rate by the square root of the number of devices.

.. code-block:: python

# Example: Square root scaling
base_lr = 1e-3
num_devices = 8
scaled_lr = base_lr * (num_devices ** 0.5) # 2.83e-3

.. note:: Huge batch sizes are actually really bad for convergence. Check out:
`Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour <https://arxiv.org/abs/1706.02677>`_
Expand Down
4 changes: 3 additions & 1 deletion docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ limit_train_batches

How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Value is per device.

.. testcode::

Expand All @@ -535,7 +536,7 @@ limit_test_batches
:width: 400
:muted:

How much of test dataset to check.
How much of test dataset to check. Value is per device.

.. testcode::

Expand All @@ -560,6 +561,7 @@ limit_val_batches

How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Value is per device.

.. testcode::

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/expertise_levels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Learn to scale up your models and enable collaborative model development at acad
.. Add callout items below this line

.. displayitem::
:header: Level 7: Interactive cloud development
:header: Level 7: Hardware acceleration
:description: Learn how to access GPUs and TPUs on the cloud.
:button_link: levels/intermediate_level_7.html
:col_css: col-md-6
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/levels/intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Learn to scale up your models and enable collaborative model development at acad
.. Add callout items below this line

.. displayitem::
:header: Level 7: Interactive cloud development
:header: Level 7: Hardware acceleration
:description: Learn how to access GPUs and TPUs on the cloud.
:button_link: intermediate_level_7.html
:col_css: col-md-6
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/levels/intermediate_level_7.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
:orphan:

######################################
Level 7: Interactive cloud development
######################################
##############################
Level 7: Hardware acceleration
##############################

Learn to develop models on cloud GPUs and TPUs.

Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,16 @@ def __init__(
:class:`datetime.timedelta`.

limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
Value is per device. Default: ``1.0``.

limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
Value is per device. Default: ``1.0``.

limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
Value is per device. Default: ``1.0``.

limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches).
Default: ``1.0``.
Value is per device. Default: ``1.0``.

overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int).
Default: ``0.0``.
Expand Down
Loading