diff --git a/argload/ArgumentLoader.py b/argload/ArgumentLoader.py index 4cbddcd..027af93 100644 --- a/argload/ArgumentLoader.py +++ b/argload/ArgumentLoader.py @@ -1,10 +1,12 @@ """ Defining ArgumentLoader class """ import os.path as path -from copy import copy +import os +from copy import copy, deepcopy import pickle from argparse import Namespace import sys + class ArgumentLoader(object): """ Wrap around a parser, to make it automatically reload arguments. @@ -21,6 +23,8 @@ def __init__(self, parser, to_reload): self._parser.add_argument('--overwrite', action='store_true', help='Flag explicitly specifying that specified arguments must ' 'overwrite reloaded arguments.') + self._parser.add_argument('--name_from_args', action='store_true', + help='Create sub-logdir name from non-default arguments ') self._parser.add_argument('--dump', action='store_true', help='Flag specifying that overwritten arguments will replace ' 'previous reloading arguments') @@ -32,7 +36,7 @@ def parse_known_args(self, args=None, namespace=None): specified_args = Namespace() raw_args = sys.argv[1:] if args is None else args self._parser._parse_known_args(raw_args, specified_args) - specified_args = vars(specified_args).keys() + specified_keys = vars(specified_args).keys() # starts by parsing current arguments args, argv = self._parser.parse_known_args(args, namespace) @@ -41,6 +45,11 @@ def parse_known_args(self, args=None, namespace=None): if not path.exists(args.logdir): raise ValueError("Logdir does not exist.") + if args.name_from_args: + args.logdir = path.join(args.logdir, self._make_name(specified_args)) + if not path.exists(args.logdir): + os.makedirs(args.logdir) + # retrieve logdir overwrite and dump and delete them from args: we don't # want to store them in the args file logdir = args.logdir @@ -71,14 +80,13 @@ def parse_known_args(self, args=None, namespace=None): dumped_args = vars(pickle.load(f)) args = vars(args) - args = self._fuse_args(dumped_args, args, specified_args, overwrite) + args = self._fuse_args(dumped_args, args, specified_keys, overwrite) if dump: self._dump_args(args, args_file, readable_args_file) return args, argv - def _dump_args(self, args, args_file, readable_args_file): args_to_dump = copy(args) @@ -94,10 +102,10 @@ def _dump_args(self, args, args_file, readable_args_file): with open(readable_args_file, 'w') as f: print(vars(args_to_dump), file=f) - def _fuse_args(self, dumped_args, args, specified_args, overwrite): + def _fuse_args(self, dumped_args, args, specified_keys, overwrite): fused_args = {} for k in list(dumped_args.keys()) + list(args.keys()): - if k in dumped_args and k in specified_args and dumped_args[k] != args[k]: + if k in dumped_args and k in specified_keys and dumped_args[k] != args[k]: if overwrite: fused_args[k] = args[k] else: @@ -109,6 +117,27 @@ def _fuse_args(self, dumped_args, args, specified_args, overwrite): return Namespace(**fused_args) + def _make_name(self, specified_args): + specified_args = vars(deepcopy(specified_args)) + for key in ['root', 'logdir', 'overwrite', 'dump', 'name_from_args']: + try: + del specified_args[key] + except KeyError: + pass + + if not specified_args: + raise ValueError("When using '--name_from_args' flag, at least " + "one argument must be provided") + + # BEWARE: this naming is dependent on the order of arguments. + # Maybe use on ordered dictionnary instead, ordered lexicographically + name = '' + for key, arg in specified_args.items(): + if (arg is True) or (arg is False): + name += key + '_' + else: + name += key + '=' + str(arg) + '_' + return name[:-1] def parse_args(self, args=None, namespace=None): """ Parse args """ diff --git a/argload/tests/test_argload.py b/argload/tests/test_argload.py index 9059de7..e7a31a8 100644 --- a/argload/tests/test_argload.py +++ b/argload/tests/test_argload.py @@ -1,5 +1,5 @@ """ Unit testing for Argload class """ -from os import mkdir +from os import mkdir, path import argparse import shutil from unittest import TestCase @@ -121,3 +121,26 @@ def overwrite_default_test(self): self.assertEqual(args.b, 3.) finally: shutil.rmtree('log') + + def name_from_args_test(self): + """ Test normal use case """ + try: + parser = argparse.ArgumentParser() + parser.add_argument('--a', type=int, nargs='*') + parser.add_argument('--b', type=int, default=2) + parser = argload.ArgumentLoader(parser, ['a']) + mkdir('log') + + args = parser.parse_args(['--logdir', 'log', '--name_from_args', '--a', '3', '5']) + self.assertTrue(args.logdir == 'log/a=[3, 5]') + + parser.parse_args(['--logdir', 'log', '--name_from_args', '--a', '2', '5']) + self.assertTrue(path.exists('log/a=[2, 5]') and path.exists('log/a=[3, 5]')) + + args = parser.parse_args(['--logdir', 'log/a=[3, 5]']) + + self.assertTrue(args.a == [3, 5]) + self.assertTrue(args.b == 2) + self.assertTrue(args.logdir == 'log/a=[3, 5]') + finally: + shutil.rmtree('log')