diff --git a/README.md b/README.md index 94c4bbc..a508625 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ python extract_attention.py --preprocessed_data_file --bert_ ``` The following optional arguments can also be added: * `--max_sequence_length`: Maximum input sequence length after tokenization (default is 128). +* `--num_docs`: Number of documents to use (default=1000). * `--batch_size`: Batch size when running BERT over examples (default is 16). * `--debug`: Use a tiny BERT model for fast debugging. * `--cased`: Do not lowercase the input text. diff --git a/preprocess_unlabeled.py b/preprocess_unlabeled.py index 404ce5d..2caa698 100644 --- a/preprocess_unlabeled.py +++ b/preprocess_unlabeled.py @@ -89,11 +89,11 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--data-file", required=True, + "--data_file", required=True, help="Location of input data; see the README for expected data format.") - parser.add_argument("--bert-dir", required=True, + parser.add_argument("--bert_dir", required=True, help="Location of the pre-trained BERT model.") - parser.add_argument("--num-docs", default=1000, type=int, + parser.add_argument("--num_docs", default=1000, type=int, help="Number of documents to use (default=1000).") parser.add_argument("--cased", default=False, action='store_true', help="Don't lowercase the input.") diff --git a/utils.py b/utils.py index 0cf0df6..14e8c49 100644 --- a/utils.py +++ b/utils.py @@ -15,7 +15,8 @@ def load_json(path): def write_json(o, path): - tf.gfile.MakeDirs(path.rsplit('/', 1)[0]) + if '/' in path: + tf.gfile.MakeDirs(path.rsplit('/', 1)[0]) with tf.gfile.GFile(path, 'w') as f: json.dump(o, f)