Skip to content

Commit

Permalink
Add NNAPI E2E test for Android java package (#8912)
Browse files Browse the repository at this point in the history
* Add NNAPI E2E test for Android java package

* address cr comment
  • Loading branch information
guoyu-wang authored Sep 1, 2021
1 parent a9a0d3f commit 8404a2d
Showing 1 changed file with 34 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package ai.onnxruntime.example.javavalidator
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtException
import ai.onnxruntime.OrtProvider
import ai.onnxruntime.OrtSession.SessionOptions
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Assert
Expand All @@ -12,14 +14,44 @@ import org.junit.runner.RunWith
import java.io.IOException
import java.util.*

private const val TAG = "ORTAndroidTest"

@RunWith(AndroidJUnit4::class)
class SimpleTest {
@Test
@Throws(OrtException::class, IOException::class)
fun runSigmoidModelTest() {
for (intraOpNumThreads in 1..4) {
runSigmoidModelTestImpl(intraOpNumThreads)
}
}

@Test
fun runSigmoidModelTestNNAPI() {
runSigmoidModelTestImpl(1, true)
}

@Throws(IOException::class)
private fun readModel(fileName: String): ByteArray {
return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName)
.readBytes()
}

@Throws(OrtException::class, IOException::class)
fun runSigmoidModelTestImpl(intraOpNumThreads: Int, useNNAPI: Boolean = false) {
Log.println(Log.INFO, TAG, "Testing with intraOpNumThreads=$intraOpNumThreads")
Log.println(Log.INFO, TAG, "Testing with useNNAPI=$useNNAPI")
val env = OrtEnvironment.getEnvironment()
env.use {
val opts = SessionOptions()
opts.setIntraOpNumThreads(intraOpNumThreads)
if (useNNAPI) {
if (OrtEnvironment.getAvailableProviders().contains(OrtProvider.NNAPI)) {
opts.addNnapi()
} else {
Log.println(Log.INFO, TAG, "NO NNAPI EP available, skip the test")
return
}
}
opts.use {
val session = env.createSession(readModel("sigmoid.ort"), opts)
session.use {
Expand All @@ -40,6 +72,7 @@ class SimpleTest {
inputTensor.use {
val output = session.run(Collections.singletonMap(inputName, inputTensor))
output.use {
@Suppress("UNCHECKED_CAST")
val rawOutput = output[0].value as Array<Array<FloatArray>>
for (i in 0..2) {
for (j in 0..3) {
Expand All @@ -58,11 +91,4 @@ class SimpleTest {
}
}
}


@Throws(IOException::class)
private fun readModel(fileName: String): ByteArray {
return InstrumentationRegistry.getInstrumentation().context.assets.open(fileName)
.readBytes()
}
}

0 comments on commit 8404a2d

Please sign in to comment.