From fc72cce88e31becc31b0b22c107e1a785f502a88 Mon Sep 17 00:00:00 2001 From: Julia Beliaeva Date: Thu, 1 Dec 2022 14:59:29 +0100 Subject: [PATCH] Convert all the model types without constructor parameters to objects (#492) --- .../cv/efficicentnet/EfficientNet4Lite.kt | 2 +- .../examples/onnx/cv/resnet/ResNet18.kt | 2 +- .../onnx/cv/resnet/ResNet18LightAPI.kt | 2 +- .../onnx/ExecutionProvidersTestSuite.kt | 24 +++++++++---------- .../examples/onnx/ModelCopyingTestSuite.kt | 22 ++++++++--------- .../examples/onnx/cv/OnnxResNetTestSuite.kt | 18 +++++++------- .../kotlinx/dl/onnx/inference/ONNXModels.kt | 4 ++-- .../kotlinx/dl/onnx/inference/ONNXModels.kt | 22 ++++++++--------- 8 files changed, 48 insertions(+), 48 deletions(-) diff --git a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNet4Lite.kt b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNet4Lite.kt index 711b7892b..2388ef09a 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNet4Lite.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/efficicentnet/EfficientNet4Lite.kt @@ -29,7 +29,7 @@ import java.io.File fun efficientNet4LitePrediction() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val modelType = ONNXModels.CV.EfficientNet4Lite() + val modelType = ONNXModels.CV.EfficientNet4Lite val model = modelHub.loadModel(modelType) model.printSummary() diff --git a/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18.kt b/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18.kt index 02aae06dc..673bda8af 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18.kt @@ -16,7 +16,7 @@ import org.jetbrains.kotlinx.dl.onnx.inference.ONNXModels * - Special preprocessing (used in ResNet'18 during training on ImageNet dataset) is applied to each image before prediction. */ fun resnet18prediction() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet18()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet18) } /** */ diff --git a/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18LightAPI.kt b/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18LightAPI.kt index 87217841d..7bcad15a6 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18LightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/resnet/ResNet18LightAPI.kt @@ -20,7 +20,7 @@ fun resnet18LightAPIPrediction() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val model = ONNXModels.CV.ResNet18().pretrainedModel(modelHub) + val model = ONNXModels.CV.ResNet18.pretrainedModel(modelHub) model.printSummary() model.use { diff --git a/examples/src/test/kotlin/examples/onnx/ExecutionProvidersTestSuite.kt b/examples/src/test/kotlin/examples/onnx/ExecutionProvidersTestSuite.kt index d57136769..4dc8e8d4d 100644 --- a/examples/src/test/kotlin/examples/onnx/ExecutionProvidersTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/ExecutionProvidersTestSuite.kt @@ -17,16 +17,16 @@ import java.io.File class ExecutionProvidersTestSuite { private fun resnetModelsInference(executionProvider: ExecutionProvider) { val modelsToTest = listOf( - ONNXModels.CV.ResNet101(), - ONNXModels.CV.ResNet101v2(), - ONNXModels.CV.ResNet152(), - ONNXModels.CV.ResNet152v2(), - ONNXModels.CV.ResNet18(), - ONNXModels.CV.ResNet18v2(), - ONNXModels.CV.ResNet34(), - ONNXModels.CV.ResNet34v2(), - ONNXModels.CV.ResNet50(), - ONNXModels.CV.ResNet50v2(), + ONNXModels.CV.ResNet101, + ONNXModels.CV.ResNet101v2, + ONNXModels.CV.ResNet152, + ONNXModels.CV.ResNet152v2, + ONNXModels.CV.ResNet18, + ONNXModels.CV.ResNet18v2, + ONNXModels.CV.ResNet34, + ONNXModels.CV.ResNet34v2, + ONNXModels.CV.ResNet50, + ONNXModels.CV.ResNet50v2, ONNXModels.CV.ResNet50custom, ) @@ -78,7 +78,7 @@ class ExecutionProvidersTestSuite { @Test fun executionProvidersDuplicatesTest() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val model = modelHub.loadModel(ONNXModels.CV.ResNet18()) + val model = modelHub.loadModel(ONNXModels.CV.ResNet18) model.use { assertDoesNotThrow { @@ -90,7 +90,7 @@ class ExecutionProvidersTestSuite { @Test fun twoCpuExecutorsWithDifferentAllocatorsTest() { val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) - val model = modelHub.loadModel(ONNXModels.CV.ResNet18()) + val model = modelHub.loadModel(ONNXModels.CV.ResNet18) model.use { assertThrows { diff --git a/examples/src/test/kotlin/examples/onnx/ModelCopyingTestSuite.kt b/examples/src/test/kotlin/examples/onnx/ModelCopyingTestSuite.kt index de317bc63..3e28f85ed 100644 --- a/examples/src/test/kotlin/examples/onnx/ModelCopyingTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/ModelCopyingTestSuite.kt @@ -34,7 +34,7 @@ class ModelCopyingTestSuite { ONNXModels.CV.EfficientNetB5(), ONNXModels.CV.EfficientNetB6(), ONNXModels.CV.EfficientNetB7(), - ONNXModels.CV.EfficientNet4Lite() + ONNXModels.CV.EfficientNet4Lite ), "datasets/vgg/image0.jpg", ImageRecognitionModel::predictObject @@ -45,17 +45,17 @@ class ModelCopyingTestSuite { fun resNetCopyTest() { runCopyTest( listOf( - ONNXModels.CV.ResNet18(), - ONNXModels.CV.ResNet18v2(), - ONNXModels.CV.ResNet34(), - ONNXModels.CV.ResNet34v2(), - ONNXModels.CV.ResNet50(), - ONNXModels.CV.ResNet50v2(), + ONNXModels.CV.ResNet18, + ONNXModels.CV.ResNet18v2, + ONNXModels.CV.ResNet34, + ONNXModels.CV.ResNet34v2, + ONNXModels.CV.ResNet50, + ONNXModels.CV.ResNet50v2, ONNXModels.CV.ResNet50custom, - ONNXModels.CV.ResNet101(), - ONNXModels.CV.ResNet101v2(), - ONNXModels.CV.ResNet152(), - ONNXModels.CV.ResNet152v2() + ONNXModels.CV.ResNet101, + ONNXModels.CV.ResNet101v2, + ONNXModels.CV.ResNet152, + ONNXModels.CV.ResNet152v2 ), "datasets/vgg/image0.jpg", ImageRecognitionModel::predictObject diff --git a/examples/src/test/kotlin/examples/onnx/cv/OnnxResNetTestSuite.kt b/examples/src/test/kotlin/examples/onnx/cv/OnnxResNetTestSuite.kt index e0975f6c7..2af9c22ab 100644 --- a/examples/src/test/kotlin/examples/onnx/cv/OnnxResNetTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/cv/OnnxResNetTestSuite.kt @@ -26,27 +26,27 @@ class OnnxResNetTestSuite { @Test fun resnet18v2predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet18v2()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet18v2) } @Test fun resnet34predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet34()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet34) } @Test fun resnet34v2predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet34v2()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet34v2) } @Test fun resnet50predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet50()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet50) } @Test fun resnet50v2predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet50v2()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet50v2) } @Test @@ -61,22 +61,22 @@ class OnnxResNetTestSuite { @Test fun resnet101predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet101()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet101) } @Test fun resnet101v2predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet101v2()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet101v2) } @Test fun resnet152predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet152()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet152) } @Test fun resnet152v2predictionTest() { - runImageRecognitionPrediction(ONNXModels.CV.ResNet152v2()) + runImageRecognitionPrediction(ONNXModels.CV.ResNet152v2) } @Test diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt index 5832f825c..075fbd79a 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt @@ -62,7 +62,7 @@ public object ONNXModels { * @see * Official EfficientNet4Lite model from ONNX Github. */ - public class EfficientNet4Lite : CV("efficientnet_lite4", channelsFirst = false) { + public object EfficientNet4Lite : CV("efficientnet_lite4", channelsFirst = false) { override val preprocessor: Operation, Pair> get() = InputType.TF.preprocessing(channelsLast = !channelsFirst) } @@ -83,7 +83,7 @@ public object ONNXModels { * @see * Official EfficientNet4Lite model from ONNX Github. */ - public class MobilenetV1 : CV("mobilenet_v1", channelsFirst = false) { + public object MobilenetV1 : CV("mobilenet_v1", channelsFirst = false) { override val preprocessor: Operation, Pair> get() = pipeline>() .rescale { scalingCoefficient = 255f } diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt index a52baf6e7..bdb65dd11 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/onnx/inference/ONNXModels.kt @@ -71,7 +71,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet18 : CV("models/onnx/cv/resnet/resnet18-v1", channelsFirst = true) { + public object ResNet18 : CV("models/onnx/cv/resnet/resnet18-v1", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() } @@ -93,7 +93,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet34 : CV("models/onnx/cv/resnet/resnet34-v1", channelsFirst = true) { + public object ResNet34 : CV("models/onnx/cv/resnet/resnet34-v1", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() } @@ -115,7 +115,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet50 : + public object ResNet50 : CV("models/onnx/cv/resnet/resnet50-v1", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -138,7 +138,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet101 : + public object ResNet101 : CV("models/onnx/cv/resnet/resnet101-v1", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -161,7 +161,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet152 : + public object ResNet152 : CV("models/onnx/cv/resnet/resnet152-v1", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -184,7 +184,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet18v2 : + public object ResNet18v2 : CV("models/onnx/cv/resnet/resnet18-v2", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -207,7 +207,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet34v2 : + public object ResNet34v2 : CV("models/onnx/cv/resnet/resnet34-v2", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -230,7 +230,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet50v2 : + public object ResNet50v2 : CV("models/onnx/cv/resnet/resnet50-v2", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -253,7 +253,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet101v2 : + public object ResNet101v2 : CV("models/onnx/cv/resnet/resnet101-v2", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -276,7 +276,7 @@ public object ONNXModels { * @see * Official ResNet model from ONNX Github. */ - public class ResNet152v2 : + public object ResNet152v2 : CV("models/onnx/cv/resnet/resnet152-v2", channelsFirst = true) { override val preprocessor: Operation, Pair> get() = resNetOnnxPreprocessing() @@ -299,7 +299,7 @@ public object ONNXModels { * @see * Official EfficientNet4Lite model from ONNX Github. */ - public class EfficientNet4Lite : + public object EfficientNet4Lite : CV("models/onnx/cv/efficientnet/efficientnet-lite4", channelsFirst = false) { override val preprocessor: Operation, Pair> get() = InputType.TF.preprocessing(channelsLast = !channelsFirst)