Skip to content

Commit 6f45c7d

Browse files
Merge pull request #59 from clulab/nonlinearity
Reorganized TokenClassificationHead to use the Sequential module. Added optional nonlinearity.
2 parents ac810c0 + e6e7249 commit 6f45c7d

File tree

7 files changed

+73
-28
lines changed

7 files changed

+73
-28
lines changed

apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierExampleApp.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import org.clulab.scala_transformers.encoder.TokenClassifier
44

55
object TokenClassifierExampleApp extends App {
66
// Choose one of these.
7-
// val tokenClassifier = TokenClassifier.fromFiles("../microsoft-deberta-v3-base-mtl/avg_export")
8-
val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/microsoft_deberta_v3_base_mtl/avg_export")
7+
val tokenClassifier = TokenClassifier.fromFiles("../microsoft-deberta-v3-base-mtl/avg_export")
8+
//val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/microsoft_deberta_v3_base_mtl/avg_export")
99
// val tokenClassifier = TokenClassifier.fromFiles("../models/google_electra_small_discriminator_mtl/avg_export")
1010
// val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/google_electra_small_discriminator_mtl/avg_export")
1111
// val tokenClassifier = TokenClassifier.fromFiles("../models/roberta_base_mtl/avg_export")
@@ -14,7 +14,8 @@ object TokenClassifierExampleApp extends App {
1414
//val words = Seq("EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", ".")
1515
//val words = Seq("John", "Doe", "went", "to", "China", ".")
1616
//val words = Seq("John", "Doe", "went", "to", "China", ".")
17-
val words = Seq("Ras1", "has", "phosphorylated", "Mek2", ".")
17+
//val words = Seq("Ras1", "has", "phosphorylated", "Mek2", ".")
18+
val words = Seq("Jack", "Doe", "and", "John", "Doe", "were", "friends", "in", "yyy", "zzz", ".")
1819

1920
println(s"Words: ${words.mkString(", ")}")
2021

encoder/src/main/python/averaging_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,28 +83,29 @@ def average_checkpoints(self,
8383
main_model = self.load_model(checkpoints[0], config, tasks)
8484
print("Done loading.")
8585

86-
self.print_some_params(main_model, "before averaging:")
86+
# TODO: update print
87+
# self.print_some_params(main_model, "before averaging:")
8788

8889
for i in range(1, len(checkpoints)): # Skip 0, which is the main_model.
8990
print(f"Loading satellite checkpoint[{i}] {checkpoints[i]}...")
9091
satellite_model = self.load_model(checkpoints[i], config, tasks)
91-
self.print_some_params(satellite_model, "satellite model:")
92+
#self.print_some_params(satellite_model, "satellite model:") # TODO
9293

9394
print("Adding its parameter weights to the main model...")
9495
for key, value in main_model.state_dict().items():
9596
if value.data.type() != "torch.LongTensor":
9697
value.data += satellite_model.state_dict()[key].data.clone()
9798
print("Done adding")
9899

99-
self.print_some_params(main_model, "after summing:")
100+
# self.print_some_params(main_model, "after summing:") # TODO
100101
if len(checkpoints) > 1:
101102
print("Computing average weights...")
102103
for value in main_model.state_dict().values():
103104
if value.data.type() != "torch.LongTensor":
104105
value.data /= len(checkpoints)
105106
print("Done computing.")
106107

107-
self.print_some_params(main_model, "after averaging:")
108+
# self.print_some_params(main_model, "after averaging:") # TODO
108109
print("Saving averaged model...")
109110
main_model.save_pretrained(path_to_save)
110111
main_model.export_model(tasks, tokenizer, path_to_export)

encoder/src/main/python/token_classifier.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from transformers.modeling_outputs import TokenClassifierOutput
1515
from typing import Any, Callable, List, Optional, Union
1616

17+
# This global variable indicates the position of the linear layer in TokenClassificationHead.classifier
18+
linear_pos = 1
19+
1720
# This class is adapted from: https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a146240
1821
class TokenClassificationModel(PreTrainedModel):
1922
def __init__(self, config: AutoConfig, transformer_name: str) -> None:
@@ -112,9 +115,10 @@ def save_pretrained(self,
112115
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
113116
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
114117
print("pickle saving done.")
115-
#super().save_pretrained(save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, **kwargs)
118+
119+
# some optional debugging prints:
116120
#for i in range(5):
117-
# key = f"output_heads.{i}.classifier.weight"
121+
# key = f"output_heads.{i}.classifier[linear_pos].weight"
118122
# print(f"{key} = {self.state_dict()[key]}")
119123

120124
def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwargs) -> None:
@@ -129,8 +133,8 @@ def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwa
129133
print("Done loading.")
130134

131135
def export_task(self, task_head, task: Task, task_checkpoint) -> None:
132-
numpy_weights = task_head.classifier.weight.cpu().detach().numpy()
133-
numpy_bias = task_head.classifier.bias.cpu().detach().numpy()
136+
numpy_weights = task_head.classifier[linear_pos].weight.cpu().detach().numpy()
137+
numpy_bias = task_head.classifier[linear_pos].bias.cpu().detach().numpy()
134138
labels = task.labels
135139
#print(f"Shape of weights: {numpy_weights.shape}")
136140
#print(f"Weights are:\n{numpy_weights}")
@@ -220,25 +224,31 @@ def export_model(self, tasks: List[Task], tokenizer: AutoTokenizer, checkpoint_d
220224
class TokenClassificationHead(nn.Module):
221225
def __init__(self, hidden_size: int, num_labels: int, task_id, dual_mode: bool=False, dropout_p: float=0.1):
222226
super().__init__()
223-
self.dropout = nn.Dropout(dropout_p)
227+
#self.dropout = nn.Dropout(dropout_p)
224228
self.dual_mode = dual_mode
225-
self.classifier = nn.Linear(
226-
hidden_size * 2 if (self.dual_mode and Parameters.use_concat) else hidden_size,
227-
num_labels
228-
)
229+
# make sure to adjust linear_pos if the position of Linear here changes!
230+
self.classifier = nn.Sequential(
231+
# nn.ReLU(), # it works better wo/ a nonlinearity here
232+
nn.Dropout(dropout_p),
233+
nn.Linear(
234+
hidden_size * 2 if (self.dual_mode and Parameters.use_concat) else hidden_size,
235+
num_labels
236+
)
237+
)
229238
self.num_labels = num_labels
230239
self.task_id = task_id
231240
self._init_weights()
232241

233242
def _init_weights(self):
234-
self.classifier.weight.data.normal_(mean=0.0, std=0.02)
235-
# torch.nn.init.xavier_normal_(self.classifier.weight.data)
236-
if self.classifier.bias is not None:
237-
self.classifier.bias.data.zero_()
243+
print("Nothing to initialize here.")
244+
self.classifier[linear_pos].weight.data.normal_(mean=0.0, std=0.02)
245+
#torch.nn.init.xavier_normal_(self.classifier[linear_pos].weight.data)
246+
if self.classifier[linear_pos].bias is not None:
247+
self.classifier[linear_pos].bias.data.zero_()
238248

239249
def summarize(self, task_id):
240250
print(f"Task {self.task_id} with {self.num_labels} labels.")
241-
print(f"Dropout is {self.dropout}")
251+
#print(f"Dropout is {self.dropout}")
242252
print(f"Classifier layer is {self.classifier}")
243253

244254
def concatenate(self, sequence_output, head_positions):
@@ -250,16 +260,12 @@ def concatenate(self, sequence_output, head_positions):
250260
# Concatenate the hidden states from modifier + head.
251261
modifier_head_states = torch.cat([sequence_output, head_states], dim=2) if Parameters.use_concat else torch.add(sequence_output, head_states)
252262
#print(f"modifier_head_states.size = {modifier_head_states.size()}")
253-
#print("EXIT")
254-
#exit(1)
255263
return modifier_head_states
256264

257265
def forward(self, sequence_output, pooled_output, head_positions, labels=None, attention_mask=None, **kwargs):
258266
#print(f"sequence_output size = {sequence_output.size()}")
259267
sequence_output_for_classification = sequence_output if not self.dual_mode else self.concatenate(sequence_output, head_positions)
260-
261-
sequence_output_dropout = self.dropout(sequence_output_for_classification)
262-
logits = self.classifier(sequence_output_dropout)
268+
logits = self.classifier(sequence_output_for_classification)
263269

264270
loss = None
265271
if labels is not None:

encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package org.clulab.scala_transformers.encoder
22

33
import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
4+
import org.clulab.scala_transformers.encoder.math.Mathematics
45
import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix}
56

67
import java.io.DataInputStream
78
import java.util.{HashMap => JHashMap}
89

9-
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession) {
10+
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) {
1011
/**
1112
* Runs the inference using a transformer encoder over a batch of sentences
1213
*
@@ -19,6 +20,19 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes
1920

2021
val result: OrtSession.Result = encoderSession.run(inputs)
2122
val outputs = Math.fromResult(result)
23+
24+
if(nonLin.isDefined) {
25+
for (matrix <- outputs) {
26+
for (i <- 0 until Math.rows(matrix)) {
27+
val row = Math.row(matrix, i)
28+
for (j <- 0 until Math.cols(matrix)) {
29+
val orig = Math.get(row, j)
30+
Math.set(row, j, nonLin.get.compute(orig))
31+
}
32+
}
33+
}
34+
}
35+
2236
outputs
2337
}
2438

@@ -36,8 +50,10 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes
3650
object Encoder {
3751
val ortEnvironment = OrtEnvironment.getEnvironment
3852

53+
// val nonLinearity = Some(new ReLU())
54+
3955
protected def fromSession(ortSession: OrtSession): Encoder =
40-
new Encoder(ortEnvironment, ortSession)
56+
new Encoder(ortEnvironment, ortSession, None) // nonLinearity // ms: skipping the nonlinearity is better
4157

4258
protected def ortSessionFromFile(fileName: String): OrtSession =
4359
ortEnvironment.createSession(fileName, new OrtSession.SessionOptions)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package org.clulab.scala_transformers.encoder
2+
3+
import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue
4+
5+
import java.lang
6+
7+
trait NonLinearity {
8+
def compute(input: MathValue): MathValue
9+
}
10+
11+
class ReLU extends NonLinearity {
12+
override def compute(input: MathValue): MathValue = {
13+
lang.Float.max(0, input)
14+
}
15+
}

encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ object EjmlMath extends Math {
132132
rowVector.get(index)
133133
}
134134

135+
def set(rowVector: MathRowVector, index: Int, value: MathValue): Unit = {
136+
assert(isRowVector(rowVector))
137+
rowVector.set(index, value)
138+
}
139+
135140
def mkMatrixFromRows(values: Array[Array[MathValue]]): MathRowMatrix = {
136141
new FMatrixRMaj(values)
137142
}

encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ trait Math {
2323
def horcat(leftRowVector: MathRowVector, rightRowVector: MathRowVector): MathRowVector
2424
def toArray(rowVector: MathRowVector): Array[MathValue]
2525
def get(rowVector: MathRowVector, index: Int): MathValue
26+
def set(rowVector: MathRowVector, index: Int, value: MathValue): Unit
2627
def mkMatrixFromRows(values: Array[Array[MathValue]]): MathRowMatrix
2728
// For this, the array is specified in column-major order,
2829
// but it should be converted to the normal representation.

0 commit comments

Comments
 (0)