14
14
from transformers .modeling_outputs import TokenClassifierOutput
15
15
from typing import Any , Callable , List , Optional , Union
16
16
17
+ # This global variable indicates the position of the linear layer in TokenClassificationHead.classifier
18
+ linear_pos = 1
19
+
17
20
# This class is adapted from: https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a146240
18
21
class TokenClassificationModel (PreTrainedModel ):
19
22
def __init__ (self , config : AutoConfig , transformer_name : str ) -> None :
@@ -112,9 +115,10 @@ def save_pretrained(self,
112
115
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
113
116
torch .save (self .state_dict (), f"{ save_directory } /pytorch_model.bin" )
114
117
print ("pickle saving done." )
115
- #super().save_pretrained(save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, **kwargs)
118
+
119
+ # some optional debugging prints:
116
120
#for i in range(5):
117
- # key = f"output_heads.{i}.classifier.weight"
121
+ # key = f"output_heads.{i}.classifier[linear_pos] .weight"
118
122
# print(f"{key} = {self.state_dict()[key]}")
119
123
120
124
def from_pretrained (self , pretrained_model_name_or_path : str , * model_args , ** kwargs ) -> None :
@@ -129,8 +133,8 @@ def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwa
129
133
print ("Done loading." )
130
134
131
135
def export_task (self , task_head , task : Task , task_checkpoint ) -> None :
132
- numpy_weights = task_head .classifier .weight .cpu ().detach ().numpy ()
133
- numpy_bias = task_head .classifier .bias .cpu ().detach ().numpy ()
136
+ numpy_weights = task_head .classifier [ linear_pos ] .weight .cpu ().detach ().numpy ()
137
+ numpy_bias = task_head .classifier [ linear_pos ] .bias .cpu ().detach ().numpy ()
134
138
labels = task .labels
135
139
#print(f"Shape of weights: {numpy_weights.shape}")
136
140
#print(f"Weights are:\n{numpy_weights}")
@@ -220,25 +224,31 @@ def export_model(self, tasks: List[Task], tokenizer: AutoTokenizer, checkpoint_d
220
224
class TokenClassificationHead (nn .Module ):
221
225
def __init__ (self , hidden_size : int , num_labels : int , task_id , dual_mode : bool = False , dropout_p : float = 0.1 ):
222
226
super ().__init__ ()
223
- self .dropout = nn .Dropout (dropout_p )
227
+ # self.dropout = nn.Dropout(dropout_p)
224
228
self .dual_mode = dual_mode
225
- self .classifier = nn .Linear (
226
- hidden_size * 2 if (self .dual_mode and Parameters .use_concat ) else hidden_size ,
227
- num_labels
228
- )
229
+ # make sure to adjust linear_pos if the position of Linear here changes!
230
+ self .classifier = nn .Sequential (
231
+ # nn.ReLU(), # it works better wo/ a nonlinearity here
232
+ nn .Dropout (dropout_p ),
233
+ nn .Linear (
234
+ hidden_size * 2 if (self .dual_mode and Parameters .use_concat ) else hidden_size ,
235
+ num_labels
236
+ )
237
+ )
229
238
self .num_labels = num_labels
230
239
self .task_id = task_id
231
240
self ._init_weights ()
232
241
233
242
def _init_weights (self ):
234
- self .classifier .weight .data .normal_ (mean = 0.0 , std = 0.02 )
235
- # torch.nn.init.xavier_normal_(self.classifier.weight.data)
236
- if self .classifier .bias is not None :
237
- self .classifier .bias .data .zero_ ()
243
+ print ("Nothing to initialize here." )
244
+ self .classifier [linear_pos ].weight .data .normal_ (mean = 0.0 , std = 0.02 )
245
+ #torch.nn.init.xavier_normal_(self.classifier[linear_pos].weight.data)
246
+ if self .classifier [linear_pos ].bias is not None :
247
+ self .classifier [linear_pos ].bias .data .zero_ ()
238
248
239
249
def summarize (self , task_id ):
240
250
print (f"Task { self .task_id } with { self .num_labels } labels." )
241
- print (f"Dropout is { self .dropout } " )
251
+ # print(f"Dropout is {self.dropout}")
242
252
print (f"Classifier layer is { self .classifier } " )
243
253
244
254
def concatenate (self , sequence_output , head_positions ):
@@ -250,16 +260,12 @@ def concatenate(self, sequence_output, head_positions):
250
260
# Concatenate the hidden states from modifier + head.
251
261
modifier_head_states = torch .cat ([sequence_output , head_states ], dim = 2 ) if Parameters .use_concat else torch .add (sequence_output , head_states )
252
262
#print(f"modifier_head_states.size = {modifier_head_states.size()}")
253
- #print("EXIT")
254
- #exit(1)
255
263
return modifier_head_states
256
264
257
265
def forward (self , sequence_output , pooled_output , head_positions , labels = None , attention_mask = None , ** kwargs ):
258
266
#print(f"sequence_output size = {sequence_output.size()}")
259
267
sequence_output_for_classification = sequence_output if not self .dual_mode else self .concatenate (sequence_output , head_positions )
260
-
261
- sequence_output_dropout = self .dropout (sequence_output_for_classification )
262
- logits = self .classifier (sequence_output_dropout )
268
+ logits = self .classifier (sequence_output_for_classification )
263
269
264
270
loss = None
265
271
if labels is not None :
0 commit comments