diff --git a/configs/config/test/integration_test/models/eval_regnet_fsdp.yaml b/configs/config/test/integration_test/models/eval_regnet_fsdp.yaml new file mode 100644 index 000000000..e3dc35115 --- /dev/null +++ b/configs/config/test/integration_test/models/eval_regnet_fsdp.yaml @@ -0,0 +1,19 @@ +# @package _global_ +config: + MODEL: + TRUNK: + NAME: regnet + REGNET: + name: anynet + depths: [2, 4, 11, 1] + widths: [224, 448, 1232, 3024] + group_widths: [112, 112, 112, 112] + bottleneck_multipliers: [1.0, 1.0, 1.0, 1.0] + strides: [2, 2, 2, 2] + HEAD: + PARAMS: [ + ["eval_mlp", {"in_channels": 3024, "dims": [12096, 1000]}], + ] + SYNC_BN_CONFIG: + CONVERT_BN_TO_SYNC_BN: True + SYNC_BN_TYPE: pytorch diff --git a/configs/config/test/integration_test/models/swav_regnet_fsdp.yaml b/configs/config/test/integration_test/models/swav_regnet_fsdp.yaml new file mode 100644 index 000000000..51cdea00d --- /dev/null +++ b/configs/config/test/integration_test/models/swav_regnet_fsdp.yaml @@ -0,0 +1,19 @@ +# @package _global_ +config: + MODEL: + TRUNK: + NAME: regnet + REGNET: + name: anynet + depths: [2, 4, 11, 1] + widths: [224, 448, 1232, 3024] + group_widths: [112, 112, 112, 112] + bottleneck_multipliers: [1.0, 1.0, 1.0, 1.0] + strides: [2, 2, 2, 2] + HEAD: + PARAMS: [ + ["swav_head", {"dims": [3024, 3024, 128], "use_bn": False, "num_clusters": [3000]}], + ] + SYNC_BN_CONFIG: + CONVERT_BN_TO_SYNC_BN: True + SYNC_BN_TYPE: pytorch diff --git a/configs/config/test/integration_test/quick_swav_2crops.yaml b/configs/config/test/integration_test/quick_swav_2crops.yaml new file mode 100644 index 000000000..be9844e7f --- /dev/null +++ b/configs/config/test/integration_test/quick_swav_2crops.yaml @@ -0,0 +1,113 @@ +# @package _global_ +config: + VERBOSE: False + LOG_FREQUENCY: 2 + TEST_ONLY: False + TEST_MODEL: False + SEED_VALUE: 0 + MULTI_PROCESSING_METHOD: forkserver + HOOKS: + PERF_STATS: + MONITOR_PERF_STATS: True + PERF_STAT_FREQUENCY: 40 + ROLLING_BTIME_FREQ: 5 + DATA: + NUM_DATALOADER_WORKERS: 5 + TRAIN: + DATA_SOURCES: [disk_filelist] + DATASET_NAMES: [imagenet1k_filelist] + BATCHSIZE_PER_REPLICA: 16 + LABEL_TYPE: sample_index # just an implementation detail. Label isn't used + TRANSFORMS: + - name: ImgPilToMultiCrop + total_num_crops: 2 + size_crops: [224] + num_crops: [2] + crop_scales: [[0.14, 1]] + - name: RandomHorizontalFlip + p: 0.5 + - name: ImgPilColorDistortion + strength: 1.0 + - name: ImgPilGaussianBlur + p: 0.5 + radius_min: 0.1 + radius_max: 2.0 + - name: ToTensor + - name: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + COLLATE_FUNCTION: multicrop_collator + MMAP_MODE: True + COPY_TO_LOCAL_DISK: False + COPY_DESTINATION_DIR: /tmp/imagenet1k/ + DATA_LIMIT: 250 + DROP_LAST: True + TRAINER: + TRAIN_STEP_NAME: standard_train_step + METERS: + name: "" + MODEL: + TRUNK: + NAME: resnet + RESNETS: + DEPTH: 50 + HEAD: + PARAMS: [ + ["swav_head", {"dims": [2048, 2048, 128], "use_bn": True, "num_clusters": [3000]}], + ] + TEMP_FROZEN_PARAMS_ITER_MAP: [ + ['module.heads.0.prototypes0.weight', 313], + # TODO (Min): FSDP need to return the original param name from named_parameters(). + ['_fsdp_wrapped_module.heads.0._fsdp_wrapped_module._fpw_module.prototypes0._fsdp_wrapped_module.weight', 313] + ] + SYNC_BN_CONFIG: + CONVERT_BN_TO_SYNC_BN: True + SYNC_BN_TYPE: pytorch + LOSS: + name: swav_loss + swav_loss: + temperature: 0.1 + use_double_precision: False + normalize_last_layer: True + num_iters: 3 + epsilon: 0.05 + crops_for_assign: [0, 1] + queue: + queue_length: 0 + start_iter: 0 + OPTIMIZER: + name: sgd + use_larc: True + larc_config: + clip: False + trust_coefficient: 0.001 + eps: 0.00000001 + weight_decay: 0.000001 + momentum: 0.9 + nesterov: False + num_epochs: 1 + regularize_bn: True + regularize_bias: True + param_schedulers: + lr: + auto_lr_scaling: + auto_scale: true + base_value: 0.3 + base_lr_batch_size: 256 + name: cosine + start_value: 0.15 # LR for batch size 256 + end_value: 0.0000 + update_interval: step + DISTRIBUTED: + BACKEND: nccl + NUM_NODES: 1 + NUM_PROC_PER_NODE: 2 + INIT_METHOD: tcp + RUN_ID: auto + MACHINE: + DEVICE: gpu + CHECKPOINT: + DIR: "." + AUTO_RESUME: True + CHECKPOINT_FREQUENCY: 5 + OVERWRITE_EXISTING: true diff --git a/tests/test_regnet_fsdp_integration.py b/tests/test_regnet_fsdp_integration.py index b8e5669b7..760b21de4 100644 --- a/tests/test_regnet_fsdp_integration.py +++ b/tests/test_regnet_fsdp_integration.py @@ -2,11 +2,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import os import unittest import torch from hydra.experimental import compose, initialize_config_module +from vissl.utils.checkpoint import CheckpointFormatConverter from vissl.utils.hydra_config import convert_to_attrdict from vissl.utils.test_utils import ( gpu_test, @@ -32,22 +33,23 @@ def _create_pretraining_config( cfg = compose( "defaults", overrides=[ - "config=test/integration_test/quick_swav", - "+config/pretrain/swav/models=regnet16Gf", - "config.DATA.TRAIN.DATA_SOURCES=[synthetic]", + "config=test/integration_test/quick_swav_2crops", + "+config/test/integration_test/models=swav_regnet_fsdp", "config.SEED_VALUE=0", "config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True", "config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch", "config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch", - "config.LOSS.swav_loss.epsilon=0.03", "config.MODEL.FSDP_CONFIG.flatten_parameters=True", - "config.DISTRIBUTED.NUM_PROC_PER_NODE=2", - "config.LOG_FREQUENCY=1", - "config.OPTIMIZER.construct_single_param_group_only=True", + "config.LOSS.swav_loss.epsilon=0.03", + "config.DATA.TRAIN.DATA_SOURCES=[synthetic]", "config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4", + "config.DATA.TRAIN.DATA_LIMIT=32", + "config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True", "config.OPTIMIZER.use_larc=False", + "config.OPTIMIZER.construct_single_param_group_only=True", + "config.LOG_FREQUENCY=1", "config.REPRODUCIBILITY.CUDDN_DETERMINISTIC=True", - "config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True", + "config.DISTRIBUTED.NUM_PROC_PER_NODE=2", ], ) args, config = convert_to_attrdict(cfg) @@ -63,12 +65,59 @@ def _create_pretraining_config( config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2" config["MODEL"]["HEAD"]["PARAMS"][0][0] = "swav_head" config.MODEL.AMP_PARAMS.USE_AMP = with_mixed_precision - + config.MODEL.TRUNK.REGNET.stage_checkpoints = [[2], [4], [6, 11], []] config.MODEL.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING = ( with_activation_checkpointing ) return config + def _create_linear_evaluation_config( + self, with_fsdp: bool, with_mixed_precision: bool, auto_wrap_threshold: int + ): + with initialize_config_module(config_module="vissl.config"): + cfg = compose( + "defaults", + overrides=[ + "config=test/integration_test/quick_eval_in1k_linear", + "+config/test/integration_test/models=eval_regnet_fsdp", + "config.SEED_VALUE=0", + "config.MODEL.SYNC_BN_CONFIG.CONVERT_BN_TO_SYNC_BN=True", + "config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch", + "config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch", + "config.MODEL.FSDP_CONFIG.flatten_parameters=True", + "config.DATA.TRAIN.DATA_SOURCES=[synthetic]", + "config.DATA.TRAIN.LABEL_SOURCES=[synthetic]", + "config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=4", + "config.DATA.TRAIN.DATA_LIMIT=32", + "config.DATA.TEST.DATA_SOURCES=[synthetic]", + "config.DATA.TEST.LABEL_SOURCES=[synthetic]", + "config.DATA.TEST.BATCHSIZE_PER_REPLICA=4", + "config.DATA.TEST.DATA_LIMIT=32", + "config.DATA.TRAIN.USE_DEBUGGING_SAMPLER=True", + "config.OPTIMIZER.use_larc=False", + "config.OPTIMIZER.construct_single_param_group_only=True", + "config.LOG_FREQUENCY=1", + "config.REPRODUCIBILITY.CUDDN_DETERMINISTIC=True", + "config.DISTRIBUTED.NUM_PROC_PER_NODE=2", + ], + ) + args, config = convert_to_attrdict(cfg) + if with_fsdp: + config["MODEL"]["TRUNK"]["NAME"] = "regnet_fsdp" + config["MODEL"]["HEAD"]["PARAMS"][0][0] = "eval_mlp_fsdp" + config.TRAINER.TASK_NAME = "self_supervision_fsdp_task" + config.MODEL.FSDP_CONFIG.mixed_precision = with_mixed_precision + config.MODEL.FSDP_CONFIG.fp32_reduce_scatter = with_mixed_precision + config.MODEL.FSDP_CONFIG.compute_dtype = torch.float32 + config.MODEL.FSDP_CONFIG.AUTO_WRAP_THRESHOLD = auto_wrap_threshold + else: + config["MODEL"]["TRUNK"]["NAME"] = "regnet_v2" + config["MODEL"]["HEAD"]["PARAMS"][0][0] = "eval_mlp" + config.MODEL.AMP_PARAMS.USE_AMP = with_mixed_precision + config.MODEL.TRUNK.REGNET.stage_checkpoints = [[2], [4], [6, 11], []] + config.MODEL.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING = False + return config + def run_pretraining( self, with_fsdp: bool, @@ -86,26 +135,77 @@ def run_pretraining( result = run_integration_test(config) return result.get_losses() + def run_linear_eval( + self, + checkpoint_path: str, + with_fsdp: bool, + with_mixed_precision: bool, + auto_wrap_threshold: int = 0, + ): + with in_temporary_directory(): + config = self._create_linear_evaluation_config( + with_fsdp=with_fsdp, + with_mixed_precision=with_mixed_precision, + auto_wrap_threshold=auto_wrap_threshold, + ) + config.MODEL.WEIGHTS_INIT.PARAMS_FILE = checkpoint_path + result = run_integration_test(config) + return result.get_losses() + @gpu_test(gpu_count=2) def test_fsdp_integration(self): - ddp_losses = self.run_pretraining( - with_fsdp=False, - with_activation_checkpointing=False, + fsdp_losses_1 = self.run_pretraining( + with_fsdp=True, + with_activation_checkpointing=True, with_mixed_precision=False, ) - fsdp_losses_1 = self.run_pretraining( + fsdp_losses_2 = self.run_pretraining( with_fsdp=True, with_activation_checkpointing=False, with_mixed_precision=False, ) - fsdp_losses_2 = self.run_pretraining( + fsdp_losses_3 = self.run_pretraining( with_fsdp=True, with_activation_checkpointing=False, with_mixed_precision=False, auto_wrap_threshold=100, ) + ddp_losses = self.run_pretraining( + with_fsdp=False, + with_activation_checkpointing=False, + with_mixed_precision=False, + ) self.assertEqual(ddp_losses, fsdp_losses_1) self.assertEqual(ddp_losses, fsdp_losses_2) + self.assertEqual(ddp_losses, fsdp_losses_3) + + @gpu_test(gpu_count=2) + def test_fsdp_integration_with_linear_eval(self): + with in_temporary_directory() as pretrain_dir: + + # Start pre-training + config = self._create_pretraining_config( + with_fsdp=True, + with_activation_checkpointing=True, + with_mixed_precision=False, + auto_wrap_threshold=0, + ) + run_integration_test(config) + + # Consolidate the weights + CheckpointFormatConverter.sharded_to_consolidated_checkpoint( + "checkpoint.torch", "checkpoint_conso.torch" + ) + + # Load the checkpoint and perform a linear evaluation on it + losses = self.run_linear_eval( + checkpoint_path=os.path.join(pretrain_dir, "checkpoint_conso.torch"), + with_fsdp=True, + with_mixed_precision=False, + auto_wrap_threshold=0, + ) + self.assertEqual(8, len(losses)) + print(losses) @gpu_test(gpu_count=2) def test_fsdp_integration_mixed_precision(self): diff --git a/vissl/config/defaults.yaml b/vissl/config/defaults.yaml index d51a84045..3ae492817 100644 --- a/vissl/config/defaults.yaml +++ b/vissl/config/defaults.yaml @@ -503,6 +503,7 @@ config: # a dedicated FSDP wrapping: the number provided here is the number of # parameters that serves as threshold to decide if a layer is "big" AUTO_WRAP_THRESHOLD: 0 + AMP_TYPE: "01" # Parameters of fairscale FSDP flatten_parameters: True mixed_precision: True diff --git a/vissl/models/trunks/regnet_fsdp.py b/vissl/models/trunks/regnet_fsdp.py index 5893d6bf6..222dc3922 100644 --- a/vissl/models/trunks/regnet_fsdp.py +++ b/vissl/models/trunks/regnet_fsdp.py @@ -16,7 +16,6 @@ and target training speed considerations. """ -import logging import math from collections import OrderedDict from typing import List, Tuple, Union @@ -161,17 +160,15 @@ def create_any_stage( group_width: int, bottleneck_multiplier: float, params: Union[RegNetParams, AnyNetParams], + group_delimiters: List[int], + group_checkpoint: List[bool], stage_index: int = 0, - checkpoints: List[int] = 0, ) -> AnyStage: - assert sorted(checkpoints) == checkpoints, "Checkpoint indices should be sorted" - - with_checkpointing = len(checkpoints) > 0 - block_delimiters = [depth] if len(checkpoints) == 0 else checkpoints + assert len(group_delimiters) == len(group_checkpoint) any_stage = AnyStage() prev_depth = 0 - for block_group_index, next_depth in enumerate(block_delimiters): + for group_index, next_depth in enumerate(group_delimiters): block_group = nn.Sequential() for i in range(prev_depth, next_depth): block = self.create_block( @@ -185,11 +182,9 @@ def create_any_stage( any_stage.stage_depth += block.depth block_group.add_module(f"block{stage_index}-{i}", block) prev_depth = next_depth - if with_checkpointing: + if group_checkpoint[group_index]: block_group = checkpoint_wrapper(block_group) - any_stage.add_module( - f"block{stage_index}-part{block_group_index}", block_group - ) + any_stage.add_module(f"block{stage_index}-part{group_index}", block_group) return any_stage @@ -200,12 +195,9 @@ class RegnetFSDPBlocksFactory(RegnetBlocksFactory): initializing the weights etc. """ - def __init__(self, fsdp_config: AttrDict, use_activation_checkpointing: bool): + def __init__(self, fsdp_config: AttrDict): super().__init__() self.fsdp_config = fsdp_config - self.use_activation_checkpointing = use_activation_checkpointing - if self.use_activation_checkpointing: - logging.info("Using Activation Checkpointing for FSDP training") def create_stem(self, params: Union[RegNetParams, AnyNetParams]): stem = super().create_stem(params) @@ -239,17 +231,15 @@ def create_any_stage( group_width: int, bottleneck_multiplier: float, params: Union[RegNetParams, AnyNetParams], + group_delimiters: List[int], + group_checkpoint: List[bool], stage_index: int = 0, - checkpoints: List[int] = 0, ): - assert sorted(checkpoints) == checkpoints, "Checkpoint indices should be sorted" - - with_checkpointing = len(checkpoints) > 0 - block_delimiters = [depth] if len(checkpoints) == 0 else checkpoints + assert len(group_delimiters) == len(group_checkpoint) any_stage = AnyStage() prev_depth = 0 - for block_group_index, next_depth in enumerate(block_delimiters): + for group_index, next_depth in enumerate(group_delimiters): block_group = nn.Sequential() for i in range(prev_depth, next_depth): block = self.create_block( @@ -263,12 +253,10 @@ def create_any_stage( any_stage.stage_depth += block.depth block_group.add_module(f"block{stage_index}-{i}", block) prev_depth = next_depth - if with_checkpointing: + if group_checkpoint[group_index]: block_group = checkpoint_wrapper(block_group) block_group = fsdp_wrapper(block_group, **self.fsdp_config) - any_stage.add_module( - f"block{stage_index}-part{block_group_index}", block_group - ) + any_stage.add_module(f"block{stage_index}-part{group_index}", block_group) return any_stage @@ -337,16 +325,34 @@ def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config): # Starting from 1 stage_index = i + 1 - # Specify where the checkpoints are set in the stage - stage_checkpoints = [] - if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING: - checkpoint_config = trunk_config.get("stage_checkpoints", []) - if len(checkpoint_config) > 0: - stage_checkpoints = checkpoint_config[i] - else: - stage_checkpoints.append(depth) - - # Create the stage and add it to the trunk + # Identify where the block groups start and end, and whether they should + # be surrounded by activation checkpoints + # A block group is a group of block that is surrounded by a FSDP wrapper + # and optionally an activation checkpoint wrapper + with_checkpointing = ( + model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING + ) + all_group_delimiters = trunk_config.get("stage_checkpoints", []) + group_delimiters = ( + all_group_delimiters[i] if len(all_group_delimiters) > i else [] + ) + group_checkpoint = [with_checkpointing] * len(group_delimiters) + assert ( + sorted(group_delimiters) == group_delimiters + ), "Checkpoint boundaries should be sorted" + if not group_delimiters: + # No delimiters means one group but no activation checkpointing + # for this group (even if USE_ACTIVATION_CHECKPOINTING is set) + group_delimiters.append(depth) + group_checkpoint.append(False) + elif group_delimiters[-1] != depth: + # Complete missing checkpoints at the end (user can give only + # the intermediate checkpoints to avoid repetitions) + group_delimiters.append(depth) + group_checkpoint.append(with_checkpointing) + + # Create the stage from the description of the block and the size of + # the block groups that compose this stage, then add it to the trunk new_stage = factory.create_any_stage( width_in=current_width, width_out=width_out, @@ -356,7 +362,8 @@ def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config): bottleneck_multiplier=bottleneck_multiplier, params=params, stage_index=stage_index, - checkpoints=stage_checkpoints, + group_delimiters=group_delimiters, + group_checkpoint=group_checkpoint, ) blocks.append((f"block{stage_index}", new_stage)) trunk_depth += blocks[-1][1].stage_depth @@ -451,8 +458,7 @@ def __init__(self, model_config: AttrDict, model_name: str): with set_torch_seed(self.seed): self._feature_blocks, self.trunk_depth = create_regnet_feature_blocks( factory=RegnetFSDPBlocksFactory( - fsdp_config=self.model_config.FSDP_CONFIG, - use_activation_checkpointing=self.use_activation_checkpointing, + fsdp_config=self.model_config.FSDP_CONFIG ), model_config=self.model_config, ) diff --git a/vissl/utils/hydra_config.py b/vissl/utils/hydra_config.py index bc4422ac3..55e6a609b 100644 --- a/vissl/utils/hydra_config.py +++ b/vissl/utils/hydra_config.py @@ -543,12 +543,22 @@ def infer_and_assert_hydra_config(cfg): # Inference of optimizer configuration if cfg["OPTIMIZER"]["use_larc"]: cfg["OPTIMIZER"]["name"] = "sgd_fsdp" + # if using LARC, we set the flatten_params=False so that we can + # compute the right params groups + cfg["MODEL"]["FSDP_CONFIG"]["flatten_parameters"] = False # AMP based inference if cfg["MODEL"]["AMP_PARAMS"]["USE_AMP"]: cfg["MODEL"]["AMP_PARAMS"]["AMP_TYPE"] = "pytorch" cfg["MODEL"]["FSDP_CONFIG"]["mixed_precision"] = True - cfg["MODEL"]["FSDP_CONFIG"]["fp32_reduce_scatter"] = True + # setup the compute_dtype and fp32_reduce_scatter + # based on whether O1 or O2 is desired + if cfg.MODEL.FSDP_CONFIG["AMP_TYPE"] == "O1": + cfg["MODEL"]["FSDP_CONFIG"]["compute_dtype"] = "float32" + cfg["MODEL"]["FSDP_CONFIG"]["fp32_reduce_scatter"] = True + elif cfg.MODEL.FSDP_CONFIG["AMP_TYPE"] == "O2": + cfg["MODEL"]["FSDP_CONFIG"]["compute_dtype"] = "float16" + cfg["MODEL"]["FSDP_CONFIG"]["fp32_reduce_scatter"] = False else: # if not using AMP, we can't use mixed_precision as it requires PyTorch AMP cfg["MODEL"]["FSDP_CONFIG"]["mixed_precision"] = False @@ -577,6 +587,8 @@ def infer_and_assert_hydra_config(cfg): # Delete the AUTO_SETUP_FSDP key since we send the FSDP_CONFIG # to FSDP from fairscale which doesn't know about AUTO_SETUP_FSDP del cfg.MODEL.FSDP_CONFIG["AUTO_SETUP_FSDP"] + del cfg.MODEL.FSDP_CONFIG["AMP_TYPE"] + logging.info(f"Using the FSDP config: {cfg.MODEL.FSDP_CONFIG}") if cfg.DATA.TRAIN.BASE_DATASET == "generic_ssl": assert (