-
Notifications
You must be signed in to change notification settings - Fork 1
Sourcery Starbot ⭐ refactored testvinder/autogluon #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,9 @@ | |
|
|
||
| def enas_unit(**kwvars): | ||
| def registered_class(Cls): | ||
|
|
||
|
|
||
|
|
||
| class enas_unit(ENAS_Unit): | ||
| def __init__(self, *args, **kwargs): | ||
| kwvars.update(kwargs) | ||
|
|
@@ -32,7 +35,7 @@ def node(self): | |
| arg = self._args[self.index] | ||
| if arg is None: return arg | ||
| summary = {} | ||
| name = self.module_list[self.index].__class__.__name__ + '(' | ||
| name = f'{self.module_list[self.index].__class__.__name__}(' | ||
| for k, v in json.loads(arg).items(): | ||
| if 'kernel' in k.lower(): | ||
| cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3", | ||
|
|
@@ -61,11 +64,16 @@ def get_config_grid(dict_space): | |
| config.update(constants) | ||
| return configs | ||
|
|
||
|
|
||
| return enas_unit | ||
|
|
||
| return registered_class | ||
|
|
||
| def enas_net(**kwvars): | ||
| def registered_class(Cls): | ||
|
|
||
|
|
||
|
|
||
|
Comment on lines
+74
to
+76
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| class ENAS_Net(Cls): | ||
| def __init__(self, *args, **kwargs): | ||
| kwvars.update(kwargs) | ||
|
|
@@ -134,7 +142,7 @@ def kwspaces(self): | |
| return self._kwspaces | ||
|
|
||
| def sample(self, **configs): | ||
| striped_keys = [k.split('.')[0] for k in configs.keys()] | ||
| striped_keys = [k.split('.')[0] for k in configs] | ||
| for k in striped_keys: | ||
| if isinstance(self._modules[k], ENAS_Unit): | ||
| self._modules[k].sample(configs[k]) | ||
|
|
@@ -168,7 +176,9 @@ def evaluate_latency(self, x): | |
| self._avg_latency = avg_latency | ||
| self.latency_evaluated = True | ||
|
|
||
|
|
||
| return ENAS_Net | ||
|
|
||
| return registered_class | ||
|
|
||
| class ENAS_Sequential(gluon.HybridBlock): | ||
|
|
@@ -279,10 +289,7 @@ def evaluate_latency(self, x): | |
| import time | ||
| # evaluate submodule latency | ||
| for k, op in self._modules.items(): | ||
| if hasattr(op, 'evaluate_latency'): | ||
| x = op.evaluate_latency(x) | ||
| else: | ||
| x = op(x) | ||
| x = op.evaluate_latency(x) if hasattr(op, 'evaluate_latency') else op(x) | ||
|
Comment on lines
-282
to
+292
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| # calc avg_latency | ||
| avg_latency = 0.0 | ||
| for k, op in self._modules.items(): | ||
|
|
@@ -297,9 +304,9 @@ def sample(self, **configs): | |
| self._modules[k].sample(v) | ||
|
|
||
| def __repr__(self): | ||
| reprstr = self.__class__.__name__ + '(' | ||
| reprstr = f'{self.__class__.__name__}(' | ||
| for i, op in self._modules.items(): | ||
| reprstr += '\n\t{}: {}'.format(i, op) | ||
| reprstr += f'\n\t{i}: {op}' | ||
|
Comment on lines
-300
to
+309
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| reprstr += ')\n' | ||
| return reprstr | ||
|
|
||
|
|
@@ -334,10 +341,10 @@ def kwspaces(self): | |
|
|
||
| @property | ||
| def nparams(self): | ||
| nparams = 0 | ||
| for _, v in self.module_list[self.index].collect_params().items(): | ||
| nparams += v.data().size | ||
| return nparams | ||
| return sum( | ||
| v.data().size | ||
| for _, v in self.module_list[self.index].collect_params().items() | ||
| ) | ||
|
Comment on lines
-337
to
+347
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| @property | ||
| def latency(self): | ||
|
|
@@ -374,6 +381,4 @@ def __len__(self): | |
| return len(self.module_list) | ||
|
|
||
| def __repr__(self): | ||
| reprstr = self.__class__.__name__ + '(num of choices: {}), current architecture:\n\t {}' \ | ||
| .format(len(self.module_list), self.module_list[self.index]) | ||
| return reprstr | ||
| return f'{self.__class__.__name__}(num of choices: {len(self.module_list)}), current architecture:\n\t {self.module_list[self.index]}' | ||
|
Comment on lines
-377
to
+384
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,8 +44,10 @@ def __init__(self, blocks_args=[], dropout_rate=0.2, num_classes=1000, input_siz | |
| block_arg.update(in_channels=out_channels, stride=1, | ||
| input_size=input_size) | ||
|
|
||
| for _ in range(block_arg.num_repeat - 1): | ||
| _blocks.append(ENAS_MbBlock(**block_arg)) | ||
| _blocks.extend( | ||
| ENAS_MbBlock(**block_arg) | ||
| for _ in range(block_arg.num_repeat - 1) | ||
| ) | ||
|
Comment on lines
-47
to
+50
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| if blocks is not None: | ||
| self._blocks = ENAS_Sequential(blocks) | ||
|
|
@@ -98,19 +100,128 @@ def hybrid_forward(self, F, x): | |
|
|
||
| def get_enas_blockargs(): | ||
| """ Creates a predefined efficientnet model, which searched by original paper. """ | ||
| blocks_args = [ | ||
| Dict(kernel=3, num_repeat=1, channels=16, expand_ratio=1, stride=1, se_ratio=0.25, in_channels=32), | ||
| Dict(kernel=3, num_repeat=1, channels=16, expand_ratio=1, stride=1, se_ratio=0.25, in_channels=16, with_zero=True), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=24, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=16), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=3, channels=24, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=24, with_zero=True), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=40, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=24), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=3, channels=40, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=40, with_zero=True), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=80, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=40), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=4, channels=80, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=80, with_zero=True), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=112, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=80), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=4, channels=112, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=112, with_zero=True), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=1, channels=192, expand_ratio=Categorical(3, 6), stride=2, se_ratio=0.25, in_channels=112), | ||
| Dict(kernel=Categorical(3, 5, 7), num_repeat=5, channels=192, expand_ratio=Categorical(3, 6), stride=1, se_ratio=0.25, in_channels=192, with_zero=True), | ||
| Dict(kernel=3, num_repeat=1, channels=320, expand_ratio=6, stride=1, se_ratio=0.25, in_channels=192), | ||
| return [ | ||
| Dict( | ||
| kernel=3, | ||
| num_repeat=1, | ||
| channels=16, | ||
| expand_ratio=1, | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=32, | ||
| ), | ||
| Dict( | ||
| kernel=3, | ||
| num_repeat=1, | ||
| channels=16, | ||
| expand_ratio=1, | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=16, | ||
| with_zero=True, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=1, | ||
| channels=24, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=2, | ||
| se_ratio=0.25, | ||
| in_channels=16, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=3, | ||
| channels=24, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=24, | ||
| with_zero=True, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=1, | ||
| channels=40, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=2, | ||
| se_ratio=0.25, | ||
| in_channels=24, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=3, | ||
| channels=40, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=40, | ||
| with_zero=True, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=1, | ||
| channels=80, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=2, | ||
| se_ratio=0.25, | ||
| in_channels=40, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=4, | ||
| channels=80, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=80, | ||
| with_zero=True, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=1, | ||
| channels=112, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=80, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=4, | ||
| channels=112, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=112, | ||
| with_zero=True, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=1, | ||
| channels=192, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=2, | ||
| se_ratio=0.25, | ||
| in_channels=112, | ||
| ), | ||
| Dict( | ||
| kernel=Categorical(3, 5, 7), | ||
| num_repeat=5, | ||
| channels=192, | ||
| expand_ratio=Categorical(3, 6), | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=192, | ||
| with_zero=True, | ||
| ), | ||
| Dict( | ||
| kernel=3, | ||
| num_repeat=1, | ||
| channels=320, | ||
| expand_ratio=6, | ||
| stride=1, | ||
| se_ratio=0.25, | ||
| in_channels=192, | ||
| ), | ||
| ] | ||
| return blocks_args | ||
|
Comment on lines
-101
to
-116
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,8 +113,7 @@ def run(self): | |
| # for recordio data | ||
| if hasattr(self.train_data, 'reset'): self.train_data.reset() | ||
| tbar = tqdm(self.train_data) | ||
| idx = 0 | ||
| for batch in tbar: | ||
| for idx, batch in enumerate(tbar): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| # sample network configuration | ||
| config = self.controller.pre_sample()[0] | ||
| self.supernet.sample(**config) | ||
|
|
@@ -129,7 +128,6 @@ def run(self): | |
| tbar.set_svg(graph._repr_svg_()) | ||
| if self.baseline: | ||
| tbar.set_description('avg reward: {:.2f}'.format(self.baseline)) | ||
| idx += 1 | ||
| self.validation() | ||
| self.save() | ||
| msg = 'epoch {}, val_acc: {:.2f}'.format(epoch, self.val_acc) | ||
|
|
@@ -150,7 +148,7 @@ def validation(self): | |
| for batch in tbar: | ||
| self.eval_fn(self.supernet, batch, metric=metric, **self.val_args) | ||
| reward = metric.get()[1] | ||
| tbar.set_description('Val Acc: {}'.format(reward)) | ||
| tbar.set_description(f'Val Acc: {reward}') | ||
|
Comment on lines
-153
to
+151
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| self.val_acc = reward | ||
| self.training_history.append(reward) | ||
|
|
@@ -195,7 +193,7 @@ def train_controller(self): | |
| self.eval_fn(self.supernet, batch, metric=metric, **self.val_args) | ||
| reward = metric.get()[1] | ||
| reward = self.reward_fn(reward, self.supernet) | ||
| self.baseline = reward if not self.baseline else self.baseline | ||
| self.baseline = self.baseline or reward | ||
|
Comment on lines
-198
to
+196
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| # substract baseline | ||
| avg_rewards = mx.nd.array([reward - self.baseline], | ||
| ctx=self.controller.context) | ||
|
|
@@ -213,12 +211,12 @@ def train_controller(self): | |
| self._prefetch_controller() | ||
|
|
||
| def load(self, checkname=None): | ||
| checkname = checkname if checkname else self.checkname | ||
| checkname = checkname or self.checkname | ||
|
Comment on lines
-216
to
+214
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| state_dict = load(checkname) | ||
| self.load_state_dict(state_dict) | ||
|
|
||
| def save(self, checkname=None): | ||
| checkname = checkname if checkname else self.checkname | ||
| checkname = checkname or self.checkname | ||
|
Comment on lines
-221
to
+219
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| mkdir(os.path.dirname(checkname)) | ||
| save(self.state_dict(), checkname) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,16 +5,17 @@ | |
| from ...scheduler.resource import get_gpu_count | ||
|
|
||
| def default_reward_fn(metric, net): | ||
| reward = metric * ((net.avg_latency / net.latency) ** 0.07) | ||
| return reward | ||
| return metric * ((net.avg_latency / net.latency) ** 0.07) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| def init_default_train_args(batch_size, net, epochs, iters_per_epoch): | ||
| train_args = {} | ||
| base_lr = 0.1 * batch_size / 256 | ||
| lr_scheduler = gcv.utils.LRScheduler('cosine', base_lr=base_lr, target_lr=0.0001, | ||
| nepochs=epochs, iters_per_epoch=iters_per_epoch) | ||
| optimizer_params = {'wd': 1e-4, 'momentum': 0.9, 'lr_scheduler': lr_scheduler} | ||
| train_args['trainer'] = gluon.Trainer(net.collect_params(), 'sgd', optimizer_params) | ||
| train_args = { | ||
| 'trainer': gluon.Trainer(net.collect_params(), 'sgd', optimizer_params) | ||
| } | ||
|
|
||
|
Comment on lines
-12
to
+18
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| train_args['batch_size'] = batch_size | ||
| train_args['criterion'] = gluon.loss.SoftmaxCrossEntropyLoss() | ||
| return train_args | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,9 +32,7 @@ def sample_config(args, config): | |
| if isinstance(v, NestedSpace): | ||
| sub_config = _strip_config_space(config, prefix=k) | ||
| args_dict[k] = v.sample(**sub_config) | ||
| else: | ||
| if SPLITTER in k: | ||
| continue | ||
| elif SPLITTER not in k: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| args_dict[k] = config[k] | ||
| elif isinstance(v, AutoGluonObject): | ||
| args_dict[k] = v.init() | ||
|
|
@@ -75,13 +73,11 @@ def update(self, **kwargs): | |
| """ | ||
| self.kwvars.update(kwargs) | ||
| for k, v in self.kwvars.items(): | ||
| if isinstance(v, (NestedSpace)): | ||
| if isinstance(v, (NestedSpace)) or not isinstance(v, Space): | ||
| self.args.update({k: v}) | ||
| elif isinstance(v, Space): | ||
| else: | ||
| hp = v.get_hp(name=k) | ||
| self.args.update({k: hp.default_value}) | ||
| else: | ||
| self.args.update({k: v}) | ||
|
Comment on lines
-78
to
-84
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| @property | ||
| def cs(self): | ||
|
|
@@ -104,9 +100,9 @@ def kwspaces(self): | |
| for k, v in self.kwvars.items(): | ||
| if isinstance(v, NestedSpace): | ||
| if isinstance(v, Categorical): | ||
| kw_spaces['{}{}choice'.format(k, SPLITTER)] = v | ||
| kw_spaces[f'{k}{SPLITTER}choice'] = v | ||
| for sub_k, sub_v in v.kwspaces.items(): | ||
| new_k = '{}{}{}'.format(k, SPLITTER, sub_k) | ||
| new_k = f'{k}{SPLITTER}{sub_k}' | ||
|
Comment on lines
-107
to
+105
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| kw_spaces[new_k] = sub_v | ||
| elif isinstance(v, Space): | ||
| kw_spaces[k] = v | ||
|
|
@@ -134,7 +130,7 @@ def args(default=None, **kwvars): | |
| ... print('Batch size is {}, LR is {}'.format(args.batch_size, arg.lr)) | ||
| """ | ||
| if default is None: | ||
| default = dict() | ||
| default = {} | ||
|
Comment on lines
-137
to
+133
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| kwvars['_default_config'] = default | ||
| def registered_func(func): | ||
| @_autogluon_method | ||
|
|
@@ -263,6 +259,9 @@ def wrapper_call(*args, **kwargs): | |
| return registered_func | ||
|
|
||
| def registered_class(Cls): | ||
|
|
||
|
|
||
|
|
||
|
Comment on lines
+262
to
+264
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| class autogluonobject(AutoGluonObject): | ||
| @_autogluon_kwargs_obj(**kwvars) | ||
| def __init__(self, *args, **kwargs): | ||
|
|
@@ -284,7 +283,8 @@ def sample(self, **config): | |
| return Cls(*args, **kwargs) | ||
|
|
||
| def __repr__(self): | ||
| return 'AutoGluonObject -- ' + Cls.__name__ | ||
| return f'AutoGluonObject -- {Cls.__name__}' | ||
|
|
||
|
|
||
| autogluonobject.kwvars = autogluonobject.__init__.kwvars | ||
| autogluonobject.__doc__ = Cls.__doc__ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
enas_unitrefactored with the following changes:use-fstring-for-concatenation)