Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions autogluon/contrib/enas/enas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

def enas_unit(**kwvars):
def registered_class(Cls):



Comment on lines +15 to +17
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function enas_unit refactored with the following changes:

class enas_unit(ENAS_Unit):
def __init__(self, *args, **kwargs):
kwvars.update(kwargs)
Expand All @@ -32,7 +35,7 @@ def node(self):
arg = self._args[self.index]
if arg is None: return arg
summary = {}
name = self.module_list[self.index].__class__.__name__ + '('
name = f'{self.module_list[self.index].__class__.__name__}('
for k, v in json.loads(arg).items():
if 'kernel' in k.lower():
cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
Expand Down Expand Up @@ -61,11 +64,16 @@ def get_config_grid(dict_space):
config.update(constants)
return configs


return enas_unit

return registered_class

def enas_net(**kwvars):
def registered_class(Cls):



Comment on lines +74 to +76
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function enas_net refactored with the following changes:

class ENAS_Net(Cls):
def __init__(self, *args, **kwargs):
kwvars.update(kwargs)
Expand Down Expand Up @@ -134,7 +142,7 @@ def kwspaces(self):
return self._kwspaces

def sample(self, **configs):
striped_keys = [k.split('.')[0] for k in configs.keys()]
striped_keys = [k.split('.')[0] for k in configs]
for k in striped_keys:
if isinstance(self._modules[k], ENAS_Unit):
self._modules[k].sample(configs[k])
Expand Down Expand Up @@ -168,7 +176,9 @@ def evaluate_latency(self, x):
self._avg_latency = avg_latency
self.latency_evaluated = True


return ENAS_Net

return registered_class

class ENAS_Sequential(gluon.HybridBlock):
Expand Down Expand Up @@ -279,10 +289,7 @@ def evaluate_latency(self, x):
import time
# evaluate submodule latency
for k, op in self._modules.items():
if hasattr(op, 'evaluate_latency'):
x = op.evaluate_latency(x)
else:
x = op(x)
x = op.evaluate_latency(x) if hasattr(op, 'evaluate_latency') else op(x)
Comment on lines -282 to +292
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Sequential.evaluate_latency refactored with the following changes:

# calc avg_latency
avg_latency = 0.0
for k, op in self._modules.items():
Expand All @@ -297,9 +304,9 @@ def sample(self, **configs):
self._modules[k].sample(v)

def __repr__(self):
reprstr = self.__class__.__name__ + '('
reprstr = f'{self.__class__.__name__}('
for i, op in self._modules.items():
reprstr += '\n\t{}: {}'.format(i, op)
reprstr += f'\n\t{i}: {op}'
Comment on lines -300 to +309
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Sequential.__repr__ refactored with the following changes:

reprstr += ')\n'
return reprstr

Expand Down Expand Up @@ -334,10 +341,10 @@ def kwspaces(self):

@property
def nparams(self):
nparams = 0
for _, v in self.module_list[self.index].collect_params().items():
nparams += v.data().size
return nparams
return sum(
v.data().size
for _, v in self.module_list[self.index].collect_params().items()
)
Comment on lines -337 to +347
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Unit.nparams refactored with the following changes:


@property
def latency(self):
Expand Down Expand Up @@ -374,6 +381,4 @@ def __len__(self):
return len(self.module_list)

def __repr__(self):
reprstr = self.__class__.__name__ + '(num of choices: {}), current architecture:\n\t {}' \
.format(len(self.module_list), self.module_list[self.index])
return reprstr
return f'{self.__class__.__name__}(num of choices: {len(self.module_list)}), current architecture:\n\t {self.module_list[self.index]}'
Comment on lines -377 to +384
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Unit.__repr__ refactored with the following changes:

145 changes: 128 additions & 17 deletions autogluon/contrib/enas/enas_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def __init__(self, blocks_args=[], dropout_rate=0.2, num_classes=1000, input_siz
block_arg.update(in_channels=out_channels, stride=1,
input_size=input_size)

for _ in range(block_arg.num_repeat - 1):
_blocks.append(ENAS_MbBlock(**block_arg))
_blocks.extend(
ENAS_MbBlock(**block_arg)
for _ in range(block_arg.num_repeat - 1)
)
Comment on lines -47 to +50
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_MBNet.__init__ refactored with the following changes:


if blocks is not None:
self._blocks = ENAS_Sequential(blocks)
Expand Down Expand Up @@ -98,19 +100,128 @@ def hybrid_forward(self, F, x):

def get_enas_blockargs():
""" Creates a predefined efficientnet model, which searched by original paper. """
blocks_args = [
Dict(kernel=3, num_repeat=1, channels=16, expand_ratio=1, stride=1, se_ratio=0.25, in_channels=32),
Dict(kernel=3, num_repeat=1, channels=16, expand_ratio=1, stride=1, se_ratio=0.25, in_channels=16, with_zero=True),
Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=24, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=16),
Dict(kernel=Categorical(3, 5, 7), num_repeat=3, channels=24, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=24, with_zero=True),
Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=40, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=24),
Dict(kernel=Categorical(3, 5, 7), num_repeat=3, channels=40, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=40, with_zero=True),
Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=80, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=40),
Dict(kernel=Categorical(3, 5, 7), num_repeat=4, channels=80, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=80, with_zero=True),
Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=112, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=80),
Dict(kernel=Categorical(3, 5, 7), num_repeat=4, channels=112, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=112, with_zero=True),
Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=192, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=112),
Dict(kernel=Categorical(3, 5, 7), num_repeat=5, channels=192, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=192, with_zero=True),
Dict(kernel=3, num_repeat=1, channels=320, expand_ratio=6, stride=1, se_ratio=0.25, in_channels=192),
return [
Dict(
kernel=3,
num_repeat=1,
channels=16,
expand_ratio=1,
stride=1,
se_ratio=0.25,
in_channels=32,
),
Dict(
kernel=3,
num_repeat=1,
channels=16,
expand_ratio=1,
stride=1,
se_ratio=0.25,
in_channels=16,
with_zero=True,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=1,
channels=24,
expand_ratio=Categorical(3, 6),
stride=2,
se_ratio=0.25,
in_channels=16,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=3,
channels=24,
expand_ratio=Categorical(3, 6),
stride=1,
se_ratio=0.25,
in_channels=24,
with_zero=True,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=1,
channels=40,
expand_ratio=Categorical(3, 6),
stride=2,
se_ratio=0.25,
in_channels=24,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=3,
channels=40,
expand_ratio=Categorical(3, 6),
stride=1,
se_ratio=0.25,
in_channels=40,
with_zero=True,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=1,
channels=80,
expand_ratio=Categorical(3, 6),
stride=2,
se_ratio=0.25,
in_channels=40,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=4,
channels=80,
expand_ratio=Categorical(3, 6),
stride=1,
se_ratio=0.25,
in_channels=80,
with_zero=True,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=1,
channels=112,
expand_ratio=Categorical(3, 6),
stride=1,
se_ratio=0.25,
in_channels=80,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=4,
channels=112,
expand_ratio=Categorical(3, 6),
stride=1,
se_ratio=0.25,
in_channels=112,
with_zero=True,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=1,
channels=192,
expand_ratio=Categorical(3, 6),
stride=2,
se_ratio=0.25,
in_channels=112,
),
Dict(
kernel=Categorical(3, 5, 7),
num_repeat=5,
channels=192,
expand_ratio=Categorical(3, 6),
stride=1,
se_ratio=0.25,
in_channels=192,
with_zero=True,
),
Dict(
kernel=3,
num_repeat=1,
channels=320,
expand_ratio=6,
stride=1,
se_ratio=0.25,
in_channels=192,
),
]
return blocks_args
Comment on lines -101 to -116
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_enas_blockargs refactored with the following changes:

12 changes: 5 additions & 7 deletions autogluon/contrib/enas/enas_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def run(self):
# for recordio data
if hasattr(self.train_data, 'reset'): self.train_data.reset()
tbar = tqdm(self.train_data)
idx = 0
for batch in tbar:
for idx, batch in enumerate(tbar):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Scheduler.run refactored with the following changes:

# sample network configuration
config = self.controller.pre_sample()[0]
self.supernet.sample(**config)
Expand All @@ -129,7 +128,6 @@ def run(self):
tbar.set_svg(graph._repr_svg_())
if self.baseline:
tbar.set_description('avg reward: {:.2f}'.format(self.baseline))
idx += 1
self.validation()
self.save()
msg = 'epoch {}, val_acc: {:.2f}'.format(epoch, self.val_acc)
Expand All @@ -150,7 +148,7 @@ def validation(self):
for batch in tbar:
self.eval_fn(self.supernet, batch, metric=metric, **self.val_args)
reward = metric.get()[1]
tbar.set_description('Val Acc: {}'.format(reward))
tbar.set_description(f'Val Acc: {reward}')
Comment on lines -153 to +151
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Scheduler.validation refactored with the following changes:


self.val_acc = reward
self.training_history.append(reward)
Expand Down Expand Up @@ -195,7 +193,7 @@ def train_controller(self):
self.eval_fn(self.supernet, batch, metric=metric, **self.val_args)
reward = metric.get()[1]
reward = self.reward_fn(reward, self.supernet)
self.baseline = reward if not self.baseline else self.baseline
self.baseline = self.baseline or reward
Comment on lines -198 to +196
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Scheduler.train_controller refactored with the following changes:

# substract baseline
avg_rewards = mx.nd.array([reward - self.baseline],
ctx=self.controller.context)
Expand All @@ -213,12 +211,12 @@ def train_controller(self):
self._prefetch_controller()

def load(self, checkname=None):
checkname = checkname if checkname else self.checkname
checkname = checkname or self.checkname
Comment on lines -216 to +214
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Scheduler.load refactored with the following changes:

state_dict = load(checkname)
self.load_state_dict(state_dict)

def save(self, checkname=None):
checkname = checkname if checkname else self.checkname
checkname = checkname or self.checkname
Comment on lines -221 to +219
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENAS_Scheduler.save refactored with the following changes:

mkdir(os.path.dirname(checkname))
save(self.state_dict(), checkname)

Expand Down
9 changes: 5 additions & 4 deletions autogluon/contrib/enas/enas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from ...scheduler.resource import get_gpu_count

def default_reward_fn(metric, net):
reward = metric * ((net.avg_latency / net.latency) ** 0.07)
return reward
return metric * ((net.avg_latency / net.latency) ** 0.07)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function default_reward_fn refactored with the following changes:


def init_default_train_args(batch_size, net, epochs, iters_per_epoch):
train_args = {}
base_lr = 0.1 * batch_size / 256
lr_scheduler = gcv.utils.LRScheduler('cosine', base_lr=base_lr, target_lr=0.0001,
nepochs=epochs, iters_per_epoch=iters_per_epoch)
optimizer_params = {'wd': 1e-4, 'momentum': 0.9, 'lr_scheduler': lr_scheduler}
train_args['trainer'] = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params)
train_args = {
'trainer': gluon.Trainer(net.collect_params(), 'sgd', optimizer_params)
}

Comment on lines -12 to +18
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function init_default_train_args refactored with the following changes:

train_args['batch_size'] = batch_size
train_args['criterion'] = gluon.loss.SoftmaxCrossEntropyLoss()
return train_args
22 changes: 11 additions & 11 deletions autogluon/core/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def sample_config(args, config):
if isinstance(v, NestedSpace):
sub_config = _strip_config_space(config, prefix=k)
args_dict[k] = v.sample(**sub_config)
else:
if SPLITTER in k:
continue
elif SPLITTER not in k:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function sample_config refactored with the following changes:

args_dict[k] = config[k]
elif isinstance(v, AutoGluonObject):
args_dict[k] = v.init()
Expand Down Expand Up @@ -75,13 +73,11 @@ def update(self, **kwargs):
"""
self.kwvars.update(kwargs)
for k, v in self.kwvars.items():
if isinstance(v, (NestedSpace)):
if isinstance(v, (NestedSpace)) or not isinstance(v, Space):
self.args.update({k: v})
elif isinstance(v, Space):
else:
hp = v.get_hp(name=k)
self.args.update({k: hp.default_value})
else:
self.args.update({k: v})
Comment on lines -78 to -84
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _autogluon_method.update refactored with the following changes:


@property
def cs(self):
Expand All @@ -104,9 +100,9 @@ def kwspaces(self):
for k, v in self.kwvars.items():
if isinstance(v, NestedSpace):
if isinstance(v, Categorical):
kw_spaces['{}{}choice'.format(k, SPLITTER)] = v
kw_spaces[f'{k}{SPLITTER}choice'] = v
for sub_k, sub_v in v.kwspaces.items():
new_k = '{}{}{}'.format(k, SPLITTER, sub_k)
new_k = f'{k}{SPLITTER}{sub_k}'
Comment on lines -107 to +105
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function _autogluon_method.kwspaces refactored with the following changes:

kw_spaces[new_k] = sub_v
elif isinstance(v, Space):
kw_spaces[k] = v
Expand Down Expand Up @@ -134,7 +130,7 @@ def args(default=None, **kwvars):
... print('Batch size is {}, LR is {}'.format(args.batch_size, arg.lr))
"""
if default is None:
default = dict()
default = {}
Comment on lines -137 to +133
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function args refactored with the following changes:

kwvars['_default_config'] = default
def registered_func(func):
@_autogluon_method
Expand Down Expand Up @@ -263,6 +259,9 @@ def wrapper_call(*args, **kwargs):
return registered_func

def registered_class(Cls):



Comment on lines +262 to +264
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function obj refactored with the following changes:

class autogluonobject(AutoGluonObject):
@_autogluon_kwargs_obj(**kwvars)
def __init__(self, *args, **kwargs):
Expand All @@ -284,7 +283,8 @@ def sample(self, **config):
return Cls(*args, **kwargs)

def __repr__(self):
return 'AutoGluonObject -- ' + Cls.__name__
return f'AutoGluonObject -- {Cls.__name__}'


autogluonobject.kwvars = autogluonobject.__init__.kwvars
autogluonobject.__doc__ = Cls.__doc__
Expand Down
Loading