Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hongyuanyu committed Feb 13, 2022
1 parent 1233990 commit ab37675
Show file tree
Hide file tree
Showing 827 changed files with 94,397 additions and 0 deletions.
19 changes: 19 additions & 0 deletions CDARTS/_init_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import sys


def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)


this_dir = osp.dirname(__file__)
lib_path = osp.join(this_dir, '..', 'lib')
add_path(lib_path)

lib_path = osp.join(this_dir, '..')
add_path(lib_path)
1 change: 1 addition & 0 deletions CDARTS/cells/cifar_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('skip_connect', 2)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('dil_conv_5x5', 2), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('skip_connect', 2)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('dil_conv_5x5', 2), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('skip_connect', 2)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('dil_conv_5x5', 2), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)]], reduce_concat=range(2, 6))"}
1 change: 1 addition & 0 deletions CDARTS/cells/dartsv1_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('skip_connect', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 0)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('avg_pool_3x3', 0)]], reduce_concat=[2, 3, 4, 5])", "1": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('skip_connect', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 0)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('avg_pool_3x3', 0)]], reduce_concat=[2, 3, 4, 5])", "2": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('skip_connect', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('skip_connect', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 0)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('avg_pool_3x3', 0)]], reduce_concat=[2, 3, 4, 5])"}
1 change: 1 addition & 0 deletions CDARTS/cells/dartsv2_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('skip_connect', 0), ('dil_conv_3x3', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('max_pool_3x3', 1)]], reduce_concat=[2, 3, 4, 5])", "1": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('skip_connect', 0), ('dil_conv_3x3', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('max_pool_3x3', 1)]], reduce_concat=[2, 3, 4, 5])", "2": "Genotype(normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('skip_connect', 0)], [('skip_connect', 0), ('dil_conv_3x3', 2)]], normal_concat=[2, 3, 4, 5], reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)], [('skip_connect', 2), ('max_pool_3x3', 1)], [('max_pool_3x3', 0), ('skip_connect', 2)], [('skip_connect', 2), ('max_pool_3x3', 1)]], reduce_concat=[2, 3, 4, 5])"}
1 change: 1 addition & 0 deletions CDARTS/cells/imagenet_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], [('sep_conv_5x5', 2), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('avg_pool_3x3', 0)], [('max_pool_3x3', 1), ('dil_conv_5x5', 2)], [('sep_conv_5x5', 1), ('dil_conv_5x5', 2)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 4)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], [('sep_conv_5x5', 2), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('avg_pool_3x3', 0)], [('max_pool_3x3', 1), ('dil_conv_5x5', 2)], [('sep_conv_5x5', 1), ('dil_conv_5x5', 2)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 4)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('sep_conv_3x3', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('sep_conv_5x5', 2)], [('sep_conv_5x5', 2), ('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('avg_pool_3x3', 0)], [('max_pool_3x3', 1), ('dil_conv_5x5', 2)], [('sep_conv_5x5', 1), ('dil_conv_5x5', 2)], [('sep_conv_3x3', 3), ('dil_conv_3x3', 4)]], reduce_concat=range(2, 6))"}
1 change: 1 addition & 0 deletions CDARTS/cells/pcdarts_cifar_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('dil_conv_3x3', 1)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('sep_conv_5x5', 1), ('sep_conv_5x5', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('dil_conv_3x3', 1)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('sep_conv_5x5', 1), ('sep_conv_5x5', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('sep_conv_3x3', 1), ('skip_connect', 0)], [('sep_conv_3x3', 0), ('dil_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], [('avg_pool_3x3', 0), ('dil_conv_3x3', 1)]], normal_concat=range(2, 6), reduce=[[('sep_conv_5x5', 1), ('max_pool_3x3', 0)], [('sep_conv_5x5', 1), ('sep_conv_5x5', 2)], [('sep_conv_3x3', 0), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 2)]], reduce_concat=range(2, 6))"}
1 change: 1 addition & 0 deletions CDARTS/cells/pcdarts_imagenet_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('skip_connect', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('skip_connect', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_3x3', 0), ('skip_connect', 1)], [('dil_conv_5x5', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('skip_connect', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('skip_connect', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_3x3', 0), ('skip_connect', 1)], [('dil_conv_5x5', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('skip_connect', 1), ('sep_conv_3x3', 0)], [('sep_conv_3x3', 0), ('skip_connect', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3', 1), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('sep_conv_3x3', 0), ('skip_connect', 1)], [('dil_conv_5x5', 2), ('max_pool_3x3', 1)], [('sep_conv_3x3', 2), ('sep_conv_3x3', 1)], [('sep_conv_5x5', 0), ('sep_conv_3x3', 3)]], reduce_concat=range(2, 6))"}
1 change: 1 addition & 0 deletions CDARTS/cells/pdarts_genotype.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": "Genotype(normal=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('skip_connect', 0),('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3',0), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('avg_pool_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('dil_conv_3x3', 1)], [('dil_conv_3x3', 1), ('dil_conv_5x5', 3)]], reduce_concat=range(2, 6))", "1": "Genotype(normal=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('skip_connect', 0),('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3',0), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('avg_pool_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('dil_conv_3x3', 1)], [('dil_conv_3x3', 1), ('dil_conv_5x5', 3)]], reduce_concat=range(2, 6))", "2": "Genotype(normal=[[('skip_connect', 0), ('dil_conv_3x3', 1)], [('skip_connect', 0),('sep_conv_3x3', 1)], [('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], [('sep_conv_3x3',0), ('dil_conv_5x5', 4)]], normal_concat=range(2, 6), reduce=[[('avg_pool_3x3', 0), ('sep_conv_5x5', 1)], [('sep_conv_3x3', 0), ('dil_conv_5x5', 2)], [('max_pool_3x3', 0), ('dil_conv_3x3', 1)], [('dil_conv_3x3', 1), ('dil_conv_5x5', 3)]], reduce_concat=range(2, 6))"}
206 changes: 206 additions & 0 deletions CDARTS/retrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
""" Retrain cell """
import _init_paths
import os
import torch
import json
import torch.nn as nn
import numpy as np
import lib.utils.genotypes as gt

from tensorboardX import SummaryWriter
from lib.models.cdarts_controller import CDARTSController
from lib.utils import utils
from lib.config import AugmentConfig
from lib.core.augment_function import train, validate

# config
config = AugmentConfig()

# make apex optional
if config.distributed:
# DDP = torch.nn.parallel.DistributedDataParallel
try:
import apex
from apex.parallel import DistributedDataParallel as DDP
from apex import amp, optimizers
from apex.fp16_utils import *
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")


# tensorboard
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
writer.add_text('config', config.as_markdown(), 0)

logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name)))
if config.local_rank == 0:
config.print_params(logger.info)

if 'cifar' in config.dataset:
from lib.datasets.cifar import get_augment_datasets
elif 'imagenet' in config.dataset:
from lib.datasets.imagenet import get_augment_datasets
else:
raise Exception("Not support dataset!")

def main():
logger.info("Logger is set - training start")

# set seed
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

if config.distributed:
config.gpu = config.local_rank % torch.cuda.device_count()
torch.cuda.set_device(config.gpu)
# distributed init
torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url,
world_size=config.world_size, rank=config.local_rank)

config.world_size = torch.distributed.get_world_size()

config.total_batch_size = config.world_size * config.batch_size
else:
config.total_batch_size = config.batch_size


loaders, samplers = get_augment_datasets(config)
train_loader, valid_loader = loaders
train_sampler, valid_sampler = samplers

net_crit = nn.CrossEntropyLoss().cuda()
controller = CDARTSController(config, net_crit, n_nodes=4, stem_multiplier=config.stem_multiplier)

file = open(config.cell_file, 'r')
js = file.read()
r_dict = json.loads(js)
if config.local_rank == 0:
logger.info(r_dict)
file.close()
genotypes_dict = {}
for layer_idx, genotype in r_dict.items():
genotypes_dict[int(layer_idx)] = gt.from_str(genotype)

controller.build_augment_model(controller.init_channel, genotypes_dict)
resume_state = None
if config.resume:
resume_state = torch.load(config.resume_path, map_location='cpu')
controller.model_main.load_state_dict(resume_state['model_main'])

controller.model_main = controller.model_main.cuda()
param_size = utils.param_size(controller.model_main)
logger.info("param size = %fMB", param_size)

# change training hyper parameters according to cell type
if 'cifar' in config.dataset:
if param_size < 3.0:
config.weight_decay = 3e-4
config.drop_path_prob = 0.2
elif param_size > 3.0 and param_size < 3.5:
config.weight_decay = 3e-4
config.drop_path_prob = 0.3
else:
config.weight_decay = 5e-4
config.drop_path_prob = 0.3

if config.local_rank == 0:
logger.info("Current weight decay: {}".format(config.weight_decay))
logger.info("Current drop path prob: {}".format(config.drop_path_prob))

controller.model_main = apex.parallel.convert_syncbn_model(controller.model_main)
# weights optimizer
optimizer = torch.optim.SGD(controller.model_main.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
# optimizer = torch.optim.SGD(controller.model_main.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)

if config.use_amp:
controller.model_main, optimizer = amp.initialize(controller.model_main, optimizer, opt_level=config.opt_level)

if config.distributed:
controller.model_main = DDP(controller.model_main, delay_allreduce=True)

best_top1 = 0.
best_top5 = 0.
sta_epoch = 0
# training loop
if config.resume:
optimizer.load_state_dict(resume_state['optimizer'])
lr_scheduler.load_state_dict(resume_state['lr_scheduler'])
best_top1 = resume_state['best_top1']
best_top5 = resume_state['best_top5']
sta_epoch = resume_state['sta_epoch']

epoch_pool = [220, 230, 235, 240, 245]
for epoch in range(sta_epoch, config.epochs):
# reset iterators
train_sampler.set_epoch(epoch)
valid_sampler.set_epoch(epoch)
current_lr = lr_scheduler.get_lr()[0]
# current_lr = utils.adjust_lr(optimizer, epoch, config)

if config.local_rank == 0:
logger.info('Epoch: %d lr %e', epoch, current_lr)
if epoch < config.warmup_epochs and config.total_batch_size > 256:
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr * (epoch + 1) / 5.0
if config.local_rank == 0:
logger.info('Warming-up Epoch: %d, LR: %e', epoch, current_lr * (epoch + 1) / 5.0)

drop_prob = config.drop_path_prob * epoch / config.epochs
controller.model_main.module.drop_path_prob(drop_prob)

# training
train(train_loader, controller.model_main, optimizer, epoch, writer, logger, config)

# validation
cur_step = (epoch+1) * len(train_loader)
top1, top5 = validate(valid_loader, controller.model_main, epoch, cur_step, writer, logger, config)

if 'cifar' in config.dataset:
lr_scheduler.step()
elif 'imagenet' in config.dataset:
lr_scheduler.step()
# current_lr = utils.adjust_lr(optimizer, epoch, config)
else:
raise Exception('Lr error!')

# save
if best_top1 < top1:
best_top1 = top1
best_top5 = top5
is_best = True
else:
is_best = False

# save
if config.local_rank == 0:
if ('imagenet' in config.dataset) and ((epoch+1) in epoch_pool) and (not config.resume) and (config.local_rank == 0):
torch.save({
"model_main":controller.model_main.module.state_dict(),
"optimizer":optimizer.state_dict(),
"lr_scheduler":lr_scheduler.state_dict(),
"best_top1":best_top1,
"best_top5":best_top5,
"sta_epoch":epoch + 1
}, os.path.join(config.path, "epoch_{}.pth.tar".format(epoch+1)))
utils.save_checkpoint(controller.model_main.module.state_dict(), config.path, is_best)

torch.save({
"model_main":controller.model_main.module.state_dict(),
"optimizer":optimizer.state_dict(),
"lr_scheduler":lr_scheduler.state_dict(),
"best_top1":best_top1,
"best_top5":best_top5,
"sta_epoch":epoch + 1
}, os.path.join(config.path, "retrain_resume.pth.tar"))
utils.save_checkpoint(controller.model_main.module.state_dict(), config.path, is_best)

if config.local_rank == 0:
logger.info("Final best Prec@1 = {:.4%}, Prec@5 = {:.4%}".format(best_top1, best_top5))


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions CDARTS/scripts/run_retrain_cifar_1gpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
NGPUS=1
SGPU=0
EGPU=$[NGPUS+SGPU-1]
GPU_ID=`seq -s , $SGPU $EGPU`
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS ./CDARTS/retrain.py \
--name cifar10-retrain --dataset cifar10 --model_type cifar \
--n_classes 10 --init_channels 36 --stem_multiplier 3 \
--cell_file './genotypes.json' \
--batch_size 128 --workers 1 --print_freq 100 \
--world_size $NGPUS --weight_decay 5e-4 \
--distributed --dist_url 'tcp://127.0.0.1:26443' \
--lr 0.025 --warmup_epochs 0 --epochs 600 \
--cutout_length 16 --aux_weight 0.4 --drop_path_prob 0.3 \
--label_smooth 0.0 --mixup_alpha 0
15 changes: 15 additions & 0 deletions CDARTS/scripts/run_retrain_cifar_4gpus.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
NGPUS=4
SGPU=0
EGPU=$[NGPUS+SGPU-1]
GPU_ID=`seq -s , $SGPU $EGPU`
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS ./CDARTS/retrain.py \
--name cifar10-retrain --dataset cifar10 --model_type cifar \
--n_classes 10 --init_channels 36 --stem_multiplier 3 \
--cell_file './genotypes.json' \
--batch_size 128 --workers 1 --print_freq 100 \
--world_size $NGPUS --weight_decay 5e-4 \
--distributed --dist_url 'tcp://127.0.0.1:26443' \
--lr 0.1 --warmup_epochs 0 --epochs 600 \
--cutout_length 16 --aux_weight 0.4 --drop_path_prob 0.3 \
--label_smooth 0.0 --mixup_alpha 0

15 changes: 15 additions & 0 deletions CDARTS/scripts/run_retrain_imagenet.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
NGPUS=8
SGPU=0
EGPU=$[NGPUS+SGPU-1]
GPU_ID=`seq -s , $SGPU $EGPU`
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS ./CDARTS/retrain.py \
--name imagenet-retrain --dataset imagenet --model_type imagenet \
--n_classes 1000 --init_channels 48 --stem_multiplier 1 \
--batch_size 128 --workers 4 --print_freq 100 \
--cell_file './genotypes.json' \
--world_size $NGPUS --weight_decay 3e-5 \
--distributed --dist_url 'tcp://127.0.0.1:24443' \
--lr 0.5 --warmup_epochs 5 --epochs 250 \
--cutout_length 0 --aux_weight 0.4 --drop_path_prob 0.0 \
--label_smooth 0.1 --mixup_alpha 0 \
--resume_name "retrain_resume.pth.tar"
Loading

0 comments on commit ab37675

Please sign in to comment.