-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1233990
commit ab37675
Showing
827 changed files
with
94,397 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os.path as osp | ||
import sys | ||
|
||
|
||
def add_path(path): | ||
if path not in sys.path: | ||
sys.path.insert(0, path) | ||
|
||
|
||
this_dir = osp.dirname(__file__) | ||
lib_path = osp.join(this_dir, '..', 'lib') | ||
add_path(lib_path) | ||
|
||
lib_path = osp.join(this_dir, '..') | ||
add_path(lib_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('skip_connect', 2)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('dil_conv_5x5', 2), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('skip_connect', 2)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('dil_conv_5x5', 2), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('skip_connect', 2)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('dil_conv_5x5', 2), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)]], reduce_concat=range(2, 6))"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('skip_connect', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 0)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('avg_pool_3x3', 0)]], reduce_concat=[2, 3, 4, 5])", "1": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('skip_connect', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 0)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('avg_pool_3x3', 0)]], reduce_concat=[2, 3, 4, 5])", "2": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('skip_connect', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 0)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('avg_pool_3x3', 0)]], reduce_concat=[2, 3, 4, 5])"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('skip_connect', 0), ('dil_conv_3x3', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('max_pool_3x3', 1)]], reduce_concat=[2, 3, 4, 5])", "1": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('skip_connect', 0), ('dil_conv_3x3', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('max_pool_3x3', 1)]], reduce_concat=[2, 3, 4, 5])", "2": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('skip_connect', 0), ('dil_conv_3x3', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('max_pool_3x3', 1)]], reduce_concat=[2, 3, 4, 5])"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], [('sep_conv_5x5', 2), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('avg_pool_3x3', 0)], [('max_pool_3x3', 1), ('dil_conv_5x5', 2)], [('sep_conv_5x5', 1), ('dil_conv_5x5', 2)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 4)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], [('sep_conv_5x5', 2), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('avg_pool_3x3', 0)], [('max_pool_3x3', 1), ('dil_conv_5x5', 2)], [('sep_conv_5x5', 1), ('dil_conv_5x5', 2)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 4)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], [('sep_conv_5x5', 2), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('avg_pool_3x3', 0)], [('max_pool_3x3', 1), ('dil_conv_5x5', 2)], [('sep_conv_5x5', 1), ('dil_conv_5x5', 2)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 4)]], reduce_concat=range(2, 6))"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('dil_conv_3x3', 1)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('sep_conv_5x5', 1), ('sep_conv_5x5', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('dil_conv_3x3', 1)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('sep_conv_5x5', 1), ('sep_conv_5x5', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('dil_conv_3x3', 1)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('sep_conv_5x5', 1), ('sep_conv_5x5', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)]], reduce_concat=range(2, 6))"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('skip_connect', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('skip_connect', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_3x3', 0), ('skip_connect', 1)], [('dil_conv_5x5', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('skip_connect', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('skip_connect', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_3x3', 0), ('skip_connect', 1)], [('dil_conv_5x5', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('skip_connect', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('skip_connect', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_3x3', 0), ('skip_connect', 1)], [('dil_conv_5x5', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]], reduce_concat=range(2, 6))"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"0": "Genotype(normal=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('skip_connect', 0),('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3',0), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('avg_pool_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('dil_conv_3x3', 1)], [('dil_conv_3x3', 1), ('dil_conv_5x5', 3)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('skip_connect', 0),('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3',0), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('avg_pool_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('dil_conv_3x3', 1)], [('dil_conv_3x3', 1), ('dil_conv_5x5', 3)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('skip_connect', 0),('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3',0), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('avg_pool_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('dil_conv_3x3', 1)], [('dil_conv_3x3', 1), ('dil_conv_5x5', 3)]], reduce_concat=range(2, 6))"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
""" Retrain cell """ | ||
import _init_paths | ||
import os | ||
import torch | ||
import json | ||
import torch.nn as nn | ||
import numpy as np | ||
import lib.utils.genotypes as gt | ||
|
||
from tensorboardX import SummaryWriter | ||
from lib.models.cdarts_controller import CDARTSController | ||
from lib.utils import utils | ||
from lib.config import AugmentConfig | ||
from lib.core.augment_function import train, validate | ||
|
||
# config | ||
config = AugmentConfig() | ||
|
||
# make apex optional | ||
if config.distributed: | ||
# DDP = torch.nn.parallel.DistributedDataParallel | ||
try: | ||
import apex | ||
from apex.parallel import DistributedDataParallel as DDP | ||
from apex import amp, optimizers | ||
from apex.fp16_utils import * | ||
except ImportError: | ||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") | ||
|
||
|
||
# tensorboard | ||
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb")) | ||
writer.add_text('config', config.as_markdown(), 0) | ||
|
||
logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name))) | ||
if config.local_rank == 0: | ||
config.print_params(logger.info) | ||
|
||
if 'cifar' in config.dataset: | ||
from lib.datasets.cifar import get_augment_datasets | ||
elif 'imagenet' in config.dataset: | ||
from lib.datasets.imagenet import get_augment_datasets | ||
else: | ||
raise Exception("Not support dataset!") | ||
|
||
def main(): | ||
logger.info("Logger is set - training start") | ||
|
||
# set seed | ||
np.random.seed(config.seed) | ||
torch.manual_seed(config.seed) | ||
torch.cuda.manual_seed_all(config.seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = True | ||
|
||
if config.distributed: | ||
config.gpu = config.local_rank % torch.cuda.device_count() | ||
torch.cuda.set_device(config.gpu) | ||
# distributed init | ||
torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url, | ||
world_size=config.world_size, rank=config.local_rank) | ||
|
||
config.world_size = torch.distributed.get_world_size() | ||
|
||
config.total_batch_size = config.world_size * config.batch_size | ||
else: | ||
config.total_batch_size = config.batch_size | ||
|
||
|
||
loaders, samplers = get_augment_datasets(config) | ||
train_loader, valid_loader = loaders | ||
train_sampler, valid_sampler = samplers | ||
|
||
net_crit = nn.CrossEntropyLoss().cuda() | ||
controller = CDARTSController(config, net_crit, n_nodes=4, stem_multiplier=config.stem_multiplier) | ||
|
||
file = open(config.cell_file, 'r') | ||
js = file.read() | ||
r_dict = json.loads(js) | ||
if config.local_rank == 0: | ||
logger.info(r_dict) | ||
file.close() | ||
genotypes_dict = {} | ||
for layer_idx, genotype in r_dict.items(): | ||
genotypes_dict[int(layer_idx)] = gt.from_str(genotype) | ||
|
||
controller.build_augment_model(controller.init_channel, genotypes_dict) | ||
resume_state = None | ||
if config.resume: | ||
resume_state = torch.load(config.resume_path, map_location='cpu') | ||
controller.model_main.load_state_dict(resume_state['model_main']) | ||
|
||
controller.model_main = controller.model_main.cuda() | ||
param_size = utils.param_size(controller.model_main) | ||
logger.info("param size = %fMB", param_size) | ||
|
||
# change training hyper parameters according to cell type | ||
if 'cifar' in config.dataset: | ||
if param_size < 3.0: | ||
config.weight_decay = 3e-4 | ||
config.drop_path_prob = 0.2 | ||
elif param_size > 3.0 and param_size < 3.5: | ||
config.weight_decay = 3e-4 | ||
config.drop_path_prob = 0.3 | ||
else: | ||
config.weight_decay = 5e-4 | ||
config.drop_path_prob = 0.3 | ||
|
||
if config.local_rank == 0: | ||
logger.info("Current weight decay: {}".format(config.weight_decay)) | ||
logger.info("Current drop path prob: {}".format(config.drop_path_prob)) | ||
|
||
controller.model_main = apex.parallel.convert_syncbn_model(controller.model_main) | ||
# weights optimizer | ||
optimizer = torch.optim.SGD(controller.model_main.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay) | ||
# optimizer = torch.optim.SGD(controller.model_main.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, nesterov=True) | ||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs) | ||
|
||
if config.use_amp: | ||
controller.model_main, optimizer = amp.initialize(controller.model_main, optimizer, opt_level=config.opt_level) | ||
|
||
if config.distributed: | ||
controller.model_main = DDP(controller.model_main, delay_allreduce=True) | ||
|
||
best_top1 = 0. | ||
best_top5 = 0. | ||
sta_epoch = 0 | ||
# training loop | ||
if config.resume: | ||
optimizer.load_state_dict(resume_state['optimizer']) | ||
lr_scheduler.load_state_dict(resume_state['lr_scheduler']) | ||
best_top1 = resume_state['best_top1'] | ||
best_top5 = resume_state['best_top5'] | ||
sta_epoch = resume_state['sta_epoch'] | ||
|
||
epoch_pool = [220, 230, 235, 240, 245] | ||
for epoch in range(sta_epoch, config.epochs): | ||
# reset iterators | ||
train_sampler.set_epoch(epoch) | ||
valid_sampler.set_epoch(epoch) | ||
current_lr = lr_scheduler.get_lr()[0] | ||
# current_lr = utils.adjust_lr(optimizer, epoch, config) | ||
|
||
if config.local_rank == 0: | ||
logger.info('Epoch: %d lr %e', epoch, current_lr) | ||
if epoch < config.warmup_epochs and config.total_batch_size > 256: | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = current_lr * (epoch + 1) / 5.0 | ||
if config.local_rank == 0: | ||
logger.info('Warming-up Epoch: %d, LR: %e', epoch, current_lr * (epoch + 1) / 5.0) | ||
|
||
drop_prob = config.drop_path_prob * epoch / config.epochs | ||
controller.model_main.module.drop_path_prob(drop_prob) | ||
|
||
# training | ||
train(train_loader, controller.model_main, optimizer, epoch, writer, logger, config) | ||
|
||
# validation | ||
cur_step = (epoch+1) * len(train_loader) | ||
top1, top5 = validate(valid_loader, controller.model_main, epoch, cur_step, writer, logger, config) | ||
|
||
if 'cifar' in config.dataset: | ||
lr_scheduler.step() | ||
elif 'imagenet' in config.dataset: | ||
lr_scheduler.step() | ||
# current_lr = utils.adjust_lr(optimizer, epoch, config) | ||
else: | ||
raise Exception('Lr error!') | ||
|
||
# save | ||
if best_top1 < top1: | ||
best_top1 = top1 | ||
best_top5 = top5 | ||
is_best = True | ||
else: | ||
is_best = False | ||
|
||
# save | ||
if config.local_rank == 0: | ||
if ('imagenet' in config.dataset) and ((epoch+1) in epoch_pool) and (not config.resume) and (config.local_rank == 0): | ||
torch.save({ | ||
"model_main":controller.model_main.module.state_dict(), | ||
"optimizer":optimizer.state_dict(), | ||
"lr_scheduler":lr_scheduler.state_dict(), | ||
"best_top1":best_top1, | ||
"best_top5":best_top5, | ||
"sta_epoch":epoch + 1 | ||
}, os.path.join(config.path, "epoch_{}.pth.tar".format(epoch+1))) | ||
utils.save_checkpoint(controller.model_main.module.state_dict(), config.path, is_best) | ||
|
||
torch.save({ | ||
"model_main":controller.model_main.module.state_dict(), | ||
"optimizer":optimizer.state_dict(), | ||
"lr_scheduler":lr_scheduler.state_dict(), | ||
"best_top1":best_top1, | ||
"best_top5":best_top5, | ||
"sta_epoch":epoch + 1 | ||
}, os.path.join(config.path, "retrain_resume.pth.tar")) | ||
utils.save_checkpoint(controller.model_main.module.state_dict(), config.path, is_best) | ||
|
||
if config.local_rank == 0: | ||
logger.info("Final best Prec@1 = {:.4%}, Prec@5 = {:.4%}".format(best_top1, best_top5)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
NGPUS=1 | ||
SGPU=0 | ||
EGPU=$[NGPUS+SGPU-1] | ||
GPU_ID=`seq -s , $SGPU $EGPU` | ||
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS ./CDARTS/retrain.py \ | ||
--name cifar10-retrain --dataset cifar10 --model_type cifar \ | ||
--n_classes 10 --init_channels 36 --stem_multiplier 3 \ | ||
--cell_file './genotypes.json' \ | ||
--batch_size 128 --workers 1 --print_freq 100 \ | ||
--world_size $NGPUS --weight_decay 5e-4 \ | ||
--distributed --dist_url 'tcp://127.0.0.1:26443' \ | ||
--lr 0.025 --warmup_epochs 0 --epochs 600 \ | ||
--cutout_length 16 --aux_weight 0.4 --drop_path_prob 0.3 \ | ||
--label_smooth 0.0 --mixup_alpha 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
NGPUS=4 | ||
SGPU=0 | ||
EGPU=$[NGPUS+SGPU-1] | ||
GPU_ID=`seq -s , $SGPU $EGPU` | ||
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS ./CDARTS/retrain.py \ | ||
--name cifar10-retrain --dataset cifar10 --model_type cifar \ | ||
--n_classes 10 --init_channels 36 --stem_multiplier 3 \ | ||
--cell_file './genotypes.json' \ | ||
--batch_size 128 --workers 1 --print_freq 100 \ | ||
--world_size $NGPUS --weight_decay 5e-4 \ | ||
--distributed --dist_url 'tcp://127.0.0.1:26443' \ | ||
--lr 0.1 --warmup_epochs 0 --epochs 600 \ | ||
--cutout_length 16 --aux_weight 0.4 --drop_path_prob 0.3 \ | ||
--label_smooth 0.0 --mixup_alpha 0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
NGPUS=8 | ||
SGPU=0 | ||
EGPU=$[NGPUS+SGPU-1] | ||
GPU_ID=`seq -s , $SGPU $EGPU` | ||
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS ./CDARTS/retrain.py \ | ||
--name imagenet-retrain --dataset imagenet --model_type imagenet \ | ||
--n_classes 1000 --init_channels 48 --stem_multiplier 1 \ | ||
--batch_size 128 --workers 4 --print_freq 100 \ | ||
--cell_file './genotypes.json' \ | ||
--world_size $NGPUS --weight_decay 3e-5 \ | ||
--distributed --dist_url 'tcp://127.0.0.1:24443' \ | ||
--lr 0.5 --warmup_epochs 5 --epochs 250 \ | ||
--cutout_length 0 --aux_weight 0.4 --drop_path_prob 0.0 \ | ||
--label_smooth 0.1 --mixup_alpha 0 \ | ||
--resume_name "retrain_resume.pth.tar" |
Oops, something went wrong.