forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_utils_hydra_config.py
85 lines (70 loc) · 3.65 KB
/
test_utils_hydra_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import List
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize_config_module
from vissl.utils.hydra_config import convert_to_attrdict
class TestUtilsHydraConfig(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
initialize_config_module(config_module="vissl.config")
@staticmethod
def _create_config(overrides: List[str]):
cfg = compose("defaults", overrides=overrides)
args, config = convert_to_attrdict(cfg)
return config
def test_inference_of_fsdp_settings_for_swav_pretraining(self):
overrides = [
"config=pretrain/swav/swav_8node_resnet",
"+config/pretrain/swav/models=regnet16Gf",
"config.MODEL.AMP_PARAMS.USE_AMP=True",
"config.OPTIMIZER.use_larc=True",
]
cfg = self._create_config(overrides)
self.assertEqual(cfg["MODEL"]["AMP_PARAMS"]["AMP_TYPE"], "apex")
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "swav_head")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_task")
self.assertEqual(cfg.OPTIMIZER.name, "sgd")
cfg = self._create_config(
overrides + ["config.MODEL.FSDP_CONFIG.AUTO_SETUP_FSDP=True"]
)
self.assertEqual(cfg["MODEL"]["AMP_PARAMS"]["AMP_TYPE"], "pytorch")
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "swav_head_fsdp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet_fsdp")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_fsdp_task")
self.assertEqual(cfg.OPTIMIZER.name, "sgd_fsdp")
def test_inference_of_fsdp_settings_for_linear_evaluation(self):
overrides = [
"config=debugging/benchmark/linear_image_classification/eval_resnet_8gpu_transfer_imagenette_160",
"+config/debugging/benchmark/linear_image_classification/models=regnet16Gf_mlp",
]
cfg = self._create_config(overrides)
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "mlp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_task")
cfg = self._create_config(
overrides + ["config.MODEL.FSDP_CONFIG.AUTO_SETUP_FSDP=True"]
)
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "mlp_fsdp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet_fsdp")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_fsdp_task")
def test_inference_of_fsdp_settings_for_linear_evaluation_with_bn(self):
overrides = [
"config=debugging/benchmark/linear_image_classification/eval_resnet_8gpu_transfer_imagenette_160",
"+config/debugging/benchmark/linear_image_classification/models=regnet16Gf_eval_mlp",
]
cfg = self._create_config(overrides)
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "eval_mlp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_task")
cfg = self._create_config(
overrides + ["config.MODEL.FSDP_CONFIG.AUTO_SETUP_FSDP=True"]
)
self.assertEqual(cfg.MODEL.HEAD.PARAMS[0][0], "eval_mlp_fsdp")
self.assertEqual(cfg.MODEL.TRUNK.NAME, "regnet_fsdp")
self.assertEqual(cfg.TRAINER.TASK_NAME, "self_supervision_fsdp_task")