diff --git a/apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierExampleApp.scala b/apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierExampleApp.scala index 6141648..b2fab96 100644 --- a/apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierExampleApp.scala +++ b/apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierExampleApp.scala @@ -4,8 +4,8 @@ import org.clulab.scala_transformers.encoder.TokenClassifier object TokenClassifierExampleApp extends App { // Choose one of these. - // val tokenClassifier = TokenClassifier.fromFiles("../microsoft-deberta-v3-base-mtl/avg_export") - val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/microsoft_deberta_v3_base_mtl/avg_export") + val tokenClassifier = TokenClassifier.fromFiles("../microsoft-deberta-v3-base-mtl/avg_export") + //val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/microsoft_deberta_v3_base_mtl/avg_export") // val tokenClassifier = TokenClassifier.fromFiles("../models/google_electra_small_discriminator_mtl/avg_export") // val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/google_electra_small_discriminator_mtl/avg_export") // val tokenClassifier = TokenClassifier.fromFiles("../models/roberta_base_mtl/avg_export") @@ -14,7 +14,8 @@ object TokenClassifierExampleApp extends App { //val words = Seq("EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", ".") //val words = Seq("John", "Doe", "went", "to", "China", ".") //val words = Seq("John", "Doe", "went", "to", "China", ".") - val words = Seq("Ras1", "has", "phosphorylated", "Mek2", ".") + //val words = Seq("Ras1", "has", "phosphorylated", "Mek2", ".") + val words = Seq("Jack", "Doe", "and", "John", "Doe", "were", "friends", "in", "yyy", "zzz", ".") println(s"Words: ${words.mkString(", ")}") diff --git a/encoder/src/main/python/averaging_trainer.py b/encoder/src/main/python/averaging_trainer.py index dce21f0..bcfb9a1 100644 --- a/encoder/src/main/python/averaging_trainer.py +++ b/encoder/src/main/python/averaging_trainer.py @@ -83,12 +83,13 @@ def average_checkpoints(self, main_model = self.load_model(checkpoints[0], config, tasks) print("Done loading.") - self.print_some_params(main_model, "before averaging:") + # TODO: update print + # self.print_some_params(main_model, "before averaging:") for i in range(1, len(checkpoints)): # Skip 0, which is the main_model. print(f"Loading satellite checkpoint[{i}] {checkpoints[i]}...") satellite_model = self.load_model(checkpoints[i], config, tasks) - self.print_some_params(satellite_model, "satellite model:") + #self.print_some_params(satellite_model, "satellite model:") # TODO print("Adding its parameter weights to the main model...") for key, value in main_model.state_dict().items(): @@ -96,7 +97,7 @@ def average_checkpoints(self, value.data += satellite_model.state_dict()[key].data.clone() print("Done adding") - self.print_some_params(main_model, "after summing:") + # self.print_some_params(main_model, "after summing:") # TODO if len(checkpoints) > 1: print("Computing average weights...") for value in main_model.state_dict().values(): @@ -104,7 +105,7 @@ def average_checkpoints(self, value.data /= len(checkpoints) print("Done computing.") - self.print_some_params(main_model, "after averaging:") + # self.print_some_params(main_model, "after averaging:") # TODO print("Saving averaged model...") main_model.save_pretrained(path_to_save) main_model.export_model(tasks, tokenizer, path_to_export) diff --git a/encoder/src/main/python/token_classifier.py b/encoder/src/main/python/token_classifier.py index 7f667b6..d4c1932 100644 --- a/encoder/src/main/python/token_classifier.py +++ b/encoder/src/main/python/token_classifier.py @@ -14,6 +14,9 @@ from transformers.modeling_outputs import TokenClassifierOutput from typing import Any, Callable, List, Optional, Union +# This global variable indicates the position of the linear layer in TokenClassificationHead.classifier +linear_pos = 1 + # This class is adapted from: https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a146240 class TokenClassificationModel(PreTrainedModel): def __init__(self, config: AutoConfig, transformer_name: str) -> None: @@ -112,9 +115,10 @@ def save_pretrained(self, # https://pytorch.org/tutorials/beginner/saving_loading_models.html torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin") print("pickle saving done.") - #super().save_pretrained(save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, **kwargs) + + # some optional debugging prints: #for i in range(5): - # key = f"output_heads.{i}.classifier.weight" + # key = f"output_heads.{i}.classifier[linear_pos].weight" # print(f"{key} = {self.state_dict()[key]}") def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwargs) -> None: @@ -129,8 +133,8 @@ def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwa print("Done loading.") def export_task(self, task_head, task: Task, task_checkpoint) -> None: - numpy_weights = task_head.classifier.weight.cpu().detach().numpy() - numpy_bias = task_head.classifier.bias.cpu().detach().numpy() + numpy_weights = task_head.classifier[linear_pos].weight.cpu().detach().numpy() + numpy_bias = task_head.classifier[linear_pos].bias.cpu().detach().numpy() labels = task.labels #print(f"Shape of weights: {numpy_weights.shape}") #print(f"Weights are:\n{numpy_weights}") @@ -220,25 +224,31 @@ def export_model(self, tasks: List[Task], tokenizer: AutoTokenizer, checkpoint_d class TokenClassificationHead(nn.Module): def __init__(self, hidden_size: int, num_labels: int, task_id, dual_mode: bool=False, dropout_p: float=0.1): super().__init__() - self.dropout = nn.Dropout(dropout_p) + #self.dropout = nn.Dropout(dropout_p) self.dual_mode = dual_mode - self.classifier = nn.Linear( - hidden_size * 2 if (self.dual_mode and Parameters.use_concat) else hidden_size, - num_labels - ) + # make sure to adjust linear_pos if the position of Linear here changes! + self.classifier = nn.Sequential( + # nn.ReLU(), # it works better wo/ a nonlinearity here + nn.Dropout(dropout_p), + nn.Linear( + hidden_size * 2 if (self.dual_mode and Parameters.use_concat) else hidden_size, + num_labels + ) + ) self.num_labels = num_labels self.task_id = task_id self._init_weights() def _init_weights(self): - self.classifier.weight.data.normal_(mean=0.0, std=0.02) - # torch.nn.init.xavier_normal_(self.classifier.weight.data) - if self.classifier.bias is not None: - self.classifier.bias.data.zero_() + print("Nothing to initialize here.") + self.classifier[linear_pos].weight.data.normal_(mean=0.0, std=0.02) + #torch.nn.init.xavier_normal_(self.classifier[linear_pos].weight.data) + if self.classifier[linear_pos].bias is not None: + self.classifier[linear_pos].bias.data.zero_() def summarize(self, task_id): print(f"Task {self.task_id} with {self.num_labels} labels.") - print(f"Dropout is {self.dropout}") + #print(f"Dropout is {self.dropout}") print(f"Classifier layer is {self.classifier}") def concatenate(self, sequence_output, head_positions): @@ -250,16 +260,12 @@ def concatenate(self, sequence_output, head_positions): # Concatenate the hidden states from modifier + head. modifier_head_states = torch.cat([sequence_output, head_states], dim=2) if Parameters.use_concat else torch.add(sequence_output, head_states) #print(f"modifier_head_states.size = {modifier_head_states.size()}") - #print("EXIT") - #exit(1) return modifier_head_states def forward(self, sequence_output, pooled_output, head_positions, labels=None, attention_mask=None, **kwargs): #print(f"sequence_output size = {sequence_output.size()}") sequence_output_for_classification = sequence_output if not self.dual_mode else self.concatenate(sequence_output, head_positions) - - sequence_output_dropout = self.dropout(sequence_output_for_classification) - logits = self.classifier(sequence_output_dropout) + logits = self.classifier(sequence_output_for_classification) loss = None if labels is not None: diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala index 19fa2d1..860a6eb 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala @@ -1,12 +1,13 @@ package org.clulab.scala_transformers.encoder import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} +import org.clulab.scala_transformers.encoder.math.Mathematics import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix} import java.io.DataInputStream import java.util.{HashMap => JHashMap} -class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession) { +class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) { /** * Runs the inference using a transformer encoder over a batch of sentences * @@ -19,6 +20,19 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes val result: OrtSession.Result = encoderSession.run(inputs) val outputs = Math.fromResult(result) + + if(nonLin.isDefined) { + for (matrix <- outputs) { + for (i <- 0 until Math.rows(matrix)) { + val row = Math.row(matrix, i) + for (j <- 0 until Math.cols(matrix)) { + val orig = Math.get(row, j) + Math.set(row, j, nonLin.get.compute(orig)) + } + } + } + } + outputs } @@ -36,8 +50,10 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes object Encoder { val ortEnvironment = OrtEnvironment.getEnvironment + // val nonLinearity = Some(new ReLU()) + protected def fromSession(ortSession: OrtSession): Encoder = - new Encoder(ortEnvironment, ortSession) + new Encoder(ortEnvironment, ortSession, None) // nonLinearity // ms: skipping the nonlinearity is better protected def ortSessionFromFile(fileName: String): OrtSession = ortEnvironment.createSession(fileName, new OrtSession.SessionOptions) diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala new file mode 100644 index 0000000..1633b61 --- /dev/null +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala @@ -0,0 +1,15 @@ +package org.clulab.scala_transformers.encoder + +import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue + +import java.lang + +trait NonLinearity { + def compute(input: MathValue): MathValue +} + +class ReLU extends NonLinearity { + override def compute(input: MathValue): MathValue = { + lang.Float.max(0, input) + } +} diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala index 655d446..04a0c95 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala @@ -132,6 +132,11 @@ object EjmlMath extends Math { rowVector.get(index) } + def set(rowVector: MathRowVector, index: Int, value: MathValue): Unit = { + assert(isRowVector(rowVector)) + rowVector.set(index, value) + } + def mkMatrixFromRows(values: Array[Array[MathValue]]): MathRowMatrix = { new FMatrixRMaj(values) } diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala index c968f98..523be6b 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala @@ -23,6 +23,7 @@ trait Math { def horcat(leftRowVector: MathRowVector, rightRowVector: MathRowVector): MathRowVector def toArray(rowVector: MathRowVector): Array[MathValue] def get(rowVector: MathRowVector, index: Int): MathValue + def set(rowVector: MathRowVector, index: Int, value: MathValue): Unit def mkMatrixFromRows(values: Array[Array[MathValue]]): MathRowMatrix // For this, the array is specified in column-major order, // but it should be converted to the normal representation.