Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import org.clulab.scala_transformers.encoder.TokenClassifier

object TokenClassifierExampleApp extends App {
// Choose one of these.
// val tokenClassifier = TokenClassifier.fromFiles("../microsoft-deberta-v3-base-mtl/avg_export")
val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/microsoft_deberta_v3_base_mtl/avg_export")
val tokenClassifier = TokenClassifier.fromFiles("../microsoft-deberta-v3-base-mtl/avg_export")
//val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/microsoft_deberta_v3_base_mtl/avg_export")
// val tokenClassifier = TokenClassifier.fromFiles("../models/google_electra_small_discriminator_mtl/avg_export")
// val tokenClassifier = TokenClassifier.fromResources("/org/clulab/scala_transformers/models/google_electra_small_discriminator_mtl/avg_export")
// val tokenClassifier = TokenClassifier.fromFiles("../models/roberta_base_mtl/avg_export")
Expand All @@ -14,7 +14,8 @@ object TokenClassifierExampleApp extends App {
//val words = Seq("EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", ".")
//val words = Seq("John", "Doe", "went", "to", "China", ".")
//val words = Seq("John", "Doe", "went", "to", "China", ".")
val words = Seq("Ras1", "has", "phosphorylated", "Mek2", ".")
//val words = Seq("Ras1", "has", "phosphorylated", "Mek2", ".")
val words = Seq("Jack", "Doe", "and", "John", "Doe", "were", "friends", "in", "yyy", "zzz", ".")

println(s"Words: ${words.mkString(", ")}")

Expand Down
9 changes: 5 additions & 4 deletions encoder/src/main/python/averaging_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,29 @@ def average_checkpoints(self,
main_model = self.load_model(checkpoints[0], config, tasks)
print("Done loading.")

self.print_some_params(main_model, "before averaging:")
# TODO: update print
# self.print_some_params(main_model, "before averaging:")

for i in range(1, len(checkpoints)): # Skip 0, which is the main_model.
print(f"Loading satellite checkpoint[{i}] {checkpoints[i]}...")
satellite_model = self.load_model(checkpoints[i], config, tasks)
self.print_some_params(satellite_model, "satellite model:")
#self.print_some_params(satellite_model, "satellite model:") # TODO

print("Adding its parameter weights to the main model...")
for key, value in main_model.state_dict().items():
if value.data.type() != "torch.LongTensor":
value.data += satellite_model.state_dict()[key].data.clone()
print("Done adding")

self.print_some_params(main_model, "after summing:")
# self.print_some_params(main_model, "after summing:") # TODO
if len(checkpoints) > 1:
print("Computing average weights...")
for value in main_model.state_dict().values():
if value.data.type() != "torch.LongTensor":
value.data /= len(checkpoints)
print("Done computing.")

self.print_some_params(main_model, "after averaging:")
# self.print_some_params(main_model, "after averaging:") # TODO
print("Saving averaged model...")
main_model.save_pretrained(path_to_save)
main_model.export_model(tasks, tokenizer, path_to_export)
Expand Down
44 changes: 25 additions & 19 deletions encoder/src/main/python/token_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from transformers.modeling_outputs import TokenClassifierOutput
from typing import Any, Callable, List, Optional, Union

# This global variable indicates the position of the linear layer in TokenClassificationHead.classifier
linear_pos = 1

# This class is adapted from: https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a146240
class TokenClassificationModel(PreTrainedModel):
def __init__(self, config: AutoConfig, transformer_name: str) -> None:
Expand Down Expand Up @@ -112,9 +115,10 @@ def save_pretrained(self,
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
print("pickle saving done.")
#super().save_pretrained(save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, **kwargs)

# some optional debugging prints:
#for i in range(5):
# key = f"output_heads.{i}.classifier.weight"
# key = f"output_heads.{i}.classifier[linear_pos].weight"
# print(f"{key} = {self.state_dict()[key]}")

def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwargs) -> None:
Expand All @@ -129,8 +133,8 @@ def from_pretrained(self, pretrained_model_name_or_path: str, *model_args, **kwa
print("Done loading.")

def export_task(self, task_head, task: Task, task_checkpoint) -> None:
numpy_weights = task_head.classifier.weight.cpu().detach().numpy()
numpy_bias = task_head.classifier.bias.cpu().detach().numpy()
numpy_weights = task_head.classifier[linear_pos].weight.cpu().detach().numpy()
numpy_bias = task_head.classifier[linear_pos].bias.cpu().detach().numpy()
labels = task.labels
#print(f"Shape of weights: {numpy_weights.shape}")
#print(f"Weights are:\n{numpy_weights}")
Expand Down Expand Up @@ -220,25 +224,31 @@ def export_model(self, tasks: List[Task], tokenizer: AutoTokenizer, checkpoint_d
class TokenClassificationHead(nn.Module):
def __init__(self, hidden_size: int, num_labels: int, task_id, dual_mode: bool=False, dropout_p: float=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout_p)
#self.dropout = nn.Dropout(dropout_p)
self.dual_mode = dual_mode
self.classifier = nn.Linear(
hidden_size * 2 if (self.dual_mode and Parameters.use_concat) else hidden_size,
num_labels
)
# make sure to adjust linear_pos if the position of Linear here changes!
self.classifier = nn.Sequential(
# nn.ReLU(), # it works better wo/ a nonlinearity here
nn.Dropout(dropout_p),
nn.Linear(
hidden_size * 2 if (self.dual_mode and Parameters.use_concat) else hidden_size,
num_labels
)
)
self.num_labels = num_labels
self.task_id = task_id
self._init_weights()

def _init_weights(self):
self.classifier.weight.data.normal_(mean=0.0, std=0.02)
# torch.nn.init.xavier_normal_(self.classifier.weight.data)
if self.classifier.bias is not None:
self.classifier.bias.data.zero_()
print("Nothing to initialize here.")
self.classifier[linear_pos].weight.data.normal_(mean=0.0, std=0.02)
#torch.nn.init.xavier_normal_(self.classifier[linear_pos].weight.data)
if self.classifier[linear_pos].bias is not None:
self.classifier[linear_pos].bias.data.zero_()

def summarize(self, task_id):
print(f"Task {self.task_id} with {self.num_labels} labels.")
print(f"Dropout is {self.dropout}")
#print(f"Dropout is {self.dropout}")
print(f"Classifier layer is {self.classifier}")

def concatenate(self, sequence_output, head_positions):
Expand All @@ -250,16 +260,12 @@ def concatenate(self, sequence_output, head_positions):
# Concatenate the hidden states from modifier + head.
modifier_head_states = torch.cat([sequence_output, head_states], dim=2) if Parameters.use_concat else torch.add(sequence_output, head_states)
#print(f"modifier_head_states.size = {modifier_head_states.size()}")
#print("EXIT")
#exit(1)
return modifier_head_states

def forward(self, sequence_output, pooled_output, head_positions, labels=None, attention_mask=None, **kwargs):
#print(f"sequence_output size = {sequence_output.size()}")
sequence_output_for_classification = sequence_output if not self.dual_mode else self.concatenate(sequence_output, head_positions)

sequence_output_dropout = self.dropout(sequence_output_for_classification)
logits = self.classifier(sequence_output_dropout)
logits = self.classifier(sequence_output_for_classification)

loss = None
if labels is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package org.clulab.scala_transformers.encoder

import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
import org.clulab.scala_transformers.encoder.math.Mathematics
import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix}

import java.io.DataInputStream
import java.util.{HashMap => JHashMap}

class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession) {
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) {
/**
* Runs the inference using a transformer encoder over a batch of sentences
*
Expand All @@ -19,6 +20,19 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes

val result: OrtSession.Result = encoderSession.run(inputs)
val outputs = Math.fromResult(result)

if(nonLin.isDefined) {
for (matrix <- outputs) {
for (i <- 0 until Math.rows(matrix)) {
val row = Math.row(matrix, i)
for (j <- 0 until Math.cols(matrix)) {
val orig = Math.get(row, j)
Math.set(row, j, nonLin.get.compute(orig))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this does stay, the nonLin.get might be moved outside the triply nested loops.

}
}
}
}
Comment on lines +24 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't have a map method. It might be nice to hide these details and maybe make it more efficient with code like

  def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit = {
    val iterator = matrix.iterator(true, 0, 0, matrix.getNumRows, matrix.getNumCols)
    
    while (iterator.hasNext) {
      val oldValue = iterator.next()
      val newValue = f(oldValue)

      iterator.set(newValue)
    }
  }

Then the Encoder would have

    nonLinOpt.foreach { nonLin =>
      outputs.foreach { matrix =>
        Math.map(matrix, nonLin.compute)
      }
    }

There will be an example PR with other details soon.


outputs
}

Expand All @@ -36,8 +50,10 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes
object Encoder {
val ortEnvironment = OrtEnvironment.getEnvironment

// val nonLinearity = Some(new ReLU())

protected def fromSession(ortSession: OrtSession): Encoder =
new Encoder(ortEnvironment, ortSession)
new Encoder(ortEnvironment, ortSession, None) // nonLinearity // ms: skipping the nonlinearity is better

protected def ortSessionFromFile(fileName: String): OrtSession =
ortEnvironment.createSession(fileName, new OrtSession.SessionOptions)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.clulab.scala_transformers.encoder

import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a way to make this independent of EjmlMath.


import java.lang

trait NonLinearity {
def compute(input: MathValue): MathValue
}

class ReLU extends NonLinearity {
override def compute(input: MathValue): MathValue = {
lang.Float.max(0, input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ object EjmlMath extends Math {
rowVector.get(index)
}

def set(rowVector: MathRowVector, index: Int, value: MathValue): Unit = {
assert(isRowVector(rowVector))
rowVector.set(index, value)
}

def mkMatrixFromRows(values: Array[Array[MathValue]]): MathRowMatrix = {
new FMatrixRMaj(values)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait Math {
def horcat(leftRowVector: MathRowVector, rightRowVector: MathRowVector): MathRowVector
def toArray(rowVector: MathRowVector): Array[MathValue]
def get(rowVector: MathRowVector, index: Int): MathValue
def set(rowVector: MathRowVector, index: Int, value: MathValue): Unit
def mkMatrixFromRows(values: Array[Array[MathValue]]): MathRowMatrix
// For this, the array is specified in column-major order,
// but it should be converted to the normal representation.
Expand Down
Loading