From 6ff9d10ffad160b54f51e796d11238a45f90e381 Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Sat, 5 Apr 2025 21:50:50 -0700 Subject: [PATCH 1/3] Implement nonLin slightly differently --- .../scala_transformers/encoder/Encoder.scala | 15 ++++----------- .../scala_transformers/encoder/NonLinearity.scala | 6 ++---- .../encoder/math/EjmlMath.scala | 11 +++++++++++ .../scala_transformers/encoder/math/Math.scala | 1 + .../encoder/math/Mathematics.scala | 1 + 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala index 860a6eb..7f2b8d2 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala @@ -7,7 +7,7 @@ import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix} import java.io.DataInputStream import java.util.{HashMap => JHashMap} -class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) { +class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLinOpt: Option[NonLinearity] = None) { /** * Runs the inference using a transformer encoder over a batch of sentences * @@ -21,18 +21,11 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes val result: OrtSession.Result = encoderSession.run(inputs) val outputs = Math.fromResult(result) - if(nonLin.isDefined) { - for (matrix <- outputs) { - for (i <- 0 until Math.rows(matrix)) { - val row = Math.row(matrix, i) - for (j <- 0 until Math.cols(matrix)) { - val orig = Math.get(row, j) - Math.set(row, j, nonLin.get.compute(orig)) - } - } + nonLinOpt.foreach { nonLin => + outputs.foreach { matrix => + Math.map(matrix, nonLin.compute) } } - outputs } diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala index 1633b61..dd6545c 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala @@ -1,8 +1,6 @@ package org.clulab.scala_transformers.encoder -import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue - -import java.lang +import org.clulab.scala_transformers.encoder.math.Mathematics.MathValue trait NonLinearity { def compute(input: MathValue): MathValue @@ -10,6 +8,6 @@ trait NonLinearity { class ReLU extends NonLinearity { override def compute(input: MathValue): MathValue = { - lang.Float.max(0, input) + scala.math.max(0, input) } } diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala index 04a0c95..69a4ce3 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala @@ -100,6 +100,17 @@ object EjmlMath extends Math { new FMatrixRMaj(rows, cols) } + override def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit = { + val iterator = matrix.iterator(true, 0, 0, matrix.getNumRows, matrix.getNumCols) + + while (iterator.hasNext) { + val oldValue = iterator.next() + val newValue = f(oldValue) + + iterator.set(newValue) + } + } + def row(matrix: MathRowMatrix, index: Int): MathRowVector = { val result = SimpleMatrix.wrap(matrix).rows(index, index + 1).getMatrix[FMatrixRMaj] diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala index 523be6b..130c099 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala @@ -13,6 +13,7 @@ trait Math { def inplaceMatrixAddition(matrix: MathRowMatrix, colVector: MathColVector): Unit def inplaceMatrixAddition(matrix: MathRowMatrix, rowIndex: Int, rowVector: MathRowVector): Unit // def rowVectorAddition(leftRowVector: MathRowVector, rightRowVector: MathRowVector): MathRowVector + def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit def mul(leftMatrix: MathRowMatrix, rightMatrix: MathRowMatrix): MathRowMatrix def rows(matrix: MathRowMatrix): Int def cols(matrix: MathRowMatrix): Int diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala index 0aeb62d..e8a7fe6 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala @@ -7,6 +7,7 @@ object Mathematics { // val Math = CommonsMath // val Math = CluMath + type MathValue = Math.MathValue type MathMatrix = Math.MathRowMatrix type MathColVector = Math.MathColVector type MathRowVector = Math.MathRowVector From 05c689bd5e80859034980d2b5a40d47d6cfbbe66 Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Sat, 5 Apr 2025 21:57:43 -0700 Subject: [PATCH 2/3] Make ReLU an object --- .../org/clulab/scala_transformers/encoder/NonLinearity.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala index dd6545c..c9717e7 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala @@ -6,7 +6,7 @@ trait NonLinearity { def compute(input: MathValue): MathValue } -class ReLU extends NonLinearity { +object ReLU extends NonLinearity { override def compute(input: MathValue): MathValue = { scala.math.max(0, input) } From 1dd8069a3f462e8f3cbe0ad3644c04f37336bea3 Mon Sep 17 00:00:00 2001 From: Keith Alcock Date: Sat, 5 Apr 2025 22:01:18 -0700 Subject: [PATCH 3/3] Remove unused import --- .../scala/org/clulab/scala_transformers/encoder/Encoder.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala index 7f2b8d2..da62a08 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala @@ -1,7 +1,6 @@ package org.clulab.scala_transformers.encoder import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} -import org.clulab.scala_transformers.encoder.math.Mathematics import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix} import java.io.DataInputStream