Skip to content

Commit

Permalink
FSDP Activation Checkpointing improvements (facebookresearch#176)
Browse files Browse the repository at this point in the history
Summary:
[enhancement] FSDP with activation checkpoint now allows to specify blocks without activations (useful for linear evaluation) and completes incomplete configurations for stage_checkpoints

Pull Request resolved: fairinternal/ssl_scaling#176

Reviewed By: prigoyal

Differential Revision: D30143386

Pulled By: QuentinDuval

fbshipit-source-id: 6fa85059d36d0bfa44ea7c07ac92994985674943
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Aug 6, 2021
1 parent 5a622f4 commit 554aa15
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 54 deletions.
19 changes: 19 additions & 0 deletions configs/config/test/integration_test/models/eval_regnet_fsdp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
config:
MODEL:
TRUNK:
NAME: regnet
REGNET:
name: anynet
depths: [2, 4, 11, 1]
widths: [224, 448, 1232, 3024]
group_widths: [112, 112, 112, 112]
bottleneck_multipliers: [1.0, 1.0, 1.0, 1.0]
strides: [2, 2, 2, 2]
HEAD:
PARAMS: [
["eval_mlp", {"in_channels": 3024, "dims": [12096, 1000]}],
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: pytorch
19 changes: 19 additions & 0 deletions configs/config/test/integration_test/models/swav_regnet_fsdp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
config:
MODEL:
TRUNK:
NAME: regnet
REGNET:
name: anynet
depths: [2, 4, 11, 1]
widths: [224, 448, 1232, 3024]
group_widths: [112, 112, 112, 112]
bottleneck_multipliers: [1.0, 1.0, 1.0, 1.0]
strides: [2, 2, 2, 2]
HEAD:
PARAMS: [
["swav_head", {"dims": [3024, 3024, 128], "use_bn": False, "num_clusters": [3000]}],
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: pytorch
113 changes: 113 additions & 0 deletions configs/config/test/integration_test/quick_swav_2crops.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# @package _global_
config:
VERBOSE: False
LOG_FREQUENCY: 2
TEST_ONLY: False
TEST_MODEL: False
SEED_VALUE: 0
MULTI_PROCESSING_METHOD: forkserver
HOOKS:
PERF_STATS:
MONITOR_PERF_STATS: True
PERF_STAT_FREQUENCY: 40
ROLLING_BTIME_FREQ: 5
DATA:
NUM_DATALOADER_WORKERS: 5
TRAIN:
DATA_SOURCES: [disk_filelist]
DATASET_NAMES: [imagenet1k_filelist]
BATCHSIZE_PER_REPLICA: 16
LABEL_TYPE: sample_index # just an implementation detail. Label isn't used
TRANSFORMS:
- name: ImgPilToMultiCrop
total_num_crops: 2
size_crops: [224]
num_crops: [2]
crop_scales: [[0.14, 1]]
- name: RandomHorizontalFlip
p: 0.5
- name: ImgPilColorDistortion
strength: 1.0
- name: ImgPilGaussianBlur
p: 0.5
radius_min: 0.1
radius_max: 2.0
- name: ToTensor
- name: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
COLLATE_FUNCTION: multicrop_collator
MMAP_MODE: True
COPY_TO_LOCAL_DISK: False
COPY_DESTINATION_DIR: /tmp/imagenet1k/
DATA_LIMIT: 250
DROP_LAST: True
TRAINER:
TRAIN_STEP_NAME: standard_train_step
METERS:
name: ""
MODEL:
TRUNK:
NAME: resnet
RESNETS:
DEPTH: 50
HEAD:
PARAMS: [
["swav_head", {"dims": [2048, 2048, 128], "use_bn": True, "num_clusters": [3000]}],
]
TEMP_FROZEN_PARAMS_ITER_MAP: [
['module.heads.0.prototypes0.weight', 313],
# TODO (Min): FSDP need to return the original param name from named_parameters().
['_fsdp_wrapped_module.heads.0._fsdp_wrapped_module._fpw_module.prototypes0._fsdp_wrapped_module.weight', 313]
]
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: pytorch
LOSS:
name: swav_loss
swav_loss:
temperature: 0.1
use_double_precision: False
normalize_last_layer: True
num_iters: 3
epsilon: 0.05
crops_for_assign: [0, 1]
queue:
queue_length: 0
start_iter: 0
OPTIMIZER:
name: sgd
use_larc: True
larc_config:
clip: False
trust_coefficient: 0.001
eps: 0.00000001
weight_decay: 0.000001
momentum: 0.9
nesterov: False
num_epochs: 1
regularize_bn: True
regularize_bias: True
param_schedulers:
lr:
auto_lr_scaling:
auto_scale: true
base_value: 0.3
base_lr_batch_size: 256
name: cosine
start_value: 0.15 # LR for batch size 256
end_value: 0.0000
update_interval: step
DISTRIBUTED:
BACKEND: nccl
NUM_NODES: 1
NUM_PROC_PER_NODE: 2
INIT_METHOD: tcp
RUN_ID: auto
MACHINE:
DEVICE: gpu
CHECKPOINT:
DIR: "."
AUTO_RESUME: True
CHECKPOINT_FREQUENCY: 5
OVERWRITE_EXISTING: true
130 changes: 115 additions & 15 deletions tests/test_regnet_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import unittest

import torch
from hydra.experimental import compose, initialize_config_module
from vissl.utils.checkpoint import CheckpointFormatConverter
from vissl.utils.hydra_config import convert_to_attrdict
from vissl.utils.test_utils import (
gpu_test,
Expand All @@ -32,22 +33,23 @@ def _create_pretraining_config(
cfg = compose(
"defaults",
overrides=[
"config=test/integration_test/quick_swav",
"+config/pretrain/swav/models=regnet16Gf",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config=test/integration_test/quick_swav_2crops",
"+config/test/integration_test/models=swav_regnet_fsdp",
"config.SEED_VALUE=0",
"config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True",
"config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch",
"config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch",
"config.LOSS.swav_loss.epsilon=0.03",
"config.MODEL.FSDP_CONFIG.flatten_parameters=True",
"config.DISTRIBUTED.NUM_PROC_PER_NODE=2",
"config.LOG_FREQUENCY=1",
"config.OPTIMIZER.construct_single_param_group_only=True",
"config.LOSS.swav_loss.epsilon=0.03",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4",
"config.DATA.TRAIN.DATA_LIMIT=32",
"config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True",
"config.OPTIMIZER.use_larc=False",
"config.OPTIMIZER.construct_single_param_group_only=True",
"config.LOG_FREQUENCY=1",
"config.REPRODUCIBILITY.CUDDN_DETERMINISTIC=True",
"config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True",
"config.DISTRIBUTED.NUM_PROC_PER_NODE=2",
],
)
args, config = convert_to_attrdict(cfg)
Expand All @@ -63,12 +65,59 @@ def _create_pretraining_config(
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head"
config.MODEL.AMP_PARAMS.USE_AMP = with_mixed_precision

config.MODEL.TRUNK.REGNET.stage_checkpoints = [[2], [4], [6, 11], []]
config.MODEL.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING = (
with_activation_checkpointing
)
return config

def _create_linear_evaluation_config(
self, with_fsdp: bool, with_mixed_precision: bool, auto_wrap_threshold: int
):
with initialize_config_module(config_module="vissl.config"):
cfg = compose(
"defaults",
overrides=[
"config=test/integration_test/quick_eval_in1k_linear",
"+config/test/integration_test/models=eval_regnet_fsdp",
"config.SEED_VALUE=0",
"config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True",
"config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch",
"config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch",
"config.MODEL.FSDP_CONFIG.flatten_parameters=True",
"config.DATA.TRAIN.DATA_SOURCES=[synthetic]",
"config.DATA.TRAIN.LABEL_SOURCES=[synthetic]",
"config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4",
"config.DATA.TRAIN.DATA_LIMIT=32",
"config.DATA.TEST.DATA_SOURCES=[synthetic]",
"config.DATA.TEST.LABEL_SOURCES=[synthetic]",
"config.DATA.TEST.BATCHSIZE_PER_REPLICA=4",
"config.DATA.TEST.DATA_LIMIT=32",
"config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True",
"config.OPTIMIZER.use_larc=False",
"config.OPTIMIZER.construct_single_param_group_only=True",
"config.LOG_FREQUENCY=1",
"config.REPRODUCIBILITY.CUDDN_DETERMINISTIC=True",
"config.DISTRIBUTED.NUM_PROC_PER_NODE=2",
],
)
args, config = convert_to_attrdict(cfg)
if with_fsdp:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "eval_mlp_fsdp"
config.TRAINER.TASK_NAME = "self_supervision_fsdp_task"
config.MODEL.FSDP_CONFIG.mixed_precision = with_mixed_precision
config.MODEL.FSDP_CONFIG.fp32_reduce_scatter = with_mixed_precision
config.MODEL.FSDP_CONFIG.compute_dtype = torch.float32
config.MODEL.FSDP_CONFIG.AUTO_WRAP_THRESHOLD = auto_wrap_threshold
else:
config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2"
config["MODEL"]["HEAD"]["PARAMS"][0][0] = "eval_mlp"
config.MODEL.AMP_PARAMS.USE_AMP = with_mixed_precision
config.MODEL.TRUNK.REGNET.stage_checkpoints = [[2], [4], [6, 11], []]
config.MODEL.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING = False
return config

def run_pretraining(
self,
with_fsdp: bool,
Expand All @@ -86,26 +135,77 @@ def run_pretraining(
result = run_integration_test(config)
return result.get_losses()

def run_linear_eval(
self,
checkpoint_path: str,
with_fsdp: bool,
with_mixed_precision: bool,
auto_wrap_threshold: int = 0,
):
with in_temporary_directory():
config = self._create_linear_evaluation_config(
with_fsdp=with_fsdp,
with_mixed_precision=with_mixed_precision,
auto_wrap_threshold=auto_wrap_threshold,
)
config.MODEL.WEIGHTS_INIT.PARAMS_FILE = checkpoint_path
result = run_integration_test(config)
return result.get_losses()

@gpu_test(gpu_count=2)
def test_fsdp_integration(self):
ddp_losses = self.run_pretraining(
with_fsdp=False,
with_activation_checkpointing=False,
fsdp_losses_1 = self.run_pretraining(
with_fsdp=True,
with_activation_checkpointing=True,
with_mixed_precision=False,
)
fsdp_losses_1 = self.run_pretraining(
fsdp_losses_2 = self.run_pretraining(
with_fsdp=True,
with_activation_checkpointing=False,
with_mixed_precision=False,
)
fsdp_losses_2 = self.run_pretraining(
fsdp_losses_3 = self.run_pretraining(
with_fsdp=True,
with_activation_checkpointing=False,
with_mixed_precision=False,
auto_wrap_threshold=100,
)
ddp_losses = self.run_pretraining(
with_fsdp=False,
with_activation_checkpointing=False,
with_mixed_precision=False,
)
self.assertEqual(ddp_losses, fsdp_losses_1)
self.assertEqual(ddp_losses, fsdp_losses_2)
self.assertEqual(ddp_losses, fsdp_losses_3)

@gpu_test(gpu_count=2)
def test_fsdp_integration_with_linear_eval(self):
with in_temporary_directory() as pretrain_dir:

# Start pre-training
config = self._create_pretraining_config(
with_fsdp=True,
with_activation_checkpointing=True,
with_mixed_precision=False,
auto_wrap_threshold=0,
)
run_integration_test(config)

# Consolidate the weights
CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
"checkpoint.torch", "checkpoint_conso.torch"
)

# Load the checkpoint and perform a linear evaluation on it
losses = self.run_linear_eval(
checkpoint_path=os.path.join(pretrain_dir, "checkpoint_conso.torch"),
with_fsdp=True,
with_mixed_precision=False,
auto_wrap_threshold=0,
)
self.assertEqual(8, len(losses))
print(losses)

@gpu_test(gpu_count=2)
def test_fsdp_integration_mixed_precision(self):
Expand Down
1 change: 1 addition & 0 deletions vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ config:
# a dedicated FSDP wrapping: the number provided here is the number of
# parameters that serves as threshold to decide if a layer is "big"
AUTO_WRAP_THRESHOLD: 0
AMP_TYPE: "01"
# Parameters of fairscale FSDP
flatten_parameters: True
mixed_precision: True
Expand Down
Loading

0 comments on commit 554aa15

Please sign in to comment.