forked from facebookresearch/vissl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_model_load.py
49 lines (38 loc) · 1.58 KB
/
test_model_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import unittest
from parameterized import parameterized
from utils import (
BENCHMARK_CONFIGS,
INTEGRATION_TEST_CONFIGS,
PRETRAIN_CONFIGS,
SSLHydraConfig,
)
from vissl.models import build_model
from vissl.utils.hydra_config import convert_to_attrdict
logger = logging.getLogger("__name__")
class TestBenchmarkModel(unittest.TestCase):
@parameterized.expand(BENCHMARK_CONFIGS)
def test_benchmark_model(self, filepath):
logger.info(f"Loading {filepath}")
cfg = SSLHydraConfig.from_configs(
[filepath, "config.DISTRIBUTED.NUM_PROC_PER_NODE=1"]
)
_, config = convert_to_attrdict(cfg.default_cfg)
build_model(config.MODEL, config.OPTIMIZER)
class TestPretrainModel(unittest.TestCase):
@parameterized.expand(PRETRAIN_CONFIGS)
def test_pretrain_model(self, filepath):
logger.info(f"Loading {filepath}")
cfg = SSLHydraConfig.from_configs([filepath])
_, config = convert_to_attrdict(cfg.default_cfg)
build_model(config.MODEL, config.OPTIMIZER)
class TestIntegrationTestModel(unittest.TestCase):
@parameterized.expand(INTEGRATION_TEST_CONFIGS)
def test_integration_test_model(self, filepath):
logger.info(f"Loading {filepath}")
cfg = SSLHydraConfig.from_configs([filepath])
_, config = convert_to_attrdict(cfg.default_cfg)
build_model(config.MODEL, config.OPTIMIZER)