7
7
from parlai .agents .bert_ranker .helpers import (
8
8
BertWrapper ,
9
9
get_bert_optimizer ,
10
- MODEL_PATH
10
+ MODEL_PATH ,
11
11
)
12
12
from parlai .core .utils import load_opt_file
13
13
from parlai .core .torch_agent import History
18
18
from collections import deque
19
19
import os
20
20
import torch
21
+
21
22
try :
22
23
from pytorch_pretrained_bert import BertModel
23
24
except ImportError :
24
- raise Exception (("BERT rankers needs pytorch-pretrained-BERT installed. \n "
25
- "pip install pytorch-pretrained-bert" ))
25
+ raise Exception (
26
+ (
27
+ "BERT rankers needs pytorch-pretrained-BERT installed. \n "
28
+ "pip install pytorch-pretrained-bert"
29
+ )
30
+ )
26
31
27
32
28
33
class BertClassifierHistory (History ):
@@ -49,11 +54,13 @@ class BertClassifierAgent(TorchClassifierAgent):
49
54
"""
50
55
Classifier based on Hugging Face BERT implementation.
51
56
"""
57
+
52
58
def __init__ (self , opt , shared = None ):
53
59
# download pretrained models
54
60
download (opt ['datapath' ])
55
- self .pretrained_path = os .path .join (opt ['datapath' ], 'models' ,
56
- 'bert_models' , MODEL_PATH )
61
+ self .pretrained_path = os .path .join (
62
+ opt ['datapath' ], 'models' , 'bert_models' , MODEL_PATH
63
+ )
57
64
opt ['pretrained_path' ] = self .pretrained_path
58
65
self ._upgrade_opt (opt )
59
66
self .add_cls_token = opt .get ('add_cls_token' , True )
@@ -68,21 +75,34 @@ def history_class(cls):
68
75
def add_cmdline_args (parser ):
69
76
TorchClassifierAgent .add_cmdline_args (parser )
70
77
parser = parser .add_argument_group ('BERT Classifier Arguments' )
71
- parser .add_argument ('--type-optimization' , type = str ,
72
- default = 'all_encoder_layers' ,
73
- choices = ['additional_layers' , 'top_layer' ,
74
- 'top4_layers' , 'all_encoder_layers' ,
75
- 'all' ],
76
- help = 'which part of the encoders do we optimize '
77
- '(defaults to all layers)' )
78
- parser .add_argument ('--add-cls-token' , type = 'bool' , default = True ,
79
- help = 'add [CLS] token to text vec' )
80
- parser .add_argument ('--sep-last-utt' , type = 'bool' , default = False ,
81
- help = 'separate the last utterance into a different'
82
- 'segment with [SEP] token in between' )
83
- parser .set_defaults (
84
- dict_maxexs = 0 , # skip building dictionary
78
+ parser .add_argument (
79
+ '--type-optimization' ,
80
+ type = str ,
81
+ default = 'all_encoder_layers' ,
82
+ choices = [
83
+ 'additional_layers' ,
84
+ 'top_layer' ,
85
+ 'top4_layers' ,
86
+ 'all_encoder_layers' ,
87
+ 'all' ,
88
+ ],
89
+ help = 'which part of the encoders do we optimize '
90
+ '(defaults to all layers)' ,
85
91
)
92
+ parser .add_argument (
93
+ '--add-cls-token' ,
94
+ type = 'bool' ,
95
+ default = True ,
96
+ help = 'add [CLS] token to text vec' ,
97
+ )
98
+ parser .add_argument (
99
+ '--sep-last-utt' ,
100
+ type = 'bool' ,
101
+ default = False ,
102
+ help = 'separate the last utterance into a different'
103
+ 'segment with [SEP] token in between' ,
104
+ )
105
+ parser .set_defaults (dict_maxexs = 0 ) # skip building dictionary
86
106
87
107
@staticmethod
88
108
def dictionary_class ():
@@ -95,23 +115,20 @@ def _upgrade_opt(self, opt):
95
115
old_opt = load_opt_file (model_opt )
96
116
if 'add_cls_token' not in old_opt :
97
117
# old model, make this default to False
98
- warn_once (
99
- 'Old model: overriding `add_cls_token` to False.'
100
- )
118
+ warn_once ('Old model: overriding `add_cls_token` to False.' )
101
119
opt ['add_cls_token' ] = False
102
120
return
103
121
104
122
def build_model (self ):
105
123
num_classes = len (self .class_list )
106
124
self .model = BertWrapper (
107
- BertModel .from_pretrained (self .pretrained_path ),
108
- num_classes
125
+ BertModel .from_pretrained (self .pretrained_path ), num_classes
109
126
)
110
127
111
128
def init_optim (self , params , optim_states = None , saved_optim_type = None ):
112
- self .optimizer = get_bert_optimizer ([ self . model ],
113
- self .opt ['type_optimization' ],
114
- self . opt [ 'learningrate' ] )
129
+ self .optimizer = get_bert_optimizer (
130
+ [ self . model ], self .opt ['type_optimization' ], self . opt [ 'learningrate' ]
131
+ )
115
132
116
133
def _set_text_vec (self , * args , ** kwargs ):
117
134
obs = super ()._set_text_vec (* args , ** kwargs )
0 commit comments