-
Notifications
You must be signed in to change notification settings - Fork 0
Reorganized TokenClassificationHead to use the Sequential module. Added optional nonlinearity. #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
package org.clulab.scala_transformers.encoder | ||
|
||
import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} | ||
import org.clulab.scala_transformers.encoder.math.Mathematics | ||
import org.clulab.scala_transformers.encoder.math.Mathematics.{Math, MathMatrix} | ||
|
||
import java.io.DataInputStream | ||
import java.util.{HashMap => JHashMap} | ||
|
||
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession) { | ||
class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSession, nonLin: Option[NonLinearity] = None) { | ||
/** | ||
* Runs the inference using a transformer encoder over a batch of sentences | ||
* | ||
|
@@ -19,6 +20,19 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes | |
|
||
val result: OrtSession.Result = encoderSession.run(inputs) | ||
val outputs = Math.fromResult(result) | ||
|
||
if(nonLin.isDefined) { | ||
for (matrix <- outputs) { | ||
for (i <- 0 until Math.rows(matrix)) { | ||
val row = Math.row(matrix, i) | ||
for (j <- 0 until Math.cols(matrix)) { | ||
val orig = Math.get(row, j) | ||
Math.set(row, j, nonLin.get.compute(orig)) | ||
} | ||
} | ||
} | ||
} | ||
Comment on lines
+24
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we don't have a map method. It might be nice to hide these details and maybe make it more efficient with code like def map(matrix: MathRowMatrix, f: MathValue => MathValue): Unit = {
val iterator = matrix.iterator(true, 0, 0, matrix.getNumRows, matrix.getNumCols)
while (iterator.hasNext) {
val oldValue = iterator.next()
val newValue = f(oldValue)
iterator.set(newValue)
}
} Then the Encoder would have nonLinOpt.foreach { nonLin =>
outputs.foreach { matrix =>
Math.map(matrix, nonLin.compute)
}
} There will be an example PR with other details soon. |
||
|
||
outputs | ||
} | ||
|
||
|
@@ -36,8 +50,10 @@ class Encoder(val encoderEnvironment: OrtEnvironment, val encoderSession: OrtSes | |
object Encoder { | ||
val ortEnvironment = OrtEnvironment.getEnvironment | ||
|
||
// val nonLinearity = Some(new ReLU()) | ||
|
||
protected def fromSession(ortSession: OrtSession): Encoder = | ||
new Encoder(ortEnvironment, ortSession) | ||
new Encoder(ortEnvironment, ortSession, None) // nonLinearity // ms: skipping the nonlinearity is better | ||
|
||
protected def ortSessionFromFile(fileName: String): OrtSession = | ||
ortEnvironment.createSession(fileName, new OrtSession.SessionOptions) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package org.clulab.scala_transformers.encoder | ||
|
||
import org.clulab.scala_transformers.encoder.math.EjmlMath.MathValue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a way to make this independent of EjmlMath. |
||
|
||
import java.lang | ||
|
||
trait NonLinearity { | ||
def compute(input: MathValue): MathValue | ||
} | ||
|
||
class ReLU extends NonLinearity { | ||
override def compute(input: MathValue): MathValue = { | ||
lang.Float.max(0, input) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this does stay, the
nonLin.get
might be moved outside the triply nested loops.