Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions argload/ArgumentLoader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
""" Defining ArgumentLoader class """
import os.path as path
from copy import copy
import os
from copy import copy, deepcopy
import pickle
from argparse import Namespace
import sys


class ArgumentLoader(object):
""" Wrap around a parser, to make it automatically reload arguments.

Expand All @@ -21,6 +23,8 @@ def __init__(self, parser, to_reload):
self._parser.add_argument('--overwrite', action='store_true',
help='Flag explicitly specifying that specified arguments must '
'overwrite reloaded arguments.')
self._parser.add_argument('--name_from_args', action='store_true',
help='Create sub-logdir name from non-default arguments ')
self._parser.add_argument('--dump', action='store_true',
help='Flag specifying that overwritten arguments will replace '
'previous reloading arguments')
Expand All @@ -32,7 +36,7 @@ def parse_known_args(self, args=None, namespace=None):
specified_args = Namespace()
raw_args = sys.argv[1:] if args is None else args
self._parser._parse_known_args(raw_args, specified_args)
specified_args = vars(specified_args).keys()
specified_keys = vars(specified_args).keys()

# starts by parsing current arguments
args, argv = self._parser.parse_known_args(args, namespace)
Expand All @@ -41,6 +45,11 @@ def parse_known_args(self, args=None, namespace=None):
if not path.exists(args.logdir):
raise ValueError("Logdir does not exist.")

if args.name_from_args:
args.logdir = path.join(args.logdir, self._make_name(specified_args))
if not path.exists(args.logdir):
os.makedirs(args.logdir)

# retrieve logdir overwrite and dump and delete them from args: we don't
# want to store them in the args file
logdir = args.logdir
Expand Down Expand Up @@ -71,14 +80,13 @@ def parse_known_args(self, args=None, namespace=None):
dumped_args = vars(pickle.load(f))

args = vars(args)
args = self._fuse_args(dumped_args, args, specified_args, overwrite)
args = self._fuse_args(dumped_args, args, specified_keys, overwrite)

if dump:
self._dump_args(args, args_file, readable_args_file)

return args, argv


def _dump_args(self, args, args_file, readable_args_file):
args_to_dump = copy(args)

Expand All @@ -94,10 +102,10 @@ def _dump_args(self, args, args_file, readable_args_file):
with open(readable_args_file, 'w') as f:
print(vars(args_to_dump), file=f)

def _fuse_args(self, dumped_args, args, specified_args, overwrite):
def _fuse_args(self, dumped_args, args, specified_keys, overwrite):
fused_args = {}
for k in list(dumped_args.keys()) + list(args.keys()):
if k in dumped_args and k in specified_args and dumped_args[k] != args[k]:
if k in dumped_args and k in specified_keys and dumped_args[k] != args[k]:
if overwrite:
fused_args[k] = args[k]
else:
Expand All @@ -109,6 +117,27 @@ def _fuse_args(self, dumped_args, args, specified_args, overwrite):

return Namespace(**fused_args)

def _make_name(self, specified_args):
specified_args = vars(deepcopy(specified_args))
for key in ['root', 'logdir', 'overwrite', 'dump', 'name_from_args']:
try:
del specified_args[key]
except KeyError:
pass

if not specified_args:
raise ValueError("When using '--name_from_args' flag, at least "
"one argument must be provided")

# BEWARE: this naming is dependent on the order of arguments.
# Maybe use on ordered dictionnary instead, ordered lexicographically
name = ''
for key, arg in specified_args.items():
if (arg is True) or (arg is False):
name += key + '_'
else:
name += key + '=' + str(arg) + '_'
return name[:-1]

def parse_args(self, args=None, namespace=None):
""" Parse args """
Expand Down
25 changes: 24 additions & 1 deletion argload/tests/test_argload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
""" Unit testing for Argload class """
from os import mkdir
from os import mkdir, path
import argparse
import shutil
from unittest import TestCase
Expand Down Expand Up @@ -121,3 +121,26 @@ def overwrite_default_test(self):
self.assertEqual(args.b, 3.)
finally:
shutil.rmtree('log')

def name_from_args_test(self):
""" Test normal use case """
try:
parser = argparse.ArgumentParser()
parser.add_argument('--a', type=int, nargs='*')
parser.add_argument('--b', type=int, default=2)
parser = argload.ArgumentLoader(parser, ['a'])
mkdir('log')

args = parser.parse_args(['--logdir', 'log', '--name_from_args', '--a', '3', '5'])
self.assertTrue(args.logdir == 'log/a=[3, 5]')

parser.parse_args(['--logdir', 'log', '--name_from_args', '--a', '2', '5'])
self.assertTrue(path.exists('log/a=[2, 5]') and path.exists('log/a=[3, 5]'))

args = parser.parse_args(['--logdir', 'log/a=[3, 5]'])

self.assertTrue(args.a == [3, 5])
self.assertTrue(args.b == 2)
self.assertTrue(args.logdir == 'log/a=[3, 5]')
finally:
shutil.rmtree('log')