diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala index 860a6eb..da62a08 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/Encoder.scala @@ -1,13 +1,12 @@ package org.clulab.scala_transformers.encoder import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} -import org.clulab.scala_transformers.encoder.math.Mathematics import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix} import java.io.DataInputStream import java.util.{HashMap => JHashMap} -class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) { +class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLinOpt: Option[NonLinearity] = None) { /** * Runs the inference using a transformer encoder over a batch of sentences * @@ -21,18 +20,11 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes val result: OrtSession.Result = encoderSession.run(inputs) val outputs = Math.fromResult(result) - if(nonLin.isDefined) { - for (matrix <- outputs) { - for (i <- 0 until Math.rows(matrix)) { - val row = Math.row(matrix, i) - for (j <- 0 until Math.cols(matrix)) { - val orig = Math.get(row, j) - Math.set(row, j, nonLin.get.compute(orig)) - } - } + nonLinOpt.foreach { nonLin => + outputs.foreach { matrix => + Math.map(matrix, nonLin.compute) } } - outputs } diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala index 1633b61..c9717e7 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/NonLinearity.scala @@ -1,15 +1,13 @@ package org.clulab.scala_transformers.encoder -import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue - -import java.lang +import org.clulab.scala_transformers.encoder.math.Mathematics.MathValue trait NonLinearity { def compute(input: MathValue): MathValue } -class ReLU extends NonLinearity { +object ReLU extends NonLinearity { override def compute(input: MathValue): MathValue = { - lang.Float.max(0, input) + scala.math.max(0, input) } } diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala index 04a0c95..69a4ce3 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/EjmlMath.scala @@ -100,6 +100,17 @@ object EjmlMath extends Math { new FMatrixRMaj(rows, cols) } + override def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit = { + val iterator = matrix.iterator(true, 0, 0, matrix.getNumRows, matrix.getNumCols) + + while (iterator.hasNext) { + val oldValue = iterator.next() + val newValue = f(oldValue) + + iterator.set(newValue) + } + } + def row(matrix: MathRowMatrix, index: Int): MathRowVector = { val result = SimpleMatrix.wrap(matrix).rows(index, index + 1).getMatrix[FMatrixRMaj] diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala index 523be6b..130c099 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Math.scala @@ -13,6 +13,7 @@ trait Math { def inplaceMatrixAddition(matrix: MathRowMatrix, colVector: MathColVector): Unit def inplaceMatrixAddition(matrix: MathRowMatrix, rowIndex: Int, rowVector: MathRowVector): Unit // def rowVectorAddition(leftRowVector: MathRowVector, rightRowVector: MathRowVector): MathRowVector + def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit def mul(leftMatrix: MathRowMatrix, rightMatrix: MathRowMatrix): MathRowMatrix def rows(matrix: MathRowMatrix): Int def cols(matrix: MathRowMatrix): Int diff --git a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala index 0aeb62d..e8a7fe6 100644 --- a/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala +++ b/encoder/src/main/scala/org/clulab/scala_transformers/encoder/math/Mathematics.scala @@ -7,6 +7,7 @@ object Mathematics { // val Math = CommonsMath // val Math = CluMath + type MathValue = Math.MathValue type MathMatrix = Math.MathRowMatrix type MathColVector = Math.MathColVector type MathRowVector = Math.MathRowVector