diff --git a/.history/analysis_options_20211031112442.yaml b/.history/analysis_options_20211031112442.yaml new file mode 100644 index 0000000..c3323b8 --- /dev/null +++ b/.history/analysis_options_20211031112442.yaml @@ -0,0 +1,43 @@ +# This file configures the analyzer, which statically analyzes Dart code to +# check for errors, warnings, and lints. +# +# The issues identified by the analyzer are surfaced in the UI of Dart-enabled +# IDEs (https://dart.dev/tools#ides-and-editors). The analyzer can also be +# invoked from the command line by running `flutter analyze`. + +# The following line activates a set of recommended lints for Flutter apps, +# packages, and plugins designed to encourage good coding practices. +include: package:flutter_lints/flutter.yaml + +analyzer: + errors: + todo: ignore + +linter: + # The lint rules applied to this project can be customized in the + # section below to disable rules from the `package:flutter_lints/flutter.yaml` + # included above or to enable additional rules. A list of all available lints + # and their documentation is published at + # https://dart-lang.github.io/linter/lints/index.html. + # + # Instead of disabling a lint rule for the entire project in the + # section below, it can also be suppressed for a single line of code + # or a specific dart file by using the `// ignore: name_of_lint` and + # `// ignore_for_file: name_of_lint` syntax on the line or in the file + # producing the lint. + rules: + - always_use_package_imports + - always_declare_return_types + - cancel_subscriptions + - close_sinks + - comment_references + - one_member_abstracts + - only_throw_errors + - package_api_docs + - prefer_final_in_for_each + - prefer_single_quotes + # avoid_print: false # Uncomment to disable the `avoid_print` rule + # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule + +# Additional information about this file can be found at +# https://dart.dev/guides/language/analysis-options diff --git a/.history/analysis_options_20211104093118.yaml b/.history/analysis_options_20211104093118.yaml new file mode 100644 index 0000000..013b99c --- /dev/null +++ b/.history/analysis_options_20211104093118.yaml @@ -0,0 +1,43 @@ +# This file configures the analyzer, which statically analyzes Dart code to +# check for errors, warnings, and lints. +# +# The issues identified by the analyzer are surfaced in the UI of Dart-enabled +# IDEs (https://dart.dev/tools#ides-and-editors). The analyzer can also be +# invoked from the command line by running `flutter analyze`. + +# The following line activates a set of recommended lints for Flutter apps, +# packages, and plugins designed to encourage good coding practices. +include: package:flutter_lints/flutter.yaml + +analyzer: + errors: + todo: ignore + +linter: + # The lint rules applied to this project can be customized in the + # section below to disable rules from the `package:flutter_lints/flutter.yaml` + # included above or to enable additional rules. A list of all available lints + # and their documentation is published at + # https://dart-lang.github.io/linter/lints/index.html. + # + # Instead of disabling a lint rule for the entire project in the + # section below, it can also be suppressed for a single line of code + # or a specific dart file by using the `// ignore: name_of_lint` and + # `// ignore_for_file: name_of_lint` syntax on the line or in the file + # producing the lint. + rules: + #- always_use_package_imports + - always_declare_return_types + - cancel_subscriptions + - close_sinks + - comment_references + - one_member_abstracts + - only_throw_errors + - package_api_docs + - prefer_final_in_for_each + - prefer_single_quotes + # avoid_print: false # Uncomment to disable the `avoid_print` rule + # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule + +# Additional information about this file can be found at +# https://dart.dev/guides/language/analysis-options diff --git a/.history/example/pubspec_20211030063242.yaml b/.history/example/pubspec_20211030063242.yaml new file mode 100644 index 0000000..2c66d4d --- /dev/null +++ b/.history/example/pubspec_20211030063242.yaml @@ -0,0 +1,86 @@ +name: tflite_example +description: Demonstrates how to use the tflite plugin. + +# The following line prevents the package from being accidentally published to +# pub.dev using `flutter pub publish`. This is preferred for private packages. +publish_to: 'none' # Remove this line if you wish to publish to pub.dev + +environment: + sdk: ">=2.12.0 <3.0.0" + +# Dependencies specify other packages that your package needs in order to work. +# To automatically upgrade your package dependencies to the latest versions +# consider running `flutter pub upgrade --major-versions`. Alternatively, +# dependencies can be manually updated by changing the version numbers below to +# the latest version available on pub.dev. To see which dependencies have newer +# versions available, run `flutter pub outdated`. +dependencies: + flutter: + sdk: flutter + + tflite: + # When depending on this package from a real application you should use: + # tflite: ^x.y.z + # See https://dart.dev/tools/pub/dependencies#version-constraints + # The example app is bundled with the plugin so we use a path dependency on + # the parent directory to use the current plugin's version. + path: ../ + + # The following adds the Cupertino Icons font to your application. + # Use with the CupertinoIcons class for iOS style icons. + cupertino_icons: ^1.0.2 + image_picker: ^0.8.4+3 + image: ^3.0.8 + +dev_dependencies: + flutter_test: + sdk: flutter + + # The "flutter_lints" package below contains a set of recommended lints to + # encourage good coding practices. The lint set provided by the package is + # activated in the `analysis_options.yaml` file located at the root of your + # package. See that file for information about deactivating specific lint + # rules and activating additional ones. + flutter_lints: ^1.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter. +flutter: + + # The following line ensures that the Material Icons font is + # included with your application, so that you can use the icons in + # the material Icons class. + uses-material-design: true + + # To add assets to your application, add an assets section, like this: + # assets: + # - images/a_dot_burr.jpeg + # - images/a_dot_ham.jpeg + + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware. + + # For details regarding adding assets from package dependencies, see + # https://flutter.dev/assets-and-images/#from-packages + + # To add custom fonts to your application, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts from package dependencies, + # see https://flutter.dev/custom-fonts/#from-packages diff --git a/.history/example/pubspec_20211104101124.yaml b/.history/example/pubspec_20211104101124.yaml new file mode 100644 index 0000000..7449391 --- /dev/null +++ b/.history/example/pubspec_20211104101124.yaml @@ -0,0 +1,94 @@ +name: tflite_example +description: Demonstrates how to use the tflite plugin. + +# The following line prevents the package from being accidentally published to +# pub.dev using `flutter pub publish`. This is preferred for private packages. +publish_to: 'none' # Remove this line if you wish to publish to pub.dev + +environment: + sdk: ">=2.12.0 <3.0.0" + +# Dependencies specify other packages that your package needs in order to work. +# To automatically upgrade your package dependencies to the latest versions +# consider running `flutter pub upgrade --major-versions`. Alternatively, +# dependencies can be manually updated by changing the version numbers below to +# the latest version available on pub.dev. To see which dependencies have newer +# versions available, run `flutter pub outdated`. +dependencies: + flutter: + sdk: flutter + + tflite: + # When depending on this package from a real application you should use: + # tflite: ^x.y.z + # See https://dart.dev/tools/pub/dependencies#version-constraints + # The example app is bundled with the plugin so we use a path dependency on + # the parent directory to use the current plugin's version. + path: ../ + + # The following adds the Cupertino Icons font to your application. + # Use with the CupertinoIcons class for iOS style icons. + cupertino_icons: ^1.0.2 + image_picker: ^0.8.4+3 + image: ^3.0.8 + +dev_dependencies: + flutter_test: + sdk: flutter + + # The "flutter_lints" package below contains a set of recommended lints to + # encourage good coding practices. The lint set provided by the package is + # activated in the `analysis_options.yaml` file located at the root of your + # package. See that file for information about deactivating specific lint + # rules and activating additional ones. + flutter_lints: ^1.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter. +flutter: + + # The following line ensures that the Material Icons font is + # included with your application, so that you can use the icons in + # the material Icons class. + uses-material-design: true + + # To add assets to your application, add an assets section, like this: + assets: + - assets/mobilenet_v1_1.0_224.txt + - assets/mobilenet_v1_1.0_224.tflite + - assets/yolov2_tiny.tflite + - assets/yolov2_tiny.txt + - assets/ssd_mobilenet.tflite + - assets/ssd_mobilenet.txt + - assets/deeplabv3_257_mv_gpu.tflite + - assets/deeplabv3_257_mv_gpu.txt + - assets/posenet_mv1_075_float_from_checkpoints.tflite + + + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware. + + # For details regarding adding assets from package dependencies, see + # https://flutter.dev/assets-and-images/#from-packages + + # To add custom fonts to your application, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts from package dependencies, + # see https://flutter.dev/custom-fonts/#from-packages diff --git a/.history/pubspec_20211104092600.yaml b/.history/pubspec_20211104092600.yaml new file mode 100644 index 0000000..538653d --- /dev/null +++ b/.history/pubspec_20211104092600.yaml @@ -0,0 +1,70 @@ +name: tflite +description: A Flutter plugin for accessing TensorFlow Lite. Supports both iOS and Android. +version: 1.1.3 +homepage: https://github.com/shaqian/flutter_tflite + +environment: + sdk: ">=2.12.0 <3.0.0" + flutter: ">=1.20.0" + +dependencies: + flutter: + sdk: flutter + flutter_web_plugins: + sdk: flutter + +dev_dependencies: + flutter_test: + sdk: flutter + flutter_lints: ^1.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter. +flutter: + # This section identifies this Flutter project as a plugin project. + # The 'pluginClass' and Android 'package' identifiers should not ordinarily + # be modified. They are used by the tooling to maintain consistency when + # adding or updating assets for this project. + plugin: + platforms: + android: + package: sq.flutter.tflite + pluginClass: TflitePlugin + ios: + pluginClass: TflitePlugin + web: + pluginClass: TfliteWeb + fileName: tflite_web.dart + + # To add assets to your plugin package, add an assets section, like this: + # assets: + # - images/a_dot_burr.jpeg + # - images/a_dot_ham.jpeg + # + # For details regarding assets in packages, see + # https://flutter.dev/assets-and-images/#from-packages + # + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware. + + # To add custom fonts to your plugin package, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts in packages, see + # https://flutter.dev/custom-fonts/#from-packages diff --git a/.history/pubspec_20211104092801.yaml b/.history/pubspec_20211104092801.yaml new file mode 100644 index 0000000..538653d --- /dev/null +++ b/.history/pubspec_20211104092801.yaml @@ -0,0 +1,70 @@ +name: tflite +description: A Flutter plugin for accessing TensorFlow Lite. Supports both iOS and Android. +version: 1.1.3 +homepage: https://github.com/shaqian/flutter_tflite + +environment: + sdk: ">=2.12.0 <3.0.0" + flutter: ">=1.20.0" + +dependencies: + flutter: + sdk: flutter + flutter_web_plugins: + sdk: flutter + +dev_dependencies: + flutter_test: + sdk: flutter + flutter_lints: ^1.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter. +flutter: + # This section identifies this Flutter project as a plugin project. + # The 'pluginClass' and Android 'package' identifiers should not ordinarily + # be modified. They are used by the tooling to maintain consistency when + # adding or updating assets for this project. + plugin: + platforms: + android: + package: sq.flutter.tflite + pluginClass: TflitePlugin + ios: + pluginClass: TflitePlugin + web: + pluginClass: TfliteWeb + fileName: tflite_web.dart + + # To add assets to your plugin package, add an assets section, like this: + # assets: + # - images/a_dot_burr.jpeg + # - images/a_dot_ham.jpeg + # + # For details regarding assets in packages, see + # https://flutter.dev/assets-and-images/#from-packages + # + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware. + + # To add custom fonts to your plugin package, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts in packages, see + # https://flutter.dev/custom-fonts/#from-packages diff --git a/.history/pubspec_20211104092836.yaml b/.history/pubspec_20211104092836.yaml new file mode 100644 index 0000000..b801677 --- /dev/null +++ b/.history/pubspec_20211104092836.yaml @@ -0,0 +1,70 @@ +name: tflite +description: A Flutter plugin for accessing TensorFlow Lite. Supports both iOS and Android, Web under development. +version: 1.1.3 +homepage: https://github.com/shaqian/flutter_tflite + +environment: + sdk: ">=2.12.0 <3.0.0" + flutter: ">=1.20.0" + +dependencies: + flutter: + sdk: flutter + flutter_web_plugins: + sdk: flutter + +dev_dependencies: + flutter_test: + sdk: flutter + flutter_lints: ^1.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter. +flutter: + # This section identifies this Flutter project as a plugin project. + # The 'pluginClass' and Android 'package' identifiers should not ordinarily + # be modified. They are used by the tooling to maintain consistency when + # adding or updating assets for this project. + plugin: + platforms: + android: + package: sq.flutter.tflite + pluginClass: TflitePlugin + ios: + pluginClass: TflitePlugin + web: + pluginClass: TfliteWeb + fileName: tflite_web.dart + + # To add assets to your plugin package, add an assets section, like this: + # assets: + # - images/a_dot_burr.jpeg + # - images/a_dot_ham.jpeg + # + # For details regarding assets in packages, see + # https://flutter.dev/assets-and-images/#from-packages + # + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware. + + # To add custom fonts to your plugin package, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts in packages, see + # https://flutter.dev/custom-fonts/#from-packages diff --git a/analysis_options.yaml b/analysis_options.yaml new file mode 100644 index 0000000..013b99c --- /dev/null +++ b/analysis_options.yaml @@ -0,0 +1,43 @@ +# This file configures the analyzer, which statically analyzes Dart code to +# check for errors, warnings, and lints. +# +# The issues identified by the analyzer are surfaced in the UI of Dart-enabled +# IDEs (https://dart.dev/tools#ides-and-editors). The analyzer can also be +# invoked from the command line by running `flutter analyze`. + +# The following line activates a set of recommended lints for Flutter apps, +# packages, and plugins designed to encourage good coding practices. +include: package:flutter_lints/flutter.yaml + +analyzer: + errors: + todo: ignore + +linter: + # The lint rules applied to this project can be customized in the + # section below to disable rules from the `package:flutter_lints/flutter.yaml` + # included above or to enable additional rules. A list of all available lints + # and their documentation is published at + # https://dart-lang.github.io/linter/lints/index.html. + # + # Instead of disabling a lint rule for the entire project in the + # section below, it can also be suppressed for a single line of code + # or a specific dart file by using the `// ignore: name_of_lint` and + # `// ignore_for_file: name_of_lint` syntax on the line or in the file + # producing the lint. + rules: + #- always_use_package_imports + - always_declare_return_types + - cancel_subscriptions + - close_sinks + - comment_references + - one_member_abstracts + - only_throw_errors + - package_api_docs + - prefer_final_in_for_each + - prefer_single_quotes + # avoid_print: false # Uncomment to disable the `avoid_print` rule + # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule + +# Additional information about this file can be found at +# https://dart.dev/guides/language/analysis-options diff --git a/android/.classpath b/android/.classpath deleted file mode 100644 index eb19361..0000000 --- a/android/.classpath +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/android/.idea/.gitignore b/android/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/android/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/android/.idea/.name b/android/.idea/.name new file mode 100644 index 0000000..f62d0c7 --- /dev/null +++ b/android/.idea/.name @@ -0,0 +1 @@ +tflite \ No newline at end of file diff --git a/android/.idea/compiler.xml b/android/.idea/compiler.xml new file mode 100644 index 0000000..fb7f4a8 --- /dev/null +++ b/android/.idea/compiler.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/android/.idea/gradle.xml b/android/.idea/gradle.xml new file mode 100644 index 0000000..b220fcc --- /dev/null +++ b/android/.idea/gradle.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/android/.idea/jarRepositories.xml b/android/.idea/jarRepositories.xml new file mode 100644 index 0000000..d2ce72d --- /dev/null +++ b/android/.idea/jarRepositories.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/.idea/misc.xml b/android/.idea/misc.xml index 9eeb5ee..860da66 100644 --- a/android/.idea/misc.xml +++ b/android/.idea/misc.xml @@ -1,5 +1,8 @@ + + + diff --git a/android/.idea/modules.xml b/android/.idea/modules.xml new file mode 100644 index 0000000..6e1e501 --- /dev/null +++ b/android/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/android/.project b/android/.project index 5fb2462..e28190d 100644 --- a/android/.project +++ b/android/.project @@ -1,15 +1,10 @@ - tflite - Project tflite created by Buildship. + android + Project android created by Buildship. - - org.eclipse.jdt.core.javabuilder - - - org.eclipse.buildship.core.gradleprojectbuilder @@ -17,7 +12,17 @@ - org.eclipse.jdt.core.javanature org.eclipse.buildship.core.gradleprojectnature + + + 1635544201826 + + 30 + + org.eclipse.core.resources.regexFilterMatcher + node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ + + + diff --git a/android/.settings/org.eclipse.buildship.core.prefs b/android/.settings/org.eclipse.buildship.core.prefs index 6aa97a9..6606b3d 100644 --- a/android/.settings/org.eclipse.buildship.core.prefs +++ b/android/.settings/org.eclipse.buildship.core.prefs @@ -1,2 +1,13 @@ -connection.project.dir=../example/android +arguments= +auto.sync=false +build.scans.enabled=false +connection.gradle.distribution=GRADLE_DISTRIBUTION(VERSION(7.0-rc-1)) +connection.project.dir= eclipse.preferences.version=1 +gradle.user.home= +java.home=C\:/Program Files/Eclipse Foundation/jdk-17.0.0.35-hotspot +jvm.arguments= +offline.mode=false +override.workspace.settings=true +show.console.view=true +show.executions.view=true diff --git a/android/build.gradle b/android/build.gradle index 8002459..fc6bdcb 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -1,39 +1,39 @@ group 'sq.flutter.tflite' -version '1.0-SNAPSHOT' +version '1.0' buildscript { repositories { google() - jcenter() + mavenCentral() } dependencies { - classpath 'com.android.tools.build:gradle:3.6.3' + classpath 'com.android.tools.build:gradle:4.1.0' } } rootProject.allprojects { repositories { google() - jcenter() + mavenCentral() } } apply plugin: 'com.android.library' android { - compileSdkVersion 28 + compileSdkVersion 31 - defaultConfig { - minSdkVersion 19 - testInstrumentationRunner 'androidx.test.runner.AndroidJUnitRunner' - } - lintOptions { - disable 'InvalidPackage' + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 } + defaultConfig { + minSdkVersion 17 + } dependencies { - compile 'org.tensorflow:tensorflow-lite:+' - compile 'org.tensorflow:tensorflow-lite-gpu:+' + implementation 'org.tensorflow:tensorflow-lite:+' + implementation 'org.tensorflow:tensorflow-lite-gpu:+' } } diff --git a/android/gradle.properties b/android/gradle.properties index 4167249..94adc3a 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -1,4 +1,3 @@ org.gradle.jvmargs=-Xmx1536M android.useAndroidX=true android.enableJetifier=true -android.enableR8=true \ No newline at end of file diff --git a/android/gradle/wrapper/gradle-wrapper.properties b/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..3c9d085 --- /dev/null +++ b/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip diff --git a/android/src/main/java/sq/flutter/tflite/TflitePlugin.java b/android/src/main/java/sq/flutter/tflite/TflitePlugin.java index 2579d35..84bffc6 100644 --- a/android/src/main/java/sq/flutter/tflite/TflitePlugin.java +++ b/android/src/main/java/sq/flutter/tflite/TflitePlugin.java @@ -1,5 +1,6 @@ package sq.flutter.tflite; +import android.app.Activity; import android.content.Context; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; @@ -19,22 +20,15 @@ import android.renderscript.Type; import android.util.Log; -import io.flutter.plugin.common.MethodCall; -import io.flutter.plugin.common.MethodChannel; -import io.flutter.plugin.common.MethodChannel.MethodCallHandler; -import io.flutter.plugin.common.MethodChannel.Result; -import io.flutter.plugin.common.PluginRegistry.Registrar; -import org.tensorflow.lite.DataType; -import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.Tensor; +import androidx.annotation.NonNull; +import org.tensorflow.lite.*; import org.tensorflow.lite.gpu.GpuDelegate; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; @@ -51,1536 +45,1608 @@ import java.util.PriorityQueue; import java.util.Vector; +import io.flutter.embedding.engine.plugins.FlutterPlugin; +import io.flutter.embedding.engine.plugins.activity.ActivityAware; +import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding; +import io.flutter.plugin.common.BinaryMessenger; +import io.flutter.plugin.common.MethodCall; +import io.flutter.plugin.common.MethodChannel; +import io.flutter.plugin.common.MethodChannel.Result; +import io.flutter.plugin.common.PluginRegistry; -public class TflitePlugin implements MethodCallHandler { - private final Registrar mRegistrar; - private Interpreter tfLite; - private boolean tfLiteBusy = false; - private int inputSize = 0; - private Vector labels; - float[][] labelProb; - private static final int BYTES_PER_CHANNEL = 4; - - String[] partNames = { - "nose", "leftEye", "rightEye", "leftEar", "rightEar", "leftShoulder", - "rightShoulder", "leftElbow", "rightElbow", "leftWrist", "rightWrist", - "leftHip", "rightHip", "leftKnee", "rightKnee", "leftAnkle", "rightAnkle" - }; - - String[][] poseChain = { - {"nose", "leftEye"}, {"leftEye", "leftEar"}, {"nose", "rightEye"}, - {"rightEye", "rightEar"}, {"nose", "leftShoulder"}, - {"leftShoulder", "leftElbow"}, {"leftElbow", "leftWrist"}, - {"leftShoulder", "leftHip"}, {"leftHip", "leftKnee"}, - {"leftKnee", "leftAnkle"}, {"nose", "rightShoulder"}, - {"rightShoulder", "rightElbow"}, {"rightElbow", "rightWrist"}, - {"rightShoulder", "rightHip"}, {"rightHip", "rightKnee"}, - {"rightKnee", "rightAnkle"} - }; - - Map partsIds = new HashMap<>(); - List parentToChildEdges = new ArrayList<>(); - List childToParentEdges = new ArrayList<>(); - - public static void registerWith(Registrar registrar) { - final MethodChannel channel = new MethodChannel(registrar.messenger(), "tflite"); - channel.setMethodCallHandler(new TflitePlugin(registrar)); - } - - private TflitePlugin(Registrar registrar) { - this.mRegistrar = registrar; - } - @Override - public void onMethodCall(MethodCall call, Result result) { - if (call.method.equals("loadModel")) { - try { - String res = loadModel((HashMap) call.arguments); - result.success(res); - } catch (Exception e) { - result.error("Failed to load model", e.getMessage(), e); - } - } else if (call.method.equals("runModelOnImage")) { - try { - new RunModelOnImage((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runModelOnBinary")) { - try { - new RunModelOnBinary((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runModelOnFrame")) { - try { - new RunModelOnFrame((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("detectObjectOnImage")) { - try { - detectObjectOnImage((HashMap) call.arguments, result); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("detectObjectOnBinary")) { - try { - detectObjectOnBinary((HashMap) call.arguments, result); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("detectObjectOnFrame")) { - try { - detectObjectOnFrame((HashMap) call.arguments, result); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("close")) { - close(); - } else if (call.method.equals("runPix2PixOnImage")) { - try { - new RunPix2PixOnImage((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runPix2PixOnBinary")) { - try { - new RunPix2PixOnBinary((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runPix2PixOnFrame")) { - try { - new RunPix2PixOnFrame((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runSegmentationOnImage")) { - try { - new RunSegmentationOnImage((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runSegmentationOnBinary")) { - try { - new RunSegmentationOnBinary((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runSegmentationOnFrame")) { - try { - new RunSegmentationOnFrame((HashMap) call.arguments, result).executeTfliteTask(); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runPoseNetOnImage")) { - try { - runPoseNetOnImage((HashMap) call.arguments, result); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runPoseNetOnBinary")) { - try { - runPoseNetOnBinary((HashMap) call.arguments, result); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else if (call.method.equals("runPoseNetOnFrame")) { - try { - runPoseNetOnFrame((HashMap) call.arguments, result); - } catch (Exception e) { - result.error("Failed to run model", e.getMessage(), e); - } - } else { - result.error("Invalid method", call.method.toString(), ""); +/** + * TflitePlugin + */ +public class TflitePlugin implements FlutterPlugin, MethodChannel.MethodCallHandler, ActivityAware { + /// The MethodChannel that will the communication between Flutter and native Android + /// + /// This local reference serves to register the plugin with the Flutter Engine and unregister it + /// when the Flutter Engine is detached from the Activity + private MethodChannel channel; + private Context applicationContext; + @SuppressWarnings("deprecation") + private PluginRegistry.Registrar registrar; + /** + * Plugin registration. + */ + @SuppressWarnings("deprecation") + public void registerWith(io.flutter.plugin.common.PluginRegistry.Registrar registrar) { + final TflitePlugin instance = new TflitePlugin(); + this.registrar = registrar; + instance.onAttachedToEngine(registrar.activity(), registrar.messenger()); } - } - - private String loadModel(HashMap args) throws IOException { - String model = args.get("model").toString(); - Object isAssetObj = args.get("isAsset"); - boolean isAsset = isAssetObj == null ? false : (boolean) isAssetObj; - MappedByteBuffer buffer = null; - String key = null; - AssetManager assetManager = null; - if (isAsset) { - assetManager = mRegistrar.context().getAssets(); - key = mRegistrar.lookupKeyForAsset(model); - AssetFileDescriptor fileDescriptor = assetManager.openFd(key); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } else { - FileInputStream inputStream = new FileInputStream(new File(model)); - FileChannel fileChannel = inputStream.getChannel(); - long declaredLength = fileChannel.size(); - buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, declaredLength); + private void onAttachedToEngine(Context applicationContext, BinaryMessenger messenger) { + this.applicationContext = applicationContext; + MethodChannel methodChannel = new MethodChannel(messenger, "plugins.flutter.io/tensor"); + methodChannel.setMethodCallHandler(this); } - int numThreads = (int) args.get("numThreads"); - Boolean useGpuDelegate = (Boolean) args.get("useGpuDelegate"); - if (useGpuDelegate == null) { - useGpuDelegate = false; + @Override + public void onAttachedToActivity(@NonNull ActivityPluginBinding binding) { + Activity activity = binding.getActivity(); } - final Interpreter.Options tfliteOptions = new Interpreter.Options(); - tfliteOptions.setNumThreads(numThreads); - if (useGpuDelegate){ - GpuDelegate delegate = new GpuDelegate(); - tfliteOptions.addDelegate(delegate); + @Override + public void onDetachedFromActivityForConfigChanges() { + } - tfLite = new Interpreter(buffer, tfliteOptions); - String labels = args.get("labels").toString(); + @Override + public void onReattachedToActivityForConfigChanges(@NonNull ActivityPluginBinding binding) { - if (labels.length() > 0) { - if (isAsset) { - key = mRegistrar.lookupKeyForAsset(labels); - loadLabels(assetManager, key); - } else { - loadLabels(null, labels); - } } - return "success"; - } + @Override + public void onDetachedFromActivity() { - private void loadLabels(AssetManager assetManager, String path) { - BufferedReader br; - try { - if (assetManager != null) { - br = new BufferedReader(new InputStreamReader(assetManager.open(path))); - } else { - br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path)))); - } - String line; - labels = new Vector<>(); - while ((line = br.readLine()) != null) { - labels.add(line); - } - labelProb = new float[1][labels.size()]; - br.close(); - } catch (IOException e) { - throw new RuntimeException("Failed to read label file", e); } - } - private List> GetTopN(int numResults, float threshold) { - PriorityQueue> pq = - new PriorityQueue<>( - 1, - new Comparator>() { - @Override - public int compare(Map lhs, Map rhs) { - return Float.compare((float) rhs.get("confidence"), (float) lhs.get("confidence")); - } - }); - - for (int i = 0; i < labels.size(); ++i) { - float confidence = labelProb[0][i]; - if (confidence > threshold) { - Map res = new HashMap<>(); - res.put("index", i); - res.put("label", labels.size() > i ? labels.get(i) : "unknown"); - res.put("confidence", confidence); - pq.add(res); - } + @Override + public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { + channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), "tflite"); + channel.setMethodCallHandler(this); + Context context = flutterPluginBinding.getApplicationContext(); } - final ArrayList> recognitions = new ArrayList<>(); - int recognitionsSize = Math.min(pq.size(), numResults); - for (int i = 0; i < recognitionsSize; ++i) { - recognitions.add(pq.poll()); + @Override + public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) { + if (call.method.equals("getPlatformVersion")) { + result.success("Android " + android.os.Build.VERSION.RELEASE); + } else { + result.notImplemented(); + } + if (call.method.equals("loadModel")) { + try { + String res = loadModel((HashMap) call.arguments); + result.success(res); + } catch (Exception e) { + result.error("Failed to load model", e.getMessage(), e); + } + } else if (call.method.equals("runModelOnImage")) { + try { + new RunModelOnImage((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runModelOnBinary")) { + try { + new RunModelOnBinary((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runModelOnFrame")) { + try { + new RunModelOnFrame((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("detectObjectOnImage")) { + try { + detectObjectOnImage((HashMap) call.arguments, result); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("detectObjectOnBinary")) { + try { + detectObjectOnBinary((HashMap) call.arguments, result); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("detectObjectOnFrame")) { + try { + detectObjectOnFrame((HashMap) call.arguments, result); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("close")) { + close(); + } else if (call.method.equals("runPix2PixOnImage")) { + try { + new RunPix2PixOnImage((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runPix2PixOnBinary")) { + try { + new RunPix2PixOnBinary((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runPix2PixOnFrame")) { + try { + new RunPix2PixOnFrame((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runSegmentationOnImage")) { + try { + new RunSegmentationOnImage((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runSegmentationOnBinary")) { + try { + new RunSegmentationOnBinary((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runSegmentationOnFrame")) { + try { + new RunSegmentationOnFrame((HashMap) call.arguments, result).executeTfliteTask(); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runPoseNetOnImage")) { + try { + runPoseNetOnImage((HashMap) call.arguments, result); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runPoseNetOnBinary")) { + try { + runPoseNetOnBinary((HashMap) call.arguments, result); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else if (call.method.equals("runPoseNetOnFrame")) { + try { + runPoseNetOnFrame((HashMap) call.arguments, result); + } catch (Exception e) { + result.error("Failed to run model", e.getMessage(), e); + } + } else { + result.error("Invalid method", call.method.toString(), ""); + } } - return recognitions; + @Override + public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { + channel.setMethodCallHandler(null); + } + + private Interpreter tfLite; + private boolean tfLiteBusy = false; + private int inputSize = 0; + private Vector labels; + float[][] labelProb; + private static final int BYTES_PER_CHANNEL = 4; + + String[] partNames = { + "nose", "leftEye", "rightEye", "leftEar", "rightEar", "leftShoulder", + "rightShoulder", "leftElbow", "rightElbow", "leftWrist", "rightWrist", + "leftHip", "rightHip", "leftKnee", "rightKnee", "leftAnkle", "rightAnkle" + }; + + String[][] poseChain = { + {"nose", "leftEye"}, {"leftEye", "leftEar"}, {"nose", "rightEye"}, + {"rightEye", "rightEar"}, {"nose", "leftShoulder"}, + {"leftShoulder", "leftElbow"}, {"leftElbow", "leftWrist"}, + {"leftShoulder", "leftHip"}, {"leftHip", "leftKnee"}, + {"leftKnee", "leftAnkle"}, {"nose", "rightShoulder"}, + {"rightShoulder", "rightElbow"}, {"rightElbow", "rightWrist"}, + {"rightShoulder", "rightHip"}, {"rightHip", "rightKnee"}, + {"rightKnee", "rightAnkle"} + }; + + + Map partsIds = new HashMap<>(); + List parentToChildEdges = new ArrayList<>(); + List childToParentEdges = new ArrayList<>(); + + /*private TflitePlugin(Registrar registrar) { + this.mRegistrar = registrar; } +*/ + + + private String loadModel(HashMap args) throws IOException { + String model = args.get("model").toString(); + Object isAssetObj = args.get("isAsset"); + boolean isAsset = isAssetObj == null ? false : (boolean) isAssetObj; + MappedByteBuffer buffer = null; + String key = null; + AssetManager assetManager = registrar.context().getAssets(); + if (isAsset) { + assetManager = applicationContext.getAssets(); + key = registrar.lookupKeyForAsset(model); + AssetFileDescriptor fileDescriptor = assetManager.openFd(key); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } else { + FileInputStream inputStream = new FileInputStream(new File(model)); + FileChannel fileChannel = inputStream.getChannel(); + long declaredLength = fileChannel.size(); + buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, declaredLength); + } - Bitmap feedOutput(ByteBuffer imgData, float mean, float std) { - Tensor tensor = tfLite.getOutputTensor(0); - int outputSize = tensor.shape()[1]; - Bitmap bitmapRaw = Bitmap.createBitmap(outputSize, outputSize, Bitmap.Config.ARGB_8888); - - if (tensor.dataType() == DataType.FLOAT32) { - for (int i = 0; i < outputSize; ++i) { - for (int j = 0; j < outputSize; ++j) { - int pixelValue = 0xFF << 24; - pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF) << 16); - pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF) << 8); - pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF)); - bitmapRaw.setPixel(j, i, pixelValue); - } - } - } else { - for (int i = 0; i < outputSize; ++i) { - for (int j = 0; j < outputSize; ++j) { - int pixelValue = 0xFF << 24; - pixelValue |= ((imgData.get() & 0xFF) << 16); - pixelValue |= ((imgData.get() & 0xFF) << 8); - pixelValue |= ((imgData.get() & 0xFF)); - bitmapRaw.setPixel(j, i, pixelValue); - } - } + int numThreads = (int) args.get("numThreads"); + Boolean useGpuDelegate = (Boolean) args.get("useGpuDelegate"); + if (useGpuDelegate == null) { + useGpuDelegate = false; + } + + final Interpreter.Options tfliteOptions = new Interpreter.Options(); + tfliteOptions.setNumThreads(numThreads); + if (useGpuDelegate) { + GpuDelegate delegate = new GpuDelegate(); + tfliteOptions.addDelegate(delegate); + } + tfLite = new Interpreter(buffer, tfliteOptions); + + String labels = args.get("labels").toString(); + + if (labels.length() > 0) { + if (isAsset) { + key = registrar.lookupKeyForAsset(labels); + loadLabels(assetManager, key); + } else { + loadLabels(null, labels); + } + } + + return "success"; } - return bitmapRaw; - } - ByteBuffer feedInputTensor(Bitmap bitmapRaw, float mean, float std) throws IOException { - Tensor tensor = tfLite.getInputTensor(0); - int[] shape = tensor.shape(); - inputSize = shape[1]; - int inputChannels = shape[3]; - - int bytePerChannel = tensor.dataType() == DataType.UINT8 ? 1 : BYTES_PER_CHANNEL; - ByteBuffer imgData = ByteBuffer.allocateDirect(1 * inputSize * inputSize * inputChannels * bytePerChannel); - imgData.order(ByteOrder.nativeOrder()); - - Bitmap bitmap = bitmapRaw; - if (bitmapRaw.getWidth() != inputSize || bitmapRaw.getHeight() != inputSize) { - Matrix matrix = getTransformationMatrix(bitmapRaw.getWidth(), bitmapRaw.getHeight(), - inputSize, inputSize, false); - bitmap = Bitmap.createBitmap(inputSize, inputSize, Bitmap.Config.ARGB_8888); - final Canvas canvas = new Canvas(bitmap); - if (inputChannels == 1){ - Paint paint = new Paint(); - ColorMatrix cm = new ColorMatrix(); - cm.setSaturation(0); - ColorMatrixColorFilter f = new ColorMatrixColorFilter(cm); - paint.setColorFilter(f); - canvas.drawBitmap(bitmapRaw, matrix, paint); - } else { - canvas.drawBitmap(bitmapRaw, matrix, null); - } + private void loadLabels(AssetManager assetManager, String path) { + BufferedReader br; + try { + if (assetManager != null) { + br = new BufferedReader(new InputStreamReader(assetManager.open(path))); + } else { + br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path)))); + } + String line; + labels = new Vector<>(); + while ((line = br.readLine()) != null) { + labels.add(line); + } + labelProb = new float[1][labels.size()]; + br.close(); + } catch (IOException e) { + throw new RuntimeException("Failed to read label file", e); + } } - if (tensor.dataType() == DataType.FLOAT32) { - for (int i = 0; i < inputSize; ++i) { - for (int j = 0; j < inputSize; ++j) { - int pixelValue = bitmap.getPixel(j, i); - if (inputChannels > 1){ - imgData.putFloat((((pixelValue >> 16) & 0xFF) - mean) / std); - imgData.putFloat((((pixelValue >> 8) & 0xFF) - mean) / std); - imgData.putFloat(((pixelValue & 0xFF) - mean) / std); - } else { - imgData.putFloat((((pixelValue >> 16 | pixelValue >> 8 | pixelValue) & 0xFF) - mean) / std); - } - } - } - } else { - for (int i = 0; i < inputSize; ++i) { - for (int j = 0; j < inputSize; ++j) { - int pixelValue = bitmap.getPixel(j, i); - if (inputChannels > 1){ - imgData.put((byte) ((pixelValue >> 16) & 0xFF)); - imgData.put((byte) ((pixelValue >> 8) & 0xFF)); - imgData.put((byte) (pixelValue & 0xFF)); - } else { - imgData.put((byte) ((pixelValue >> 16 | pixelValue >> 8 | pixelValue) & 0xFF)); - } - } - } + private List> GetTopN(int numResults, float threshold) { + PriorityQueue> pq = + new PriorityQueue<>( + 1, + new Comparator>() { + @Override + public int compare(Map lhs, Map rhs) { + return Float.compare((float) rhs.get("confidence"), (float) lhs.get("confidence")); + } + }); + + for (int i = 0; i < labels.size(); ++i) { + float confidence = labelProb[0][i]; + if (confidence > threshold) { + Map res = new HashMap<>(); + res.put("index", i); + res.put("label", labels.size() > i ? labels.get(i) : "unknown"); + res.put("confidence", confidence); + pq.add(res); + } + } + + final ArrayList> recognitions = new ArrayList<>(); + int recognitionsSize = Math.min(pq.size(), numResults); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + + return recognitions; } - return imgData; - } + Bitmap feedOutput(ByteBuffer imgData, float mean, float std) { + Tensor tensor = tfLite.getOutputTensor(0); + int outputSize = tensor.shape()[1]; + Bitmap bitmapRaw = Bitmap.createBitmap(outputSize, outputSize, Bitmap.Config.ARGB_8888); + + if (tensor.dataType() == DataType.FLOAT32) { + for (int i = 0; i < outputSize; ++i) { + for (int j = 0; j < outputSize; ++j) { + int pixelValue = 0xFF << 24; + pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF) << 16); + pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF) << 8); + pixelValue |= ((Math.round(imgData.getFloat() * std + mean) & 0xFF)); + bitmapRaw.setPixel(j, i, pixelValue); + } + } + } else { + for (int i = 0; i < outputSize; ++i) { + for (int j = 0; j < outputSize; ++j) { + int pixelValue = 0xFF << 24; + pixelValue |= ((imgData.get() & 0xFF) << 16); + pixelValue |= ((imgData.get() & 0xFF) << 8); + pixelValue |= ((imgData.get() & 0xFF)); + bitmapRaw.setPixel(j, i, pixelValue); + } + } + } + return bitmapRaw; + } + + ByteBuffer feedInputTensor(Bitmap bitmapRaw, float mean, float std) throws IOException { + Tensor tensor = tfLite.getInputTensor(0); + int[] shape = tensor.shape(); + inputSize = shape[1]; + int inputChannels = shape[3]; + + int bytePerChannel = tensor.dataType() == DataType.UINT8 ? 1 : BYTES_PER_CHANNEL; + ByteBuffer imgData = ByteBuffer.allocateDirect(1 * inputSize * inputSize * inputChannels * bytePerChannel); + imgData.order(ByteOrder.nativeOrder()); + + Bitmap bitmap = bitmapRaw; + if (bitmapRaw.getWidth() != inputSize || bitmapRaw.getHeight() != inputSize) { + Matrix matrix = getTransformationMatrix(bitmapRaw.getWidth(), bitmapRaw.getHeight(), + inputSize, inputSize, false); + bitmap = Bitmap.createBitmap(inputSize, inputSize, Bitmap.Config.ARGB_8888); + final Canvas canvas = new Canvas(bitmap); + if (inputChannels == 1) { + Paint paint = new Paint(); + ColorMatrix cm = new ColorMatrix(); + cm.setSaturation(0); + ColorMatrixColorFilter f = new ColorMatrixColorFilter(cm); + paint.setColorFilter(f); + canvas.drawBitmap(bitmapRaw, matrix, paint); + } else { + canvas.drawBitmap(bitmapRaw, matrix, null); + } + } - ByteBuffer feedInputTensorImage(String path, float mean, float std) throws IOException { - InputStream inputStream = new FileInputStream(path.replace("file://", "")); - Bitmap bitmapRaw = BitmapFactory.decodeStream(inputStream); + if (tensor.dataType() == DataType.FLOAT32) { + for (int i = 0; i < inputSize; ++i) { + for (int j = 0; j < inputSize; ++j) { + int pixelValue = bitmap.getPixel(j, i); + if (inputChannels > 1) { + imgData.putFloat((((pixelValue >> 16) & 0xFF) - mean) / std); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - mean) / std); + imgData.putFloat(((pixelValue & 0xFF) - mean) / std); + } else { + imgData.putFloat((((pixelValue >> 16 | pixelValue >> 8 | pixelValue) & 0xFF) - mean) / std); + } + } + } + } else { + for (int i = 0; i < inputSize; ++i) { + for (int j = 0; j < inputSize; ++j) { + int pixelValue = bitmap.getPixel(j, i); + if (inputChannels > 1) { + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); + } else { + imgData.put((byte) ((pixelValue >> 16 | pixelValue >> 8 | pixelValue) & 0xFF)); + } + } + } + } - return feedInputTensor(bitmapRaw, mean, std); - } + return imgData; + } - ByteBuffer feedInputTensorFrame(List bytesList, int imageHeight, int imageWidth, float mean, float std, int rotation) throws IOException { - ByteBuffer Y = ByteBuffer.wrap(bytesList.get(0)); - ByteBuffer U = ByteBuffer.wrap(bytesList.get(1)); - ByteBuffer V = ByteBuffer.wrap(bytesList.get(2)); + ByteBuffer feedInputTensorImage(String path, float mean, float std) throws IOException { + InputStream inputStream = new FileInputStream(path.replace("file://", "")); + Bitmap bitmapRaw = BitmapFactory.decodeStream(inputStream); - int Yb = Y.remaining(); - int Ub = U.remaining(); - int Vb = V.remaining(); + return feedInputTensor(bitmapRaw, mean, std); + } - byte[] data = new byte[Yb + Ub + Vb]; + ByteBuffer feedInputTensorFrame(List bytesList, int imageHeight, int imageWidth, float mean, float std, int rotation) throws IOException { + ByteBuffer Y = ByteBuffer.wrap(bytesList.get(0)); + ByteBuffer U = ByteBuffer.wrap(bytesList.get(1)); + ByteBuffer V = ByteBuffer.wrap(bytesList.get(2)); - Y.get(data, 0, Yb); - V.get(data, Yb, Vb); - U.get(data, Yb + Vb, Ub); + int Yb = Y.remaining(); + int Ub = U.remaining(); + int Vb = V.remaining(); - Bitmap bitmapRaw = Bitmap.createBitmap(imageWidth, imageHeight, Bitmap.Config.ARGB_8888); - Allocation bmData = renderScriptNV21ToRGBA888( - mRegistrar.context(), - imageWidth, - imageHeight, - data); - bmData.copyTo(bitmapRaw); + byte[] data = new byte[Yb + Ub + Vb]; - Matrix matrix = new Matrix(); - matrix.postRotate(rotation); - bitmapRaw = Bitmap.createBitmap(bitmapRaw, 0, 0, bitmapRaw.getWidth(), bitmapRaw.getHeight(), matrix, true); + Y.get(data, 0, Yb); + V.get(data, Yb, Vb); + U.get(data, Yb + Vb, Ub); - return feedInputTensor(bitmapRaw, mean, std); - } + Bitmap bitmapRaw = Bitmap.createBitmap(imageWidth, imageHeight, Bitmap.Config.ARGB_8888); + Allocation bmData = renderScriptNV21ToRGBA888( + registrar.context(), + imageWidth, + imageHeight, + data); + bmData.copyTo(bitmapRaw); - public Allocation renderScriptNV21ToRGBA888(Context context, int width, int height, byte[] nv21) { - // https://stackoverflow.com/a/36409748 - RenderScript rs = RenderScript.create(context); - ScriptIntrinsicYuvToRGB yuvToRgbIntrinsic = ScriptIntrinsicYuvToRGB.create(rs, Element.U8_4(rs)); + Matrix matrix = new Matrix(); + matrix.postRotate(rotation); + bitmapRaw = Bitmap.createBitmap(bitmapRaw, 0, 0, bitmapRaw.getWidth(), bitmapRaw.getHeight(), matrix, true); - Type.Builder yuvType = new Type.Builder(rs, Element.U8(rs)).setX(nv21.length); - Allocation in = Allocation.createTyped(rs, yuvType.create(), Allocation.USAGE_SCRIPT); + return feedInputTensor(bitmapRaw, mean, std); + } - Type.Builder rgbaType = new Type.Builder(rs, Element.RGBA_8888(rs)).setX(width).setY(height); - Allocation out = Allocation.createTyped(rs, rgbaType.create(), Allocation.USAGE_SCRIPT); + public Allocation renderScriptNV21ToRGBA888(Context context, int width, int height, byte[] nv21) { + // https://stackoverflow.com/a/36409748 + RenderScript rs = RenderScript.create(context); + ScriptIntrinsicYuvToRGB yuvToRgbIntrinsic = ScriptIntrinsicYuvToRGB.create(rs, Element.U8_4(rs)); - in.copyFrom(nv21); + Type.Builder yuvType = new Type.Builder(rs, Element.U8(rs)).setX(nv21.length); + Allocation in = Allocation.createTyped(rs, yuvType.create(), Allocation.USAGE_SCRIPT); - yuvToRgbIntrinsic.setInput(in); - yuvToRgbIntrinsic.forEach(out); - return out; - } + Type.Builder rgbaType = new Type.Builder(rs, Element.RGBA_8888(rs)).setX(width).setY(height); + Allocation out = Allocation.createTyped(rs, rgbaType.create(), Allocation.USAGE_SCRIPT); - private abstract class TfliteTask extends AsyncTask { - Result result; - boolean asynch; + in.copyFrom(nv21); - TfliteTask(HashMap args, Result result) { - if (tfLiteBusy) throw new RuntimeException("Interpreter busy"); - else tfLiteBusy = true; - Object asynch = args.get("asynch"); - this.asynch = asynch == null ? false : (boolean) asynch; - this.result = result; + yuvToRgbIntrinsic.setInput(in); + yuvToRgbIntrinsic.forEach(out); + return out; } - abstract void runTflite(); - abstract void onRunTfliteDone(); + private abstract class TfliteTask extends AsyncTask { + Result result; + boolean asynch; - public void executeTfliteTask() { - if (asynch) execute(); - else { - runTflite(); - tfLiteBusy = false; - onRunTfliteDone(); - } - } + TfliteTask(HashMap args, Result result) { + if (tfLiteBusy) throw new RuntimeException("Interpreter busy"); + else tfLiteBusy = true; + Object asynch = args.get("asynch"); + this.asynch = asynch == null ? false : (boolean) asynch; + this.result = result; + } - protected Void doInBackground(Void... backgroundArguments) { - runTflite(); - return null; - } + abstract void runTflite(); - protected void onPostExecute(Void backgroundResult) { - tfLiteBusy = false; - onRunTfliteDone(); - } - } + abstract void onRunTfliteDone(); - private class RunModelOnImage extends TfliteTask { - int NUM_RESULTS; - float THRESHOLD; - ByteBuffer input; - long startTime; - - RunModelOnImage(HashMap args, Result result) throws IOException { - super(args, result); - - String path = args.get("path").toString(); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - NUM_RESULTS = (int) args.get("numResults"); - double threshold = (double) args.get("threshold"); - THRESHOLD = (float) threshold; - - startTime = SystemClock.uptimeMillis(); - input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); - } + public void executeTfliteTask() { + if (asynch) execute(); + else { + runTflite(); + tfLiteBusy = false; + onRunTfliteDone(); + } + } - protected void runTflite() { - tfLite.run(input, labelProb); - } + protected Void doInBackground(Void... backgroundArguments) { + runTflite(); + return null; + } - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - result.success(GetTopN(NUM_RESULTS, THRESHOLD)); + protected void onPostExecute(Void backgroundResult) { + tfLiteBusy = false; + onRunTfliteDone(); + } } - } - private class RunModelOnBinary extends TfliteTask { - int NUM_RESULTS; - float THRESHOLD; - ByteBuffer imgData; + private class RunModelOnImage extends TfliteTask { + int NUM_RESULTS; + float THRESHOLD; + ByteBuffer input; + long startTime; - RunModelOnBinary(HashMap args, Result result) throws IOException { - super(args, result); + RunModelOnImage(HashMap args, Result result) throws IOException { + super(args, result); - byte[] binary = (byte[]) args.get("binary"); - NUM_RESULTS = (int) args.get("numResults"); - double threshold = (double) args.get("threshold"); - THRESHOLD = (float) threshold; + String path = args.get("path").toString(); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + NUM_RESULTS = (int) args.get("numResults"); + double threshold = (double) args.get("threshold"); + THRESHOLD = (float) threshold; - imgData = ByteBuffer.wrap(binary); - } + startTime = SystemClock.uptimeMillis(); + input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); + } - protected void runTflite() { - tfLite.run(imgData, labelProb); - } + protected void runTflite() { + tfLite.run(input, labelProb); + } - protected void onRunTfliteDone() { - result.success(GetTopN(NUM_RESULTS, THRESHOLD)); + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + result.success(GetTopN(NUM_RESULTS, THRESHOLD)); + } } - } - private class RunModelOnFrame extends TfliteTask { - int NUM_RESULTS; - float THRESHOLD; - long startTime; - ByteBuffer imgData; - - RunModelOnFrame(HashMap args, Result result) throws IOException { - super(args, result); - - List bytesList = (ArrayList) args.get("bytesList"); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - int imageHeight = (int) (args.get("imageHeight")); - int imageWidth = (int) (args.get("imageWidth")); - int rotation = (int) (args.get("rotation")); - NUM_RESULTS = (int) args.get("numResults"); - double threshold = (double) args.get("threshold"); - THRESHOLD = (float) threshold; - - startTime = SystemClock.uptimeMillis(); - - imgData = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); - } + private class RunModelOnBinary extends TfliteTask { + int NUM_RESULTS; + float THRESHOLD; + ByteBuffer imgData; - protected void runTflite() { - tfLite.run(imgData, labelProb); - } + RunModelOnBinary(HashMap args, Result result) throws IOException { + super(args, result); - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - result.success(GetTopN(NUM_RESULTS, THRESHOLD)); - } - } + byte[] binary = (byte[]) args.get("binary"); + NUM_RESULTS = (int) args.get("numResults"); + double threshold = (double) args.get("threshold"); + THRESHOLD = (float) threshold; - void detectObjectOnImage(HashMap args, Result result) throws IOException { - String path = args.get("path").toString(); - String model = args.get("model").toString(); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - double threshold = (double) args.get("threshold"); - float THRESHOLD = (float) threshold; - List ANCHORS = (ArrayList) args.get("anchors"); - int BLOCK_SIZE = (int) args.get("blockSize"); - int NUM_BOXES_PER_BLOCK = (int) args.get("numBoxesPerBlock"); - int NUM_RESULTS_PER_CLASS = (int) args.get("numResultsPerClass"); - - ByteBuffer imgData = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); - - if (model.equals("SSDMobileNet")) { - new RunSSDMobileNet(args, imgData, NUM_RESULTS_PER_CLASS, THRESHOLD, result).executeTfliteTask(); - } else { - new RunYOLO(args, imgData, BLOCK_SIZE, NUM_BOXES_PER_BLOCK, ANCHORS, THRESHOLD, NUM_RESULTS_PER_CLASS, result).executeTfliteTask(); - } - } + imgData = ByteBuffer.wrap(binary); + } - void detectObjectOnBinary(HashMap args, Result result) throws IOException { - byte[] binary = (byte[]) args.get("binary"); - String model = args.get("model").toString(); - double threshold = (double) args.get("threshold"); - float THRESHOLD = (float) threshold; - List ANCHORS = (ArrayList) args.get("anchors"); - int BLOCK_SIZE = (int) args.get("blockSize"); - int NUM_BOXES_PER_BLOCK = (int) args.get("numBoxesPerBlock"); - int NUM_RESULTS_PER_CLASS = (int) args.get("numResultsPerClass"); - - ByteBuffer imgData = ByteBuffer.wrap(binary); - - if (model.equals("SSDMobileNet")) { - new RunSSDMobileNet(args, imgData, NUM_RESULTS_PER_CLASS, THRESHOLD, result).executeTfliteTask(); - } else { - new RunYOLO(args, imgData, BLOCK_SIZE, NUM_BOXES_PER_BLOCK, ANCHORS, THRESHOLD, NUM_RESULTS_PER_CLASS, result).executeTfliteTask(); - } - } + protected void runTflite() { + tfLite.run(imgData, labelProb); + } - void detectObjectOnFrame(HashMap args, Result result) throws IOException { - List bytesList = (ArrayList) args.get("bytesList"); - String model = args.get("model").toString(); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - int imageHeight = (int) (args.get("imageHeight")); - int imageWidth = (int) (args.get("imageWidth")); - int rotation = (int) (args.get("rotation")); - double threshold = (double) args.get("threshold"); - float THRESHOLD = (float) threshold; - int NUM_RESULTS_PER_CLASS = (int) args.get("numResultsPerClass"); - - List ANCHORS = (ArrayList) args.get("anchors"); - int BLOCK_SIZE = (int) args.get("blockSize"); - int NUM_BOXES_PER_BLOCK = (int) args.get("numBoxesPerBlock"); - - ByteBuffer imgData = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); - - if (model.equals("SSDMobileNet")) { - new RunSSDMobileNet(args, imgData, NUM_RESULTS_PER_CLASS, THRESHOLD, result).executeTfliteTask(); - } else { - new RunYOLO(args, imgData, BLOCK_SIZE, NUM_BOXES_PER_BLOCK, ANCHORS, THRESHOLD, NUM_RESULTS_PER_CLASS, result).executeTfliteTask(); + protected void onRunTfliteDone() { + result.success(GetTopN(NUM_RESULTS, THRESHOLD)); + } } - } - private class RunSSDMobileNet extends TfliteTask { - int num; - int numResultsPerClass; - float threshold; - float[][][] outputLocations; - float[][] outputClasses; - float[][] outputScores; - float[] numDetections = new float[1]; - Object[] inputArray; - Map outputMap = new HashMap<>(); - long startTime; - - RunSSDMobileNet(HashMap args, ByteBuffer imgData, int numResultsPerClass, float threshold, Result result) { - super(args, result); - this.num = tfLite.getOutputTensor(0).shape()[1]; - this.numResultsPerClass = numResultsPerClass; - this.threshold = threshold; - this.outputLocations = new float[1][num][4]; - this.outputClasses = new float[1][num]; - this.outputScores = new float[1][num]; - this.inputArray = new Object[]{imgData}; - - outputMap.put(0, outputLocations); - outputMap.put(1, outputClasses); - outputMap.put(2, outputScores); - outputMap.put(3, numDetections); - - startTime = SystemClock.uptimeMillis(); - } + private class RunModelOnFrame extends TfliteTask { + int NUM_RESULTS; + float THRESHOLD; + long startTime; + ByteBuffer imgData; - protected void runTflite() { - tfLite.runForMultipleInputsOutputs(inputArray, outputMap); - } + RunModelOnFrame(HashMap args, Result result) throws IOException { + super(args, result); - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + List bytesList = (ArrayList) args.get("bytesList"); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + int imageHeight = (int) (args.get("imageHeight")); + int imageWidth = (int) (args.get("imageWidth")); + int rotation = (int) (args.get("rotation")); + NUM_RESULTS = (int) args.get("numResults"); + double threshold = (double) args.get("threshold"); + THRESHOLD = (float) threshold; - Map counters = new HashMap<>(); - final List> results = new ArrayList<>(); + startTime = SystemClock.uptimeMillis(); - for (int i = 0; i < numDetections[0]; ++i) { - if (outputScores[0][i] < threshold) continue; + imgData = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); + } - String detectedClass = labels.get((int) outputClasses[0][i] + 1); + protected void runTflite() { + tfLite.run(imgData, labelProb); + } - if (counters.get(detectedClass) == null) { - counters.put(detectedClass, 1); + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + result.success(GetTopN(NUM_RESULTS, THRESHOLD)); + } + } + + void detectObjectOnImage(HashMap args, Result result) throws IOException { + String path = args.get("path").toString(); + String model = args.get("model").toString(); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + double threshold = (double) args.get("threshold"); + float THRESHOLD = (float) threshold; + List ANCHORS = (ArrayList) args.get("anchors"); + int BLOCK_SIZE = (int) args.get("blockSize"); + int NUM_BOXES_PER_BLOCK = (int) args.get("numBoxesPerBlock"); + int NUM_RESULTS_PER_CLASS = (int) args.get("numResultsPerClass"); + + ByteBuffer imgData = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); + + if (model.equals("SSDMobileNet")) { + new RunSSDMobileNet(args, imgData, NUM_RESULTS_PER_CLASS, THRESHOLD, result).executeTfliteTask(); } else { - int count = counters.get(detectedClass); - if (count >= numResultsPerClass) { - continue; - } else { - counters.put(detectedClass, count + 1); - } - } - - Map rect = new HashMap<>(); - float ymin = Math.max(0, outputLocations[0][i][0]); - float xmin = Math.max(0, outputLocations[0][i][1]); - float ymax = outputLocations[0][i][2]; - float xmax = outputLocations[0][i][3]; - rect.put("x", xmin); - rect.put("y", ymin); - rect.put("w", Math.min(1 - xmin, xmax - xmin)); - rect.put("h", Math.min(1 - ymin, ymax - ymin)); - - Map ret = new HashMap<>(); - ret.put("rect", rect); - ret.put("confidenceInClass", outputScores[0][i]); - ret.put("detectedClass", detectedClass); - - results.add(ret); - } - - result.success(results); + new RunYOLO(args, imgData, BLOCK_SIZE, NUM_BOXES_PER_BLOCK, ANCHORS, THRESHOLD, NUM_RESULTS_PER_CLASS, result).executeTfliteTask(); + } } - } - private class RunYOLO extends TfliteTask { - ByteBuffer imgData; - int blockSize; - int numBoxesPerBlock; - List anchors; - float threshold; - int numResultsPerClass; - long startTime; - int gridSize; - int numClasses; - final float[][][][] output; - - RunYOLO(HashMap args, - ByteBuffer imgData, - int blockSize, - int numBoxesPerBlock, - List anchors, - float threshold, - int numResultsPerClass, - Result result) { - super(args, result); - this.imgData = imgData; - this.blockSize = blockSize; - this.numBoxesPerBlock = numBoxesPerBlock; - this.anchors = anchors; - this.threshold = threshold; - this.numResultsPerClass = numResultsPerClass; - this.startTime = SystemClock.uptimeMillis(); - - Tensor tensor = tfLite.getInputTensor(0); - inputSize = tensor.shape()[1]; - - this.gridSize = inputSize / blockSize; - this.numClasses = labels.size(); - this.output = new float[1][gridSize][gridSize][(numClasses + 5) * numBoxesPerBlock]; + void detectObjectOnBinary(HashMap args, Result result) throws IOException { + byte[] binary = (byte[]) args.get("binary"); + String model = args.get("model").toString(); + double threshold = (double) args.get("threshold"); + float THRESHOLD = (float) threshold; + List ANCHORS = (ArrayList) args.get("anchors"); + int BLOCK_SIZE = (int) args.get("blockSize"); + int NUM_BOXES_PER_BLOCK = (int) args.get("numBoxesPerBlock"); + int NUM_RESULTS_PER_CLASS = (int) args.get("numResultsPerClass"); + + ByteBuffer imgData = ByteBuffer.wrap(binary); + + if (model.equals("SSDMobileNet")) { + new RunSSDMobileNet(args, imgData, NUM_RESULTS_PER_CLASS, THRESHOLD, result).executeTfliteTask(); + } else { + new RunYOLO(args, imgData, BLOCK_SIZE, NUM_BOXES_PER_BLOCK, ANCHORS, THRESHOLD, NUM_RESULTS_PER_CLASS, result).executeTfliteTask(); + } } - protected void runTflite() { - tfLite.run(imgData, output); + void detectObjectOnFrame(HashMap args, Result result) throws IOException { + List bytesList = (ArrayList) args.get("bytesList"); + String model = args.get("model").toString(); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + int imageHeight = (int) (args.get("imageHeight")); + int imageWidth = (int) (args.get("imageWidth")); + int rotation = (int) (args.get("rotation")); + double threshold = (double) args.get("threshold"); + float THRESHOLD = (float) threshold; + int NUM_RESULTS_PER_CLASS = (int) args.get("numResultsPerClass"); + + List ANCHORS = (ArrayList) args.get("anchors"); + int BLOCK_SIZE = (int) args.get("blockSize"); + int NUM_BOXES_PER_BLOCK = (int) args.get("numBoxesPerBlock"); + + ByteBuffer imgData = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); + + if (model.equals("SSDMobileNet")) { + new RunSSDMobileNet(args, imgData, NUM_RESULTS_PER_CLASS, THRESHOLD, result).executeTfliteTask(); + } else { + new RunYOLO(args, imgData, BLOCK_SIZE, NUM_BOXES_PER_BLOCK, ANCHORS, THRESHOLD, NUM_RESULTS_PER_CLASS, result).executeTfliteTask(); + } } - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + private class RunSSDMobileNet extends TfliteTask { + int num; + int numResultsPerClass; + float threshold; + float[][][] outputLocations; + float[][] outputClasses; + float[][] outputScores; + float[] numDetections = new float[1]; + Object[] inputArray; + Map outputMap = new HashMap<>(); + long startTime; + + RunSSDMobileNet(HashMap args, ByteBuffer imgData, int numResultsPerClass, float threshold, Result result) { + super(args, result); + this.num = tfLite.getOutputTensor(0).shape()[1]; + this.numResultsPerClass = numResultsPerClass; + this.threshold = threshold; + this.outputLocations = new float[1][num][4]; + this.outputClasses = new float[1][num]; + this.outputScores = new float[1][num]; + this.inputArray = new Object[]{imgData}; + + outputMap.put(0, outputLocations); + outputMap.put(1, outputClasses); + outputMap.put(2, outputScores); + outputMap.put(3, numDetections); + + startTime = SystemClock.uptimeMillis(); + } - PriorityQueue> pq = - new PriorityQueue<>( - 1, - new Comparator>() { - @Override - public int compare(Map lhs, Map rhs) { - return Float.compare((float) rhs.get("confidenceInClass"), (float) lhs.get("confidenceInClass")); - } - }); + protected void runTflite() { + tfLite.runForMultipleInputsOutputs(inputArray, outputMap); + } - for (int y = 0; y < gridSize; ++y) { - for (int x = 0; x < gridSize; ++x) { - for (int b = 0; b < numBoxesPerBlock; ++b) { - final int offset = (numClasses + 5) * b; + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - final float confidence = sigmoid(output[0][y][x][offset + 4]); + Map counters = new HashMap<>(); + final List> results = new ArrayList<>(); - final float[] classes = new float[numClasses]; - for (int c = 0; c < numClasses; ++c) { - classes[c] = output[0][y][x][offset + 5 + c]; - } - softmax(classes); - - int detectedClass = -1; - float maxClass = 0; - for (int c = 0; c < numClasses; ++c) { - if (classes[c] > maxClass) { - detectedClass = c; - maxClass = classes[c]; - } + for (int i = 0; i < numDetections[0]; ++i) { + if (outputScores[0][i] < threshold) continue; + + String detectedClass = labels.get((int) outputClasses[0][i] + 1); + + if (counters.get(detectedClass) == null) { + counters.put(detectedClass, 1); + } else { + int count = counters.get(detectedClass); + if (count >= numResultsPerClass) { + continue; + } else { + counters.put(detectedClass, count + 1); + } + } + + Map rect = new HashMap<>(); + float ymin = Math.max(0, outputLocations[0][i][0]); + float xmin = Math.max(0, outputLocations[0][i][1]); + float ymax = outputLocations[0][i][2]; + float xmax = outputLocations[0][i][3]; + rect.put("x", xmin); + rect.put("y", ymin); + rect.put("w", Math.min(1 - xmin, xmax - xmin)); + rect.put("h", Math.min(1 - ymin, ymax - ymin)); + + Map ret = new HashMap<>(); + ret.put("rect", rect); + ret.put("confidenceInClass", outputScores[0][i]); + ret.put("detectedClass", detectedClass); + + results.add(ret); } - final float confidenceInClass = maxClass * confidence; - if (confidenceInClass > threshold) { - final float xPos = (x + sigmoid(output[0][y][x][offset + 0])) * blockSize; - final float yPos = (y + sigmoid(output[0][y][x][offset + 1])) * blockSize; + result.success(results); + } + } - final float w = (float) (Math.exp(output[0][y][x][offset + 2]) * anchors.get(2 * b + 0)) * blockSize; - final float h = (float) (Math.exp(output[0][y][x][offset + 3]) * anchors.get(2 * b + 1)) * blockSize; + private class RunYOLO extends TfliteTask { + ByteBuffer imgData; + int blockSize; + int numBoxesPerBlock; + List anchors; + float threshold; + int numResultsPerClass; + long startTime; + int gridSize; + int numClasses; + final float[][][][] output; + + RunYOLO(HashMap args, + ByteBuffer imgData, + int blockSize, + int numBoxesPerBlock, + List anchors, + float threshold, + int numResultsPerClass, + Result result) { + super(args, result); + this.imgData = imgData; + this.blockSize = blockSize; + this.numBoxesPerBlock = numBoxesPerBlock; + this.anchors = anchors; + this.threshold = threshold; + this.numResultsPerClass = numResultsPerClass; + this.startTime = SystemClock.uptimeMillis(); + + Tensor tensor = tfLite.getInputTensor(0); + inputSize = tensor.shape()[1]; + + this.gridSize = inputSize / blockSize; + this.numClasses = labels.size(); + this.output = new float[1][gridSize][gridSize][(numClasses + 5) * numBoxesPerBlock]; + } - final float xmin = Math.max(0, (xPos - w / 2) / inputSize); - final float ymin = Math.max(0, (yPos - h / 2) / inputSize); + protected void runTflite() { + tfLite.run(imgData, output); + } - Map rect = new HashMap<>(); - rect.put("x", xmin); - rect.put("y", ymin); - rect.put("w", Math.min(1 - xmin, w / inputSize)); - rect.put("h", Math.min(1 - ymin, h / inputSize)); + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + + PriorityQueue> pq = + new PriorityQueue<>( + 1, + new Comparator>() { + @Override + public int compare(Map lhs, Map rhs) { + return Float.compare((float) rhs.get("confidenceInClass"), (float) lhs.get("confidenceInClass")); + } + }); + + for (int y = 0; y < gridSize; ++y) { + for (int x = 0; x < gridSize; ++x) { + for (int b = 0; b < numBoxesPerBlock; ++b) { + final int offset = (numClasses + 5) * b; + + final float confidence = sigmoid(output[0][y][x][offset + 4]); + + final float[] classes = new float[numClasses]; + for (int c = 0; c < numClasses; ++c) { + classes[c] = output[0][y][x][offset + 5 + c]; + } + softmax(classes); + + int detectedClass = -1; + float maxClass = 0; + for (int c = 0; c < numClasses; ++c) { + if (classes[c] > maxClass) { + detectedClass = c; + maxClass = classes[c]; + } + } + + final float confidenceInClass = maxClass * confidence; + if (confidenceInClass > threshold) { + final float xPos = (x + sigmoid(output[0][y][x][offset + 0])) * blockSize; + final float yPos = (y + sigmoid(output[0][y][x][offset + 1])) * blockSize; + + final float w = (float) (Math.exp(output[0][y][x][offset + 2]) * anchors.get(2 * b + 0)) * blockSize; + final float h = (float) (Math.exp(output[0][y][x][offset + 3]) * anchors.get(2 * b + 1)) * blockSize; + + final float xmin = Math.max(0, (xPos - w / 2) / inputSize); + final float ymin = Math.max(0, (yPos - h / 2) / inputSize); + + Map rect = new HashMap<>(); + rect.put("x", xmin); + rect.put("y", ymin); + rect.put("w", Math.min(1 - xmin, w / inputSize)); + rect.put("h", Math.min(1 - ymin, h / inputSize)); + + Map ret = new HashMap<>(); + ret.put("rect", rect); + ret.put("confidenceInClass", confidenceInClass); + ret.put("detectedClass", labels.get(detectedClass)); + + pq.add(ret); + } + } + } + } - Map ret = new HashMap<>(); - ret.put("rect", rect); - ret.put("confidenceInClass", confidenceInClass); - ret.put("detectedClass", labels.get(detectedClass)); + Map counters = new HashMap<>(); + List> results = new ArrayList<>(); + + for (int i = 0; i < pq.size(); ++i) { + Map ret = pq.poll(); + String detectedClass = ret.get("detectedClass").toString(); + + if (counters.get(detectedClass) == null) { + counters.put(detectedClass, 1); + } else { + int count = counters.get(detectedClass); + if (count >= numResultsPerClass) { + continue; + } else { + counters.put(detectedClass, count + 1); + } + } + results.add(ret); + } + result.success(results); + } + } - pq.add(ret); + private class RunPix2PixOnImage extends TfliteTask { + String path, outputType; + float IMAGE_MEAN, IMAGE_STD; + long startTime; + ByteBuffer input, output; + + RunPix2PixOnImage(HashMap args, Result result) throws IOException { + super(args, result); + path = args.get("path").toString(); + double mean = (double) (args.get("imageMean")); + IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + IMAGE_STD = (float) std; + + outputType = args.get("outputType").toString(); + startTime = SystemClock.uptimeMillis(); + input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); + output = ByteBuffer.allocateDirect(input.limit()); + output.order(ByteOrder.nativeOrder()); + if (input.limit() == 0) { + result.error("Unexpected input position, bad file?", null, null); + return; + } + if (output.position() != 0) { + result.error("Unexpected output position", null, null); + return; } - } } - } - Map counters = new HashMap<>(); - List> results = new ArrayList<>(); + protected void runTflite() { + tfLite.run(input, output); + } - for (int i = 0; i < pq.size(); ++i) { - Map ret = pq.poll(); - String detectedClass = ret.get("detectedClass").toString(); + protected void onRunTfliteDone() { + Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime)); + if (output.position() != input.limit()) { + result.error("Mismatching input/output position", null, null); + return; + } - if (counters.get(detectedClass) == null) { - counters.put(detectedClass, 1); - } else { - int count = counters.get(detectedClass); - if (count >= numResultsPerClass) { - continue; - } else { - counters.put(detectedClass, count + 1); - } - } - results.add(ret); - } - result.success(results); - } - } + output.flip(); + Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD); - private class RunPix2PixOnImage extends TfliteTask { - String path, outputType; - float IMAGE_MEAN, IMAGE_STD; - long startTime; - ByteBuffer input, output; - - RunPix2PixOnImage(HashMap args, Result result) throws IOException { - super(args, result); - path = args.get("path").toString(); - double mean = (double) (args.get("imageMean")); - IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - IMAGE_STD = (float) std; - - outputType = args.get("outputType").toString(); - startTime = SystemClock.uptimeMillis(); - input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); - output = ByteBuffer.allocateDirect(input.limit()); - output.order(ByteOrder.nativeOrder()); - if (input.limit() == 0) { - result.error("Unexpected input position, bad file?", null, null); - return; - } - if (output.position() != 0) { - result.error("Unexpected output position", null, null); - return; - } + if (outputType.equals("png")) { + result.success(compressPNG(bitmapRaw)); + } else { + result.success(bitmapRaw); + } + } } - protected void runTflite() { - tfLite.run(input, output); - } + private class RunPix2PixOnBinary extends TfliteTask { + long startTime; + String outputType; + ByteBuffer input, output; - protected void onRunTfliteDone() { - Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime)); - if (output.position() != input.limit()) { - result.error("Mismatching input/output position", null, null); - return; - } - - output.flip(); - Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD); - - if (outputType.equals("png")) { - result.success(compressPNG(bitmapRaw)); - } else { - result.success(bitmapRaw); - } - } - } + RunPix2PixOnBinary(HashMap args, Result result) throws IOException { + super(args, result); + byte[] binary = (byte[]) args.get("binary"); + outputType = args.get("outputType").toString(); + startTime = SystemClock.uptimeMillis(); + input = ByteBuffer.wrap(binary); + output = ByteBuffer.allocateDirect(input.limit()); + output.order(ByteOrder.nativeOrder()); - private class RunPix2PixOnBinary extends TfliteTask { - long startTime; - String outputType; - ByteBuffer input, output; - - RunPix2PixOnBinary(HashMap args, Result result) throws IOException { - super(args, result); - byte[] binary = (byte[]) args.get("binary"); - outputType = args.get("outputType").toString(); - startTime = SystemClock.uptimeMillis(); - input = ByteBuffer.wrap(binary); - output = ByteBuffer.allocateDirect(input.limit()); - output.order(ByteOrder.nativeOrder()); - - if (input.limit() == 0) { - result.error("Unexpected input position, bad file?", null, null); - return; - } - if (output.position() != 0) { - result.error("Unexpected output position", null, null); - return; - } - } + if (input.limit() == 0) { + result.error("Unexpected input position, bad file?", null, null); + return; + } + if (output.position() != 0) { + result.error("Unexpected output position", null, null); + return; + } + } - protected void runTflite() { - tfLite.run(input, output); - } + protected void runTflite() { + tfLite.run(input, output); + } - protected void onRunTfliteDone() { - Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime)); - if (output.position() != input.limit()) { - result.error("Mismatching input/output position", null, null); - return; - } + protected void onRunTfliteDone() { + Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime)); + if (output.position() != input.limit()) { + result.error("Mismatching input/output position", null, null); + return; + } - output.flip(); - result.success(output.array()); + output.flip(); + result.success(output.array()); + } } - } - private class RunPix2PixOnFrame extends TfliteTask { - long startTime; - String outputType; - float IMAGE_MEAN, IMAGE_STD; - ByteBuffer input, output; - - RunPix2PixOnFrame(HashMap args, Result result) throws IOException { - super(args, result); - List bytesList = (ArrayList) args.get("bytesList"); - double mean = (double) (args.get("imageMean")); - IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - IMAGE_STD = (float) std; - int imageHeight = (int) (args.get("imageHeight")); - int imageWidth = (int) (args.get("imageWidth")); - int rotation = (int) (args.get("rotation")); - - outputType = args.get("outputType").toString(); - startTime = SystemClock.uptimeMillis(); - input = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); - output = ByteBuffer.allocateDirect(input.limit()); - output.order(ByteOrder.nativeOrder()); - - if (input.limit() == 0) { - result.error("Unexpected input position, bad file?", null, null); - return; - } - if (output.position() != 0) { - result.error("Unexpected output position", null, null); - return; - } - } + private class RunPix2PixOnFrame extends TfliteTask { + long startTime; + String outputType; + float IMAGE_MEAN, IMAGE_STD; + ByteBuffer input, output; + + RunPix2PixOnFrame(HashMap args, Result result) throws IOException { + super(args, result); + List bytesList = (ArrayList) args.get("bytesList"); + double mean = (double) (args.get("imageMean")); + IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + IMAGE_STD = (float) std; + int imageHeight = (int) (args.get("imageHeight")); + int imageWidth = (int) (args.get("imageWidth")); + int rotation = (int) (args.get("rotation")); + + outputType = args.get("outputType").toString(); + startTime = SystemClock.uptimeMillis(); + input = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); + output = ByteBuffer.allocateDirect(input.limit()); + output.order(ByteOrder.nativeOrder()); + + if (input.limit() == 0) { + result.error("Unexpected input position, bad file?", null, null); + return; + } + if (output.position() != 0) { + result.error("Unexpected output position", null, null); + return; + } + } - protected void runTflite() { - tfLite.run(input, output); - } + protected void runTflite() { + tfLite.run(input, output); + } + + protected void onRunTfliteDone() { + Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime)); + if (output.position() != input.limit()) { + result.error("Mismatching input/output position", null, null); + return; + } + + output.flip(); + Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD); - protected void onRunTfliteDone() { - Log.v("time", "Generating took " + (SystemClock.uptimeMillis() - startTime)); - if (output.position() != input.limit()) { - result.error("Mismatching input/output position", null, null); - return; - } - - output.flip(); - Bitmap bitmapRaw = feedOutput(output, IMAGE_MEAN, IMAGE_STD); - - if (outputType.equals("png")) { - result.success(compressPNG(bitmapRaw)); - } else { - result.success(bitmapRaw); - } + if (outputType.equals("png")) { + result.success(compressPNG(bitmapRaw)); + } else { + result.success(bitmapRaw); + } + } } - } - private class RunSegmentationOnImage extends TfliteTask { - List labelColors; - String outputType; - long startTime; - ByteBuffer input, output; + private class RunSegmentationOnImage extends TfliteTask { + List labelColors; + String outputType; + long startTime; + ByteBuffer input, output; - RunSegmentationOnImage(HashMap args, Result result) throws IOException { - super(args, result); + RunSegmentationOnImage(HashMap args, Result result) throws IOException { + super(args, result); - String path = args.get("path").toString(); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; + String path = args.get("path").toString(); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; - labelColors = (ArrayList) args.get("labelColors"); - outputType = args.get("outputType").toString(); + labelColors = (ArrayList) args.get("labelColors"); + outputType = args.get("outputType").toString(); - startTime = SystemClock.uptimeMillis(); - input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); - output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); - output.order(ByteOrder.nativeOrder()); - } + startTime = SystemClock.uptimeMillis(); + input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); + output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); + output.order(ByteOrder.nativeOrder()); + } - protected void runTflite() { - tfLite.run(input, output); - } + protected void runTflite() { + tfLite.run(input, output); + } - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - if (input.limit() == 0) { - result.error("Unexpected input position, bad file?", null, null); - return; - } - if (output.position() != output.limit()) { - result.error("Unexpected output position", null, null); - return; - } - output.flip(); + if (input.limit() == 0) { + result.error("Unexpected input position, bad file?", null, null); + return; + } + if (output.position() != output.limit()) { + result.error("Unexpected output position", null, null); + return; + } + output.flip(); - result.success(fetchArgmax(output, labelColors, outputType)); + result.success(fetchArgmax(output, labelColors, outputType)); + } } - } - private class RunSegmentationOnBinary extends TfliteTask { - List labelColors; - String outputType; - long startTime; - ByteBuffer input, output; + private class RunSegmentationOnBinary extends TfliteTask { + List labelColors; + String outputType; + long startTime; + ByteBuffer input, output; - RunSegmentationOnBinary(HashMap args, Result result) throws IOException { - super(args, result); + RunSegmentationOnBinary(HashMap args, Result result) throws IOException { + super(args, result); - byte[] binary = (byte[]) args.get("binary"); - labelColors = (ArrayList) args.get("labelColors"); - outputType = args.get("outputType").toString(); + byte[] binary = (byte[]) args.get("binary"); + labelColors = (ArrayList) args.get("labelColors"); + outputType = args.get("outputType").toString(); - startTime = SystemClock.uptimeMillis(); - input = ByteBuffer.wrap(binary); - output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); - output.order(ByteOrder.nativeOrder()); - } + startTime = SystemClock.uptimeMillis(); + input = ByteBuffer.wrap(binary); + output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); + output.order(ByteOrder.nativeOrder()); + } - protected void runTflite() { - tfLite.run(input, output); - } + protected void runTflite() { + tfLite.run(input, output); + } - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - if (input.limit() == 0) { - result.error("Unexpected input position, bad file?", null, null); - return; - } - if (output.position() != output.limit()) { - result.error("Unexpected output position", null, null); - return; - } - output.flip(); + if (input.limit() == 0) { + result.error("Unexpected input position, bad file?", null, null); + return; + } + if (output.position() != output.limit()) { + result.error("Unexpected output position", null, null); + return; + } + output.flip(); - result.success(fetchArgmax(output, labelColors, outputType)); + result.success(fetchArgmax(output, labelColors, outputType)); + } } - } - private class RunSegmentationOnFrame extends TfliteTask { - List labelColors; - String outputType; - long startTime; - ByteBuffer input, output; - - RunSegmentationOnFrame(HashMap args, Result result) throws IOException { - super(args, result); - - List bytesList = (ArrayList) args.get("bytesList"); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - int imageHeight = (int) (args.get("imageHeight")); - int imageWidth = (int) (args.get("imageWidth")); - int rotation = (int) (args.get("rotation")); - labelColors = (ArrayList) args.get("labelColors"); - outputType = args.get("outputType").toString(); - - startTime = SystemClock.uptimeMillis(); - input = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); - output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); - output.order(ByteOrder.nativeOrder()); - } + private class RunSegmentationOnFrame extends TfliteTask { + List labelColors; + String outputType; + long startTime; + ByteBuffer input, output; + + RunSegmentationOnFrame(HashMap args, Result result) throws IOException { + super(args, result); + + List bytesList = (ArrayList) args.get("bytesList"); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + int imageHeight = (int) (args.get("imageHeight")); + int imageWidth = (int) (args.get("imageWidth")); + int rotation = (int) (args.get("rotation")); + labelColors = (ArrayList) args.get("labelColors"); + outputType = args.get("outputType").toString(); + + startTime = SystemClock.uptimeMillis(); + input = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); + output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); + output.order(ByteOrder.nativeOrder()); + } - protected void runTflite() { - tfLite.run(input, output); - } + protected void runTflite() { + tfLite.run(input, output); + } - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - if (input.limit() == 0) { - result.error("Unexpected input position, bad file?", null, null); - return; - } - if (output.position() != output.limit()) { - result.error("Unexpected output position", null, null); - return; - } - output.flip(); + if (input.limit() == 0) { + result.error("Unexpected input position, bad file?", null, null); + return; + } + if (output.position() != output.limit()) { + result.error("Unexpected output position", null, null); + return; + } + output.flip(); - result.success(fetchArgmax(output, labelColors, outputType)); + result.success(fetchArgmax(output, labelColors, outputType)); + } } - } - byte[] fetchArgmax(ByteBuffer output, List labelColors, String outputType) { - Tensor outputTensor = tfLite.getOutputTensor(0); - int outputBatchSize = outputTensor.shape()[0]; - assert outputBatchSize == 1; - int outputHeight = outputTensor.shape()[1]; - int outputWidth = outputTensor.shape()[2]; - int outputChannels = outputTensor.shape()[3]; + byte[] fetchArgmax(ByteBuffer output, List labelColors, String outputType) { + Tensor outputTensor = tfLite.getOutputTensor(0); + int outputBatchSize = outputTensor.shape()[0]; + assert outputBatchSize == 1; + int outputHeight = outputTensor.shape()[1]; + int outputWidth = outputTensor.shape()[2]; + int outputChannels = outputTensor.shape()[3]; - Bitmap outputArgmax = null; - byte[] outputBytes = new byte[outputWidth * outputHeight * 4]; - if (outputType.equals("png")) { - outputArgmax = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888); - } + Bitmap outputArgmax = null; + byte[] outputBytes = new byte[outputWidth * outputHeight * 4]; + if (outputType.equals("png")) { + outputArgmax = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888); + } - if (outputTensor.dataType() == DataType.FLOAT32) { - for (int i = 0; i < outputHeight; ++i) { - for (int j = 0; j < outputWidth; ++j) { - int maxIndex = 0; - float maxValue = 0.0f; - for (int c = 0; c < outputChannels; ++c) { - float outputValue = output.getFloat(); - if (outputValue > maxValue) { - maxIndex = c; - maxValue = outputValue; + if (outputTensor.dataType() == DataType.FLOAT32) { + for (int i = 0; i < outputHeight; ++i) { + for (int j = 0; j < outputWidth; ++j) { + int maxIndex = 0; + float maxValue = 0.0f; + for (int c = 0; c < outputChannels; ++c) { + float outputValue = output.getFloat(); + if (outputValue > maxValue) { + maxIndex = c; + maxValue = outputValue; + } + } + int labelColor = labelColors.get(maxIndex).intValue(); + if (outputType.equals("png")) { + outputArgmax.setPixel(j, i, labelColor); + } else { + setPixel(outputBytes, i * outputWidth + j, labelColor); + } + } } - } - int labelColor = labelColors.get(maxIndex).intValue(); - if (outputType.equals("png")) { - outputArgmax.setPixel(j, i, labelColor); - } else { - setPixel(outputBytes, i * outputWidth + j, labelColor); - } - } - } - } else { - for (int i = 0; i < outputHeight; ++i) { - for (int j = 0; j < outputWidth; ++j) { - int maxIndex = 0; - int maxValue = 0; - for (int c = 0; c < outputChannels; ++c) { - int outputValue = output.get(); - if (outputValue > maxValue) { - maxIndex = c; - maxValue = outputValue; + } else { + for (int i = 0; i < outputHeight; ++i) { + for (int j = 0; j < outputWidth; ++j) { + int maxIndex = 0; + int maxValue = 0; + for (int c = 0; c < outputChannels; ++c) { + int outputValue = output.get(); + if (outputValue > maxValue) { + maxIndex = c; + maxValue = outputValue; + } + } + int labelColor = labelColors.get(maxIndex).intValue(); + if (outputType.equals("png")) { + outputArgmax.setPixel(j, i, labelColor); + } else { + setPixel(outputBytes, i * outputWidth + j, labelColor); + } + } } - } - int labelColor = labelColors.get(maxIndex).intValue(); - if (outputType.equals("png")) { - outputArgmax.setPixel(j, i, labelColor); - } else { - setPixel(outputBytes, i * outputWidth + j, labelColor); - } - } - } - } - if (outputType.equals("png")) { - return compressPNG(outputArgmax); - } else { - return outputBytes; + } + if (outputType.equals("png")) { + return compressPNG(outputArgmax); + } else { + return outputBytes; + } } - } - void setPixel(byte[] rgba, int index, long color) { - rgba[index * 4] = (byte) ((color >> 16) & 0xFF); - rgba[index * 4 + 1] = (byte) ((color >> 8) & 0xFF); - rgba[index * 4 + 2] = (byte) (color & 0xFF); - rgba[index * 4 + 3] = (byte) ((color >> 24) & 0xFF); - } + void setPixel(byte[] rgba, int index, long color) { + rgba[index * 4] = (byte) ((color >> 16) & 0xFF); + rgba[index * 4 + 1] = (byte) ((color >> 8) & 0xFF); + rgba[index * 4 + 2] = (byte) (color & 0xFF); + rgba[index * 4 + 3] = (byte) ((color >> 24) & 0xFF); + } - byte[] compressPNG(Bitmap bitmap) { - // https://stackoverflow.com/questions/4989182/converting-java-bitmap-to-byte-array#4989543 - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - bitmap.compress(Bitmap.CompressFormat.PNG, 100, stream); - byte[] byteArray = stream.toByteArray(); - // bitmap.recycle(); - return byteArray; - } + byte[] compressPNG(Bitmap bitmap) { + // https://stackoverflow.com/questions/4989182/converting-java-bitmap-to-byte-array#4989543 + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + bitmap.compress(Bitmap.CompressFormat.PNG, 100, stream); + byte[] byteArray = stream.toByteArray(); + // bitmap.recycle(); + return byteArray; + } - void runPoseNetOnImage(HashMap args, Result result) throws IOException { - String path = args.get("path").toString(); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - int numResults = (int) args.get("numResults"); - double threshold = (double) args.get("threshold"); - int nmsRadius = (int) args.get("nmsRadius"); + void runPoseNetOnImage(HashMap args, Result result) throws IOException { + String path = args.get("path").toString(); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + int numResults = (int) args.get("numResults"); + double threshold = (double) args.get("threshold"); + int nmsRadius = (int) args.get("nmsRadius"); - ByteBuffer imgData = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); + ByteBuffer imgData = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD); - new RunPoseNet(args, imgData, numResults, threshold, nmsRadius, result).executeTfliteTask(); - } + new RunPoseNet(args, imgData, numResults, threshold, nmsRadius, result).executeTfliteTask(); + } - void runPoseNetOnBinary(HashMap args, Result result) throws IOException { - byte[] binary = (byte[]) args.get("binary"); - int numResults = (int) args.get("numResults"); - double threshold = (double) args.get("threshold"); - int nmsRadius = (int) args.get("nmsRadius"); + void runPoseNetOnBinary(HashMap args, Result result) throws IOException { + byte[] binary = (byte[]) args.get("binary"); + int numResults = (int) args.get("numResults"); + double threshold = (double) args.get("threshold"); + int nmsRadius = (int) args.get("nmsRadius"); - ByteBuffer imgData = ByteBuffer.wrap(binary); + ByteBuffer imgData = ByteBuffer.wrap(binary); - new RunPoseNet(args, imgData, numResults, threshold, nmsRadius, result).executeTfliteTask(); - } + new RunPoseNet(args, imgData, numResults, threshold, nmsRadius, result).executeTfliteTask(); + } - void runPoseNetOnFrame(HashMap args, Result result) throws IOException { - List bytesList = (ArrayList) args.get("bytesList"); - double mean = (double) (args.get("imageMean")); - float IMAGE_MEAN = (float) mean; - double std = (double) (args.get("imageStd")); - float IMAGE_STD = (float) std; - int imageHeight = (int) (args.get("imageHeight")); - int imageWidth = (int) (args.get("imageWidth")); - int rotation = (int) (args.get("rotation")); - int numResults = (int) args.get("numResults"); - double threshold = (double) args.get("threshold"); - int nmsRadius = (int) args.get("nmsRadius"); - - ByteBuffer imgData = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); - - new RunPoseNet(args, imgData, numResults, threshold, nmsRadius, result).executeTfliteTask(); - } + void runPoseNetOnFrame(HashMap args, Result result) throws IOException { + List bytesList = (ArrayList) args.get("bytesList"); + double mean = (double) (args.get("imageMean")); + float IMAGE_MEAN = (float) mean; + double std = (double) (args.get("imageStd")); + float IMAGE_STD = (float) std; + int imageHeight = (int) (args.get("imageHeight")); + int imageWidth = (int) (args.get("imageWidth")); + int rotation = (int) (args.get("rotation")); + int numResults = (int) args.get("numResults"); + double threshold = (double) args.get("threshold"); + int nmsRadius = (int) args.get("nmsRadius"); - void initPoseNet(Map outputMap) { - if (partsIds.size() == 0) { - for (int i = 0; i < partNames.length; ++i) - partsIds.put(partNames[i], i); + ByteBuffer imgData = feedInputTensorFrame(bytesList, imageHeight, imageWidth, IMAGE_MEAN, IMAGE_STD, rotation); - for (int i = 0; i < poseChain.length; ++i) { - parentToChildEdges.add(partsIds.get(poseChain[i][1])); - childToParentEdges.add(partsIds.get(poseChain[i][0])); - } + new RunPoseNet(args, imgData, numResults, threshold, nmsRadius, result).executeTfliteTask(); } - for (int i = 0; i < tfLite.getOutputTensorCount(); i++) { - int[] shape = tfLite.getOutputTensor(i).shape(); - float[][][][] output = new float[shape[0]][shape[1]][shape[2]][shape[3]]; - outputMap.put(i, output); - } - } + void initPoseNet(Map outputMap) { + if (partsIds.size() == 0) { + for (int i = 0; i < partNames.length; ++i) + partsIds.put(partNames[i], i); - private class RunPoseNet extends TfliteTask { - long startTime; - Object[] input; - Map outputMap = new HashMap<>(); - int numResults; - double threshold; - int nmsRadius; - - int localMaximumRadius = 1; - int outputStride = 16; - - RunPoseNet(HashMap args, - ByteBuffer imgData, - int numResults, - double threshold, - int nmsRadius, - Result result) throws IOException { - super(args, result); - this.numResults = numResults; - this.threshold = threshold; - this.nmsRadius = nmsRadius; - - input = new Object[]{imgData}; - initPoseNet(outputMap); - - startTime = SystemClock.uptimeMillis(); - } + for (int i = 0; i < poseChain.length; ++i) { + parentToChildEdges.add(partsIds.get(poseChain[i][1])); + childToParentEdges.add(partsIds.get(poseChain[i][0])); + } + } - protected void runTflite() { - tfLite.runForMultipleInputsOutputs(input, outputMap); + for (int i = 0; i < tfLite.getOutputTensorCount(); i++) { + int[] shape = tfLite.getOutputTensor(i).shape(); + float[][][][] output = new float[shape[0]][shape[1]][shape[2]][shape[3]]; + outputMap.put(i, output); + } } - protected void onRunTfliteDone() { - Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); + private class RunPoseNet extends TfliteTask { + long startTime; + Object[] input; + Map outputMap = new HashMap<>(); + int numResults; + double threshold; + int nmsRadius; + + int localMaximumRadius = 1; + int outputStride = 16; + + RunPoseNet(HashMap args, + ByteBuffer imgData, + int numResults, + double threshold, + int nmsRadius, + Result result) throws IOException { + super(args, result); + this.numResults = numResults; + this.threshold = threshold; + this.nmsRadius = nmsRadius; + + input = new Object[]{imgData}; + initPoseNet(outputMap); + + startTime = SystemClock.uptimeMillis(); + } + + protected void runTflite() { + tfLite.runForMultipleInputsOutputs(input, outputMap); + } - float[][][] scores = ((float[][][][]) outputMap.get(0))[0]; - float[][][] offsets = ((float[][][][]) outputMap.get(1))[0]; - float[][][] displacementsFwd = ((float[][][][]) outputMap.get(2))[0]; - float[][][] displacementsBwd = ((float[][][][]) outputMap.get(3))[0]; + protected void onRunTfliteDone() { + Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime)); - PriorityQueue> pq = buildPartWithScoreQueue(scores, threshold, localMaximumRadius); + float[][][] scores = ((float[][][][]) outputMap.get(0))[0]; + float[][][] offsets = ((float[][][][]) outputMap.get(1))[0]; + float[][][] displacementsFwd = ((float[][][][]) outputMap.get(2))[0]; + float[][][] displacementsBwd = ((float[][][][]) outputMap.get(3))[0]; - int numParts = scores[0][0].length; - int numEdges = parentToChildEdges.size(); - int sqaredNmsRadius = nmsRadius * nmsRadius; + PriorityQueue> pq = buildPartWithScoreQueue(scores, threshold, localMaximumRadius); - List> results = new ArrayList<>(); + int numParts = scores[0][0].length; + int numEdges = parentToChildEdges.size(); + int sqaredNmsRadius = nmsRadius * nmsRadius; - while (results.size() < numResults && pq.size() > 0) { - Map root = pq.poll(); - float[] rootPoint = getImageCoords(root, outputStride, numParts, offsets); + List> results = new ArrayList<>(); - if (withinNmsRadiusOfCorrespondingPoint( - results, sqaredNmsRadius, rootPoint[0], rootPoint[1], (int) root.get("partId"))) - continue; + while (results.size() < numResults && pq.size() > 0) { + Map root = pq.poll(); + float[] rootPoint = getImageCoords(root, outputStride, numParts, offsets); - Map keypoint = new HashMap<>(); - keypoint.put("score", root.get("score")); - keypoint.put("part", partNames[(int) root.get("partId")]); - keypoint.put("y", rootPoint[0] / inputSize); - keypoint.put("x", rootPoint[1] / inputSize); - - Map> keypoints = new HashMap<>(); - keypoints.put((int) root.get("partId"), keypoint); - - for (int edge = numEdges - 1; edge >= 0; --edge) { - int sourceKeypointId = parentToChildEdges.get(edge); - int targetKeypointId = childToParentEdges.get(edge); - if (keypoints.containsKey(sourceKeypointId) && !keypoints.containsKey(targetKeypointId)) { - keypoint = traverseToTargetKeypoint(edge, keypoints.get(sourceKeypointId), - targetKeypointId, scores, offsets, outputStride, displacementsBwd); - keypoints.put(targetKeypointId, keypoint); - } - } - - for (int edge = 0; edge < numEdges; ++edge) { - int sourceKeypointId = childToParentEdges.get(edge); - int targetKeypointId = parentToChildEdges.get(edge); - if (keypoints.containsKey(sourceKeypointId) && !keypoints.containsKey(targetKeypointId)) { - keypoint = traverseToTargetKeypoint(edge, keypoints.get(sourceKeypointId), - targetKeypointId, scores, offsets, outputStride, displacementsFwd); - keypoints.put(targetKeypointId, keypoint); - } - } - - Map result = new HashMap<>(); - result.put("keypoints", keypoints); - result.put("score", getInstanceScore(keypoints, numParts)); - results.add(result); - } - - result.success(results); - } - } + if (withinNmsRadiusOfCorrespondingPoint( + results, sqaredNmsRadius, rootPoint[0], rootPoint[1], (int) root.get("partId"))) + continue; - PriorityQueue> buildPartWithScoreQueue(float[][][] scores, - double threshold, - int localMaximumRadius) { - PriorityQueue> pq = - new PriorityQueue<>( - 1, - new Comparator>() { - @Override - public int compare(Map lhs, Map rhs) { - return Float.compare((float) rhs.get("score"), (float) lhs.get("score")); - } - }); - - for (int heatmapY = 0; heatmapY < scores.length; ++heatmapY) { - for (int heatmapX = 0; heatmapX < scores[0].length; ++heatmapX) { - for (int keypointId = 0; keypointId < scores[0][0].length; ++keypointId) { - float score = sigmoid(scores[heatmapY][heatmapX][keypointId]); - if (score < threshold) continue; - - if (scoreIsMaximumInLocalWindow( - keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores)) { - Map res = new HashMap<>(); - res.put("score", score); - res.put("y", heatmapY); - res.put("x", heatmapX); - res.put("partId", keypointId); - pq.add(res); - } - } - } - } + Map keypoint = new HashMap<>(); + keypoint.put("score", root.get("score")); + keypoint.put("part", partNames[(int) root.get("partId")]); + keypoint.put("y", rootPoint[0] / inputSize); + keypoint.put("x", rootPoint[1] / inputSize); - return pq; - } + Map> keypoints = new HashMap<>(); + keypoints.put((int) root.get("partId"), keypoint); - boolean scoreIsMaximumInLocalWindow(int keypointId, - float score, - int heatmapY, - int heatmapX, - int localMaximumRadius, - float[][][] scores) { - boolean localMaximum = true; - int height = scores.length; - int width = scores[0].length; - - int yStart = Math.max(heatmapY - localMaximumRadius, 0); - int yEnd = Math.min(heatmapY + localMaximumRadius + 1, height); - for (int yCurrent = yStart; yCurrent < yEnd; ++yCurrent) { - int xStart = Math.max(heatmapX - localMaximumRadius, 0); - int xEnd = Math.min(heatmapX + localMaximumRadius + 1, width); - for (int xCurrent = xStart; xCurrent < xEnd; ++xCurrent) { - if (sigmoid(scores[yCurrent][xCurrent][keypointId]) > score) { - localMaximum = false; - break; - } - } - if (!localMaximum) { - break; - } - } + for (int edge = numEdges - 1; edge >= 0; --edge) { + int sourceKeypointId = parentToChildEdges.get(edge); + int targetKeypointId = childToParentEdges.get(edge); + if (keypoints.containsKey(sourceKeypointId) && !keypoints.containsKey(targetKeypointId)) { + keypoint = traverseToTargetKeypoint(edge, keypoints.get(sourceKeypointId), + targetKeypointId, scores, offsets, outputStride, displacementsBwd); + keypoints.put(targetKeypointId, keypoint); + } + } - return localMaximum; - } + for (int edge = 0; edge < numEdges; ++edge) { + int sourceKeypointId = childToParentEdges.get(edge); + int targetKeypointId = parentToChildEdges.get(edge); + if (keypoints.containsKey(sourceKeypointId) && !keypoints.containsKey(targetKeypointId)) { + keypoint = traverseToTargetKeypoint(edge, keypoints.get(sourceKeypointId), + targetKeypointId, scores, offsets, outputStride, displacementsFwd); + keypoints.put(targetKeypointId, keypoint); + } + } - float[] getImageCoords(Map keypoint, - int outputStride, - int numParts, - float[][][] offsets) { - int heatmapY = (int) keypoint.get("y"); - int heatmapX = (int) keypoint.get("x"); - int keypointId = (int) keypoint.get("partId"); - float offsetY = offsets[heatmapY][heatmapX][keypointId]; - float offsetX = offsets[heatmapY][heatmapX][keypointId + numParts]; + Map result = new HashMap<>(); + result.put("keypoints", keypoints); + result.put("score", getInstanceScore(keypoints, numParts)); + results.add(result); + } - float y = heatmapY * outputStride + offsetY; - float x = heatmapX * outputStride + offsetX; + result.success(results); + } + } - return new float[]{y, x}; - } + PriorityQueue> buildPartWithScoreQueue(float[][][] scores, + double threshold, + int localMaximumRadius) { + PriorityQueue> pq = + new PriorityQueue<>( + 1, + new Comparator>() { + @Override + public int compare(Map lhs, Map rhs) { + return Float.compare((float) rhs.get("score"), (float) lhs.get("score")); + } + }); + + for (int heatmapY = 0; heatmapY < scores.length; ++heatmapY) { + for (int heatmapX = 0; heatmapX < scores[0].length; ++heatmapX) { + for (int keypointId = 0; keypointId < scores[0][0].length; ++keypointId) { + float score = sigmoid(scores[heatmapY][heatmapX][keypointId]); + if (score < threshold) continue; + + if (scoreIsMaximumInLocalWindow( + keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores)) { + Map res = new HashMap<>(); + res.put("score", score); + res.put("y", heatmapY); + res.put("x", heatmapX); + res.put("partId", keypointId); + pq.add(res); + } + } + } + } - boolean withinNmsRadiusOfCorrespondingPoint(List> poses, - float squaredNmsRadius, - float y, - float x, - int keypointId) { - for (Map pose : poses) { - Map keypoints = (Map) pose.get("keypoints"); - Map correspondingKeypoint = (Map) keypoints.get(keypointId); - float _x = (float) correspondingKeypoint.get("x") * inputSize - x; - float _y = (float) correspondingKeypoint.get("y") * inputSize - y; - float squaredDistance = _x * _x + _y * _y; - if (squaredDistance <= squaredNmsRadius) - return true; + return pq; + } + + boolean scoreIsMaximumInLocalWindow(int keypointId, + float score, + int heatmapY, + int heatmapX, + int localMaximumRadius, + float[][][] scores) { + boolean localMaximum = true; + int height = scores.length; + int width = scores[0].length; + + int yStart = Math.max(heatmapY - localMaximumRadius, 0); + int yEnd = Math.min(heatmapY + localMaximumRadius + 1, height); + for (int yCurrent = yStart; yCurrent < yEnd; ++yCurrent) { + int xStart = Math.max(heatmapX - localMaximumRadius, 0); + int xEnd = Math.min(heatmapX + localMaximumRadius + 1, width); + for (int xCurrent = xStart; xCurrent < xEnd; ++xCurrent) { + if (sigmoid(scores[yCurrent][xCurrent][keypointId]) > score) { + localMaximum = false; + break; + } + } + if (!localMaximum) { + break; + } + } + + return localMaximum; + } + + float[] getImageCoords(Map keypoint, + int outputStride, + int numParts, + float[][][] offsets) { + int heatmapY = (int) keypoint.get("y"); + int heatmapX = (int) keypoint.get("x"); + int keypointId = (int) keypoint.get("partId"); + float offsetY = offsets[heatmapY][heatmapX][keypointId]; + float offsetX = offsets[heatmapY][heatmapX][keypointId + numParts]; + + float y = heatmapY * outputStride + offsetY; + float x = heatmapX * outputStride + offsetX; + + return new float[]{y, x}; + } + + boolean withinNmsRadiusOfCorrespondingPoint(List> poses, + float squaredNmsRadius, + float y, + float x, + int keypointId) { + for (Map pose : poses) { + Map keypoints = (Map) pose.get("keypoints"); + Map correspondingKeypoint = (Map) keypoints.get(keypointId); + float _x = (float) correspondingKeypoint.get("x") * inputSize - x; + float _y = (float) correspondingKeypoint.get("y") * inputSize - y; + float squaredDistance = _x * _x + _y * _y; + if (squaredDistance <= squaredNmsRadius) + return true; + } + + return false; } - return false; - } + Map traverseToTargetKeypoint(int edgeId, + Map sourceKeypoint, + int targetKeypointId, + float[][][] scores, + float[][][] offsets, + int outputStride, + float[][][] displacements) { + int height = scores.length; + int width = scores[0].length; + int numKeypoints = scores[0][0].length; + float sourceKeypointY = (float) sourceKeypoint.get("y") * inputSize; + float sourceKeypointX = (float) sourceKeypoint.get("x") * inputSize; - Map traverseToTargetKeypoint(int edgeId, - Map sourceKeypoint, - int targetKeypointId, - float[][][] scores, - float[][][] offsets, - int outputStride, - float[][][] displacements) { - int height = scores.length; - int width = scores[0].length; - int numKeypoints = scores[0][0].length; - float sourceKeypointY = (float) sourceKeypoint.get("y") * inputSize; - float sourceKeypointX = (float) sourceKeypoint.get("x") * inputSize; - - int[] sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypointY, sourceKeypointX, - outputStride, height, width); - - float[] displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements); - - float[] displacedPoint = new float[]{ - sourceKeypointY + displacement[0], - sourceKeypointX + displacement[1] - }; + int[] sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypointY, sourceKeypointX, + outputStride, height, width); - float[] targetKeypoint = displacedPoint; + float[] displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements); - final int offsetRefineStep = 2; - for (int i = 0; i < offsetRefineStep; i++) { - int[] targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint[0], targetKeypoint[1], - outputStride, height, width); + float[] displacedPoint = new float[]{ + sourceKeypointY + displacement[0], + sourceKeypointX + displacement[1] + }; - int targetKeypointY = targetKeypointIndices[0]; - int targetKeypointX = targetKeypointIndices[1]; + float[] targetKeypoint = displacedPoint; - float offsetY = offsets[targetKeypointY][targetKeypointX][targetKeypointId]; - float offsetX = offsets[targetKeypointY][targetKeypointX][targetKeypointId + numKeypoints]; + final int offsetRefineStep = 2; + for (int i = 0; i < offsetRefineStep; i++) { + int[] targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint[0], targetKeypoint[1], + outputStride, height, width); - targetKeypoint = new float[]{ - targetKeypointY * outputStride + offsetY, - targetKeypointX * outputStride + offsetX - }; - } + int targetKeypointY = targetKeypointIndices[0]; + int targetKeypointX = targetKeypointIndices[1]; - int[] targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint[0], targetKeypoint[1], - outputStride, height, width); + float offsetY = offsets[targetKeypointY][targetKeypointX][targetKeypointId]; + float offsetX = offsets[targetKeypointY][targetKeypointX][targetKeypointId + numKeypoints]; - float score = sigmoid(scores[targetKeypointIndices[0]][targetKeypointIndices[1]][targetKeypointId]); + targetKeypoint = new float[]{ + targetKeypointY * outputStride + offsetY, + targetKeypointX * outputStride + offsetX + }; + } - Map keypoint = new HashMap<>(); - keypoint.put("score", score); - keypoint.put("part", partNames[targetKeypointId]); - keypoint.put("y", targetKeypoint[0] / inputSize); - keypoint.put("x", targetKeypoint[1] / inputSize); + int[] targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint[0], targetKeypoint[1], + outputStride, height, width); - return keypoint; - } + float score = sigmoid(scores[targetKeypointIndices[0]][targetKeypointIndices[1]][targetKeypointId]); - int[] getStridedIndexNearPoint(float _y, float _x, int outputStride, int height, int width) { - int y_ = Math.round(_y / outputStride); - int x_ = Math.round(_x / outputStride); - int y = y_ < 0 ? 0 : y_ > height - 1 ? height - 1 : y_; - int x = x_ < 0 ? 0 : x_ > width - 1 ? width - 1 : x_; - return new int[]{y, x}; - } + Map keypoint = new HashMap<>(); + keypoint.put("score", score); + keypoint.put("part", partNames[targetKeypointId]); + keypoint.put("y", targetKeypoint[0] / inputSize); + keypoint.put("x", targetKeypoint[1] / inputSize); - float[] getDisplacement(int edgeId, int[] keypoint, float[][][] displacements) { - int numEdges = displacements[0][0].length / 2; - int y = keypoint[0]; - int x = keypoint[1]; - return new float[]{displacements[y][x][edgeId], displacements[y][x][edgeId + numEdges]}; - } + return keypoint; + } - float getInstanceScore(Map> keypoints, int numKeypoints) { - float scores = 0; - for (Map.Entry> keypoint : keypoints.entrySet()) - scores += (float) keypoint.getValue().get("score"); - return scores / numKeypoints; - } + int[] getStridedIndexNearPoint(float _y, float _x, int outputStride, int height, int width) { + int y_ = Math.round(_y / outputStride); + int x_ = Math.round(_x / outputStride); + int y = y_ < 0 ? 0 : y_ > height - 1 ? height - 1 : y_; + int x = x_ < 0 ? 0 : x_ > width - 1 ? width - 1 : x_; + return new int[]{y, x}; + } - private float sigmoid(final float x) { - return (float) (1. / (1. + Math.exp(-x))); - } + float[] getDisplacement(int edgeId, int[] keypoint, float[][][] displacements) { + int numEdges = displacements[0][0].length / 2; + int y = keypoint[0]; + int x = keypoint[1]; + return new float[]{displacements[y][x][edgeId], displacements[y][x][edgeId + numEdges]}; + } - private void softmax(final float[] vals) { - float max = Float.NEGATIVE_INFINITY; - for (final float val : vals) { - max = Math.max(max, val); + float getInstanceScore(Map> keypoints, int numKeypoints) { + float scores = 0; + for (Map.Entry> keypoint : keypoints.entrySet()) + scores += (float) keypoint.getValue().get("score"); + return scores / numKeypoints; } - float sum = 0.0f; - for (int i = 0; i < vals.length; ++i) { - vals[i] = (float) Math.exp(vals[i] - max); - sum += vals[i]; + + private float sigmoid(final float x) { + return (float) (1. / (1. + Math.exp(-x))); } - for (int i = 0; i < vals.length; ++i) { - vals[i] = vals[i] / sum; + + private void softmax(final float[] vals) { + float max = Float.NEGATIVE_INFINITY; + for (final float val : vals) { + max = Math.max(max, val); + } + float sum = 0.0f; + for (int i = 0; i < vals.length; ++i) { + vals[i] = (float) Math.exp(vals[i] - max); + sum += vals[i]; + } + for (int i = 0; i < vals.length; ++i) { + vals[i] = vals[i] / sum; + } } - } - private static Matrix getTransformationMatrix(final int srcWidth, - final int srcHeight, - final int dstWidth, - final int dstHeight, - final boolean maintainAspectRatio) { - final Matrix matrix = new Matrix(); - - if (srcWidth != dstWidth || srcHeight != dstHeight) { - final float scaleFactorX = dstWidth / (float) srcWidth; - final float scaleFactorY = dstHeight / (float) srcHeight; - - if (maintainAspectRatio) { - final float scaleFactor = Math.max(scaleFactorX, scaleFactorY); - matrix.postScale(scaleFactor, scaleFactor); - } else { - matrix.postScale(scaleFactorX, scaleFactorY); - } + private static Matrix getTransformationMatrix(final int srcWidth, + final int srcHeight, + final int dstWidth, + final int dstHeight, + final boolean maintainAspectRatio) { + final Matrix matrix = new Matrix(); + + if (srcWidth != dstWidth || srcHeight != dstHeight) { + final float scaleFactorX = dstWidth / (float) srcWidth; + final float scaleFactorY = dstHeight / (float) srcHeight; + + if (maintainAspectRatio) { + final float scaleFactor = Math.max(scaleFactorX, scaleFactorY); + matrix.postScale(scaleFactor, scaleFactor); + } else { + matrix.postScale(scaleFactorX, scaleFactorY); + } + } + + matrix.invert(new Matrix()); + return matrix; + } + + private void close() { + if (tfLite != null) + tfLite.close(); + labels = null; + labelProb = null; } - matrix.invert(new Matrix()); - return matrix; - } - private void close() { - if (tfLite != null) - tfLite.close(); - labels = null; - labelProb = null; - } } diff --git a/example/.gitignore b/example/.gitignore index bc6181c..0fa6b67 100644 --- a/example/.gitignore +++ b/example/.gitignore @@ -1,23 +1,46 @@ +# Miscellaneous +*.class +*.log +*.pyc +*.swp .DS_Store -.dart_tool/ +.atom/ +.buildlog/ +.history +.svn/ -.packages -.pub/ +# IntelliJ related +*.iml +*.ipr +*.iws +.idea/ -build/ +# The .vscode folder contains launch configuration and tasks you configure in +# VS Code which you may wish to be included in version control, so this line +# is commented out by default. +#.vscode/ +# Flutter/Dart/Pub related +**/doc/api/ +**/ios/Flutter/.last_build_id +.dart_tool/ .flutter-plugins .flutter-plugins-dependencies +.packages +.pub-cache/ +.pub/ +/build/ -flutter_export_environment.sh -Flutter.podspec +# Web related +lib/generated_plugin_registrant.dart -# IntelliJ -*.iml -.idea/workspace.xml -.idea/tasks.xml -.idea/gradle.xml -.idea/assetWizardSettings.xml -.idea/dictionaries -.idea/libraries -.idea/caches \ No newline at end of file +# Symbolication related +app.*.symbols + +# Obfuscation related +app.*.map.json + +# Android Studio will place build artifacts here +/android/app/debug +/android/app/profile +/android/app/release diff --git a/example/.metadata b/example/.metadata index 1634cfb..be0f63d 100644 --- a/example/.metadata +++ b/example/.metadata @@ -4,5 +4,7 @@ # This file should be version controlled and should not be manually edited. version: - revision: 3b309bda072a6b326e8aa4591a5836af600923ce - channel: beta + revision: 4cc385b4b84ac2f816d939a49ea1f328c4e0b48e + channel: stable + +project_type: app diff --git a/example/README.md b/example/README.md index 849606f..8e39bb3 100644 --- a/example/README.md +++ b/example/README.md @@ -1,35 +1,16 @@ # tflite_example -Use tflite plugin to run model on images. The image is captured by camera or selected from gallery (with the help of [image_picker](https://pub.dartlang.org/packages/image_picker) plugin). +Demonstrates how to use the tflite plugin. -![](yolo.jpg) +## Getting Started -## Prerequisites +This project is a starting point for a Flutter application. -Create a `assets` folder. From https://github.com/shaqian/flutter_tflite/tree/master/example/assets -dowload the following files and place them in `assets` folder. - - mobilenet_v1_1.0_224.tflite - - mobilenet_v1_1.0_224.txt - - ssd_mobilenet.tflite - - ssd_mobilenet.txt - - yolov2_tiny.tflite - - yolov2_tiny.txt - - deeplabv3_257_mv_gpu.tflite - - deeplabv3_257_mv_gpu.txt - - posenet_mv1_075_float_from_checkpoints.tflite +A few resources to get you started if this is your first Flutter project: -## Install +- [Lab: Write your first Flutter app](https://flutter.dev/docs/get-started/codelab) +- [Cookbook: Useful Flutter samples](https://flutter.dev/docs/cookbook) -``` -flutter packages get -``` - -## Run - -``` -flutter run -``` - -## Caveat - -```recognizeImageBinary(image)``` (sample code for ```runModelOnBinary```) is slow on iOS when decoding image due to a [known issue](https://github.com/brendan-duncan/image/issues/55) with image package. +For help getting started with Flutter, view our +[online documentation](https://flutter.dev/docs), which offers tutorials, +samples, guidance on mobile development, and a full API reference. diff --git a/example/analysis_options.yaml b/example/analysis_options.yaml new file mode 100644 index 0000000..61b6c4d --- /dev/null +++ b/example/analysis_options.yaml @@ -0,0 +1,29 @@ +# This file configures the analyzer, which statically analyzes Dart code to +# check for errors, warnings, and lints. +# +# The issues identified by the analyzer are surfaced in the UI of Dart-enabled +# IDEs (https://dart.dev/tools#ides-and-editors). The analyzer can also be +# invoked from the command line by running `flutter analyze`. + +# The following line activates a set of recommended lints for Flutter apps, +# packages, and plugins designed to encourage good coding practices. +include: package:flutter_lints/flutter.yaml + +linter: + # The lint rules applied to this project can be customized in the + # section below to disable rules from the `package:flutter_lints/flutter.yaml` + # included above or to enable additional rules. A list of all available lints + # and their documentation is published at + # https://dart-lang.github.io/linter/lints/index.html. + # + # Instead of disabling a lint rule for the entire project in the + # section below, it can also be suppressed for a single line of code + # or a specific dart file by using the `// ignore: name_of_lint` and + # `// ignore_for_file: name_of_lint` syntax on the line or in the file + # producing the lint. + rules: + # avoid_print: false # Uncomment to disable the `avoid_print` rule + # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule + +# Additional information about this file can be found at +# https://dart.dev/guides/language/analysis-options diff --git a/example/android/.gitignore b/example/android/.gitignore index 65b7315..6f56801 100644 --- a/example/android/.gitignore +++ b/example/android/.gitignore @@ -1,10 +1,13 @@ -*.iml -*.class -.gradle +gradle-wrapper.jar +/.gradle +/captures/ +/gradlew +/gradlew.bat /local.properties -/.idea/workspace.xml -/.idea/libraries -.DS_Store -/build -/captures GeneratedPluginRegistrant.java + +# Remember to never publicly share your keystore. +# See https://flutter.dev/docs/deployment/android#reference-the-keystore-from-the-app +key.properties +**/*.keystore +**/*.jks diff --git a/example/android/.project b/example/android/.project index 3964dd3..e31d95a 100644 --- a/example/android/.project +++ b/example/android/.project @@ -1,7 +1,7 @@ - android - Project android created by Buildship. + android_ + Project android_ created by Buildship. @@ -14,4 +14,15 @@ org.eclipse.buildship.core.gradleprojectnature + + + 1635544201843 + + 30 + + org.eclipse.core.resources.regexFilterMatcher + node_modules|.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__ + + + diff --git a/example/android/.settings/org.eclipse.buildship.core.prefs b/example/android/.settings/org.eclipse.buildship.core.prefs index e889521..016f0a1 100644 --- a/example/android/.settings/org.eclipse.buildship.core.prefs +++ b/example/android/.settings/org.eclipse.buildship.core.prefs @@ -1,2 +1,13 @@ +arguments= +auto.sync=false +build.scans.enabled=false +connection.gradle.distribution=GRADLE_DISTRIBUTION(WRAPPER) connection.project.dir= eclipse.preferences.version=1 +gradle.user.home= +java.home=C\:/Program Files/Eclipse Foundation/jdk-17.0.0.35-hotspot +jvm.arguments= +offline.mode=false +override.workspace.settings=true +show.console.view=true +show.executions.view=true diff --git a/example/android/app/.classpath b/example/android/app/.classpath deleted file mode 100644 index eb19361..0000000 --- a/example/android/app/.classpath +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - diff --git a/example/android/app/.project b/example/android/app/.project deleted file mode 100644 index ac485d7..0000000 --- a/example/android/app/.project +++ /dev/null @@ -1,23 +0,0 @@ - - - app - Project app created by Buildship. - - - - - org.eclipse.jdt.core.javabuilder - - - - - org.eclipse.buildship.core.gradleprojectbuilder - - - - - - org.eclipse.jdt.core.javanature - org.eclipse.buildship.core.gradleprojectnature - - diff --git a/example/android/app/.settings/org.eclipse.buildship.core.prefs b/example/android/app/.settings/org.eclipse.buildship.core.prefs deleted file mode 100644 index b1886ad..0000000 --- a/example/android/app/.settings/org.eclipse.buildship.core.prefs +++ /dev/null @@ -1,2 +0,0 @@ -connection.project.dir=.. -eclipse.preferences.version=1 diff --git a/example/android/app/build.gradle b/example/android/app/build.gradle index a627ccf..3e59f62 100644 --- a/example/android/app/build.gradle +++ b/example/android/app/build.gradle @@ -25,24 +25,20 @@ apply plugin: 'com.android.application' apply from: "$flutterRoot/packages/flutter_tools/gradle/flutter.gradle" android { - compileSdkVersion 28 + compileSdkVersion 31 - lintOptions { - disable 'InvalidPackage' - } - - aaptOptions { - noCompress 'tflite' + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 } defaultConfig { // TODO: Specify your own unique Application ID (https://developer.android.com/studio/build/application-id.html). - applicationId "sq.flutter.tfliteexample" - minSdkVersion 19 - targetSdkVersion 28 + applicationId "sq.flutter.tflite_example" + minSdkVersion 17 + targetSdkVersion 31 versionCode flutterVersionCode.toInteger() versionName flutterVersionName - testInstrumentationRunner 'androidx.test.runner.AndroidJUnitRunner' } buildTypes { @@ -57,9 +53,3 @@ android { flutter { source '../..' } - -dependencies { - testImplementation 'junit:junit:4.12' - androidTestImplementation 'androidx.test.ext:junit:1.1.1' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.1.0' -} diff --git a/example/android/app/src/debug/AndroidManifest.xml b/example/android/app/src/debug/AndroidManifest.xml new file mode 100644 index 0000000..94529f8 --- /dev/null +++ b/example/android/app/src/debug/AndroidManifest.xml @@ -0,0 +1,7 @@ + + + + diff --git a/example/android/app/src/main/AndroidManifest.xml b/example/android/app/src/main/AndroidManifest.xml index aa3336f..08c92d0 100644 --- a/example/android/app/src/main/AndroidManifest.xml +++ b/example/android/app/src/main/AndroidManifest.xml @@ -1,39 +1,33 @@ - - - - - - + - + + android:name="io.flutter.embedding.android.NormalTheme" + android:resource="@style/NormalTheme" + /> + + + diff --git a/example/android/app/src/main/java/sq/flutter/tflite_example/MainActivity.java b/example/android/app/src/main/java/sq/flutter/tflite_example/MainActivity.java new file mode 100644 index 0000000..f21902a --- /dev/null +++ b/example/android/app/src/main/java/sq/flutter/tflite_example/MainActivity.java @@ -0,0 +1,6 @@ +package sq.flutter.tflite_example; + +import io.flutter.embedding.android.FlutterActivity; + +public class MainActivity extends FlutterActivity { +} diff --git a/example/android/app/src/main/java/sq/flutter/tfliteexample/MainActivity.java b/example/android/app/src/main/java/sq/flutter/tfliteexample/MainActivity.java deleted file mode 100644 index 968a340..0000000 --- a/example/android/app/src/main/java/sq/flutter/tfliteexample/MainActivity.java +++ /dev/null @@ -1,13 +0,0 @@ -package sq.flutter.tfliteexample; - -import android.os.Bundle; -import io.flutter.app.FlutterActivity; -import io.flutter.plugins.GeneratedPluginRegistrant; - -public class MainActivity extends FlutterActivity { - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - GeneratedPluginRegistrant.registerWith(this); - } -} diff --git a/example/android/app/src/main/res/drawable-v21/launch_background.xml b/example/android/app/src/main/res/drawable-v21/launch_background.xml new file mode 100644 index 0000000..f74085f --- /dev/null +++ b/example/android/app/src/main/res/drawable-v21/launch_background.xml @@ -0,0 +1,12 @@ + + + + + + + + diff --git a/example/android/app/src/main/res/values-night/styles.xml b/example/android/app/src/main/res/values-night/styles.xml new file mode 100644 index 0000000..449a9f9 --- /dev/null +++ b/example/android/app/src/main/res/values-night/styles.xml @@ -0,0 +1,18 @@ + + + + + + + diff --git a/example/android/app/src/main/res/values/styles.xml b/example/android/app/src/main/res/values/styles.xml index 00fa441..d74aa35 100644 --- a/example/android/app/src/main/res/values/styles.xml +++ b/example/android/app/src/main/res/values/styles.xml @@ -1,8 +1,18 @@ - + + diff --git a/example/android/app/src/profile/AndroidManifest.xml b/example/android/app/src/profile/AndroidManifest.xml new file mode 100644 index 0000000..94529f8 --- /dev/null +++ b/example/android/app/src/profile/AndroidManifest.xml @@ -0,0 +1,7 @@ + + + + diff --git a/example/android/build.gradle b/example/android/build.gradle index 83f114c..8bd9635 100644 --- a/example/android/build.gradle +++ b/example/android/build.gradle @@ -1,26 +1,24 @@ buildscript { repositories { google() - jcenter() + mavenCentral() } dependencies { - classpath 'com.android.tools.build:gradle:3.6.1' + classpath 'com.android.tools.build:gradle:4.1.3' } } allprojects { repositories { google() - jcenter() + mavenCentral() } } rootProject.buildDir = '../build' subprojects { project.buildDir = "${rootProject.buildDir}/${project.name}" -} -subprojects { project.evaluationDependsOn(':app') } diff --git a/example/android/gradle.properties b/example/android/gradle.properties index 29bf260..94adc3a 100644 --- a/example/android/gradle.properties +++ b/example/android/gradle.properties @@ -1,4 +1,3 @@ org.gradle.jvmargs=-Xmx1536M -target-platform=android-arm64 android.useAndroidX=true android.enableJetifier=true diff --git a/example/android/gradle/wrapper/gradle-wrapper.properties b/example/android/gradle/wrapper/gradle-wrapper.properties index 46510f3..b8793d3 100644 --- a/example/android/gradle/wrapper/gradle-wrapper.properties +++ b/example/android/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ -#Sat Mar 28 00:33:22 ICT 2020 +#Fri Jun 23 08:50:38 CEST 2017 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.4-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.0.2-all.zip diff --git a/example/android/settings.gradle b/example/android/settings.gradle index 5a2f14f..44e62bc 100644 --- a/example/android/settings.gradle +++ b/example/android/settings.gradle @@ -1,15 +1,11 @@ include ':app' -def flutterProjectRoot = rootProject.projectDir.parentFile.toPath() +def localPropertiesFile = new File(rootProject.projectDir, "local.properties") +def properties = new Properties() -def plugins = new Properties() -def pluginsFile = new File(flutterProjectRoot.toFile(), '.flutter-plugins') -if (pluginsFile.exists()) { - pluginsFile.withReader('UTF-8') { reader -> plugins.load(reader) } -} +assert localPropertiesFile.exists() +localPropertiesFile.withReader("UTF-8") { reader -> properties.load(reader) } -plugins.each { name, path -> - def pluginDirectory = flutterProjectRoot.resolve(path).resolve('android').toFile() - include ":$name" - project(":$name").projectDir = pluginDirectory -} +def flutterSdkPath = properties.getProperty("flutter.sdk") +assert flutterSdkPath != null, "flutter.sdk not set in local.properties" +apply from: "$flutterSdkPath/packages/flutter_tools/gradle/app_plugin_loader.gradle" diff --git a/example/android/settings_aar.gradle b/example/android/settings_aar.gradle deleted file mode 100644 index e7b4def..0000000 --- a/example/android/settings_aar.gradle +++ /dev/null @@ -1 +0,0 @@ -include ':app' diff --git a/example/ios/.gitignore b/example/ios/.gitignore index 79cc4da..151026b 100644 --- a/example/ios/.gitignore +++ b/example/ios/.gitignore @@ -1,45 +1,33 @@ -.idea/ -.vagrant/ -.sconsign.dblite -.svn/ - -.DS_Store -*.swp -profile - -DerivedData/ -build/ -GeneratedPluginRegistrant.h -GeneratedPluginRegistrant.m - -.generated/ - -*.pbxuser *.mode1v3 *.mode2v3 +*.moved-aside +*.pbxuser *.perspectivev3 - -!default.pbxuser +**/*sync/ +.sconsign.dblite +.tags* +**/.vagrant/ +**/DerivedData/ +Icon? +**/Pods/ +**/.symlinks/ +profile +xcuserdata +**/.generated/ +Flutter/App.framework +Flutter/Flutter.framework +Flutter/Flutter.podspec +Flutter/Generated.xcconfig +Flutter/ephemeral/ +Flutter/app.flx +Flutter/app.zip +Flutter/flutter_assets/ +Flutter/flutter_export_environment.sh +ServiceDefinitions.json +Runner/GeneratedPluginRegistrant.* + +# Exceptions to above rules. !default.mode1v3 !default.mode2v3 +!default.pbxuser !default.perspectivev3 - -xcuserdata - -*.moved-aside - -*.pyc -*sync/ -Icon? -.tags* - -/Flutter/app.flx -/Flutter/app.zip -/Flutter/flutter_assets/ -/Flutter/App.framework -/Flutter/Flutter.framework -/Flutter/Generated.xcconfig -/ServiceDefinitions.json - -Pods/ -.symlinks/ diff --git a/example/ios/Flutter/AppFrameworkInfo.plist b/example/ios/Flutter/AppFrameworkInfo.plist index 9367d48..8d4492f 100644 --- a/example/ios/Flutter/AppFrameworkInfo.plist +++ b/example/ios/Flutter/AppFrameworkInfo.plist @@ -21,6 +21,6 @@ CFBundleVersion 1.0 MinimumOSVersion - 8.0 + 9.0 diff --git a/example/ios/Flutter/Debug.xcconfig b/example/ios/Flutter/Debug.xcconfig index e8efba1..592ceee 100644 --- a/example/ios/Flutter/Debug.xcconfig +++ b/example/ios/Flutter/Debug.xcconfig @@ -1,2 +1 @@ -#include "Pods/Target Support Files/Pods-Runner/Pods-Runner.debug.xcconfig" #include "Generated.xcconfig" diff --git a/example/ios/Flutter/Release.xcconfig b/example/ios/Flutter/Release.xcconfig index 399e934..592ceee 100644 --- a/example/ios/Flutter/Release.xcconfig +++ b/example/ios/Flutter/Release.xcconfig @@ -1,2 +1 @@ -#include "Pods/Target Support Files/Pods-Runner/Pods-Runner.release.xcconfig" #include "Generated.xcconfig" diff --git a/example/ios/Podfile b/example/ios/Podfile deleted file mode 100644 index f7d6a5e..0000000 --- a/example/ios/Podfile +++ /dev/null @@ -1,38 +0,0 @@ -# Uncomment this line to define a global platform for your project -# platform :ios, '9.0' - -# CocoaPods analytics sends network stats synchronously affecting flutter build latency. -ENV['COCOAPODS_DISABLE_STATS'] = 'true' - -project 'Runner', { - 'Debug' => :debug, - 'Profile' => :release, - 'Release' => :release, -} - -def flutter_root - generated_xcode_build_settings_path = File.expand_path(File.join('..', 'Flutter', 'Generated.xcconfig'), __FILE__) - unless File.exist?(generated_xcode_build_settings_path) - raise "#{generated_xcode_build_settings_path} must exist. If you're running pod install manually, make sure flutter pub get is executed first" - end - - File.foreach(generated_xcode_build_settings_path) do |line| - matches = line.match(/FLUTTER_ROOT\=(.*)/) - return matches[1].strip if matches - end - raise "FLUTTER_ROOT not found in #{generated_xcode_build_settings_path}. Try deleting Generated.xcconfig, then run flutter pub get" -end - -require File.expand_path(File.join('packages', 'flutter_tools', 'bin', 'podhelper'), flutter_root) - -flutter_ios_podfile_setup - -target 'Runner' do - flutter_install_all_ios_pods File.dirname(File.realpath(__FILE__)) -end - -post_install do |installer| - installer.pods_project.targets.each do |target| - flutter_additional_ios_build_settings(target) - end -end diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock deleted file mode 100644 index 6286399..0000000 --- a/example/ios/Podfile.lock +++ /dev/null @@ -1,35 +0,0 @@ -PODS: - - Flutter (1.0.0) - - image_picker (0.0.1): - - Flutter - - TensorFlowLiteC (2.2.0) - - tflite (1.1.2): - - Flutter - - TensorFlowLiteC - -DEPENDENCIES: - - Flutter (from `Flutter`) - - image_picker (from `.symlinks/plugins/image_picker/ios`) - - tflite (from `.symlinks/plugins/tflite/ios`) - -SPEC REPOS: - trunk: - - TensorFlowLiteC - -EXTERNAL SOURCES: - Flutter: - :path: Flutter - image_picker: - :path: ".symlinks/plugins/image_picker/ios" - tflite: - :path: ".symlinks/plugins/tflite/ios" - -SPEC CHECKSUMS: - Flutter: 434fef37c0980e73bb6479ef766c45957d4b510c - image_picker: a211f28b95a560433c00f5cd3773f4710a20404d - TensorFlowLiteC: b3ab9e867b0b71052ca102a32a786555b330b02e - tflite: f0403a894740019d63ab5662253bba5b2dd37296 - -PODFILE CHECKSUM: 8e679eca47255a8ca8067c4c67aab20e64cb974d - -COCOAPODS: 1.10.1 diff --git a/example/ios/Runner.xcodeproj/project.pbxproj b/example/ios/Runner.xcodeproj/project.pbxproj index 1252660..bf7a2c6 100644 --- a/example/ios/Runner.xcodeproj/project.pbxproj +++ b/example/ios/Runner.xcodeproj/project.pbxproj @@ -9,13 +9,11 @@ /* Begin PBXBuildFile section */ 1498D2341E8E89220040F4C2 /* GeneratedPluginRegistrant.m in Sources */ = {isa = PBXBuildFile; fileRef = 1498D2331E8E89220040F4C2 /* GeneratedPluginRegistrant.m */; }; 3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */ = {isa = PBXBuildFile; fileRef = 3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */; }; - 9740EEB41CF90195004384FC /* Debug.xcconfig in Resources */ = {isa = PBXBuildFile; fileRef = 9740EEB21CF90195004384FC /* Debug.xcconfig */; }; 978B8F6F1D3862AE00F588F7 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 7AFFD8EE1D35381100E5BB4D /* AppDelegate.m */; }; 97C146F31CF9000F007C117D /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 97C146F21CF9000F007C117D /* main.m */; }; 97C146FC1CF9000F007C117D /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FA1CF9000F007C117D /* Main.storyboard */; }; 97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FD1CF9000F007C117D /* Assets.xcassets */; }; 97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FF1CF9000F007C117D /* LaunchScreen.storyboard */; }; - A8FCB07931B147D0C738D807 /* libPods-Runner.a in Frameworks */ = {isa = PBXBuildFile; fileRef = A4A034B01AB21E851714E03C /* libPods-Runner.a */; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -38,7 +36,6 @@ 7AFA3C8E1D35360C0083082E /* Release.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = Release.xcconfig; path = Flutter/Release.xcconfig; sourceTree = ""; }; 7AFFD8ED1D35381100E5BB4D /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; 7AFFD8EE1D35381100E5BB4D /* AppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; - 864E0E2308AE5F3A9409E901 /* Pods-Runner.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Runner.release.xcconfig"; path = "Pods/Target Support Files/Pods-Runner/Pods-Runner.release.xcconfig"; sourceTree = ""; }; 9740EEB21CF90195004384FC /* Debug.xcconfig */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.xcconfig; name = Debug.xcconfig; path = Flutter/Debug.xcconfig; sourceTree = ""; }; 9740EEB31CF90195004384FC /* Generated.xcconfig */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.xcconfig; name = Generated.xcconfig; path = Flutter/Generated.xcconfig; sourceTree = ""; }; 97C146EE1CF9000F007C117D /* Runner.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Runner.app; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -47,8 +44,6 @@ 97C146FD1CF9000F007C117D /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; 97C147001CF9000F007C117D /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = ""; }; 97C147021CF9000F007C117D /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; - A4A034B01AB21E851714E03C /* libPods-Runner.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-Runner.a"; sourceTree = BUILT_PRODUCTS_DIR; }; - E0C0C115F9024C6ADB3B2DB5 /* Pods-Runner.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Runner.debug.xcconfig"; path = "Pods/Target Support Files/Pods-Runner/Pods-Runner.debug.xcconfig"; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -56,30 +51,12 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - A8FCB07931B147D0C738D807 /* libPods-Runner.a in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ - 7670CC45CF9B055E20C18D9C /* Frameworks */ = { - isa = PBXGroup; - children = ( - A4A034B01AB21E851714E03C /* libPods-Runner.a */, - ); - name = Frameworks; - sourceTree = ""; - }; - 8EE3D73475BA2048B61051C2 /* Pods */ = { - isa = PBXGroup; - children = ( - E0C0C115F9024C6ADB3B2DB5 /* Pods-Runner.debug.xcconfig */, - 864E0E2308AE5F3A9409E901 /* Pods-Runner.release.xcconfig */, - ); - name = Pods; - sourceTree = ""; - }; 9740EEB11CF90186004384FC /* Flutter */ = { isa = PBXGroup; children = ( @@ -97,8 +74,7 @@ 9740EEB11CF90186004384FC /* Flutter */, 97C146F01CF9000F007C117D /* Runner */, 97C146EF1CF9000F007C117D /* Products */, - 8EE3D73475BA2048B61051C2 /* Pods */, - 7670CC45CF9B055E20C18D9C /* Frameworks */, + CF3B75C9A7D2FA2A4C99F110 /* Frameworks */, ); sourceTree = ""; }; @@ -141,7 +117,6 @@ isa = PBXNativeTarget; buildConfigurationList = 97C147051CF9000F007C117D /* Build configuration list for PBXNativeTarget "Runner" */; buildPhases = ( - FAEC5F0CFA3366178E53C4C5 /* [CP] Check Pods Manifest.lock */, 9740EEB61CF901F6004384FC /* Run Script */, 97C146EA1CF9000F007C117D /* Sources */, 97C146EB1CF9000F007C117D /* Frameworks */, @@ -164,21 +139,19 @@ 97C146E61CF9000F007C117D /* Project object */ = { isa = PBXProject; attributes = { - LastUpgradeCheck = 0910; - ORGANIZATIONNAME = "The Chromium Authors"; + LastUpgradeCheck = 1020; + ORGANIZATIONNAME = ""; TargetAttributes = { 97C146ED1CF9000F007C117D = { CreatedOnToolsVersion = 7.3.1; - DevelopmentTeam = ZJG3P98JS9; }; }; }; buildConfigurationList = 97C146E91CF9000F007C117D /* Build configuration list for PBXProject "Runner" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; hasScannedForEncodings = 0; knownRegions = ( - English, en, Base, ); @@ -199,7 +172,6 @@ files = ( 97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */, 3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */, - 9740EEB41CF90195004384FC /* Debug.xcconfig in Resources */, 97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */, 97C146FC1CF9000F007C117D /* Main.storyboard in Resources */, ); @@ -236,24 +208,6 @@ shellPath = /bin/sh; shellScript = "/bin/sh \"$FLUTTER_ROOT/packages/flutter_tools/bin/xcode_backend.sh\" build"; }; - FAEC5F0CFA3366178E53C4C5 /* [CP] Check Pods Manifest.lock */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - "${PODS_PODFILE_DIR_PATH}/Podfile.lock", - "${PODS_ROOT}/Manifest.lock", - ); - name = "[CP] Check Pods Manifest.lock"; - outputPaths = ( - "$(DERIVED_FILE_DIR)/Pods-Runner-checkManifestLockResult.txt", - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; - showEnvVarsInLog = 0; - }; /* End PBXShellScriptBuildPhase section */ /* Begin PBXSourcesBuildPhase section */ @@ -289,6 +243,71 @@ /* End PBXVariantGroup section */ /* Begin XCBuildConfiguration section */ + 249021D3217E4FDB00AE95B9 /* Profile */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 9.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + SUPPORTED_PLATFORMS = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Profile; + }; + 249021D4217E4FDB00AE95B9 /* Profile */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 7AFA3C8E1D35360C0083082E /* Release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CURRENT_PROJECT_VERSION = "$(FLUTTER_BUILD_NUMBER)"; + ENABLE_BITCODE = NO; + INFOPLIST_FILE = Runner/Info.plist; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = sq.flutter.tfliteExample; + PRODUCT_NAME = "$(TARGET_NAME)"; + VERSIONING_SYSTEM = "apple-generic"; + }; + name = Profile; + }; 97C147031CF9000F007C117D /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { @@ -302,12 +321,14 @@ CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; @@ -334,7 +355,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; + IPHONEOS_DEPLOYMENT_TARGET = 9.0; MTL_ENABLE_DEBUG_INFO = YES; ONLY_ACTIVE_ARCH = YES; SDKROOT = iphoneos; @@ -355,12 +376,14 @@ CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; CLANG_WARN_INT_CONVERSION = YES; CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; @@ -381,10 +404,10 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; + IPHONEOS_DEPLOYMENT_TARGET = 9.0; MTL_ENABLE_DEBUG_INFO = NO; - ONLY_ACTIVE_ARCH = YES; SDKROOT = iphoneos; + SUPPORTED_PLATFORMS = iphoneos; TARGETED_DEVICE_FAMILY = "1,2"; VALIDATE_PRODUCT = YES; }; @@ -396,27 +419,9 @@ buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; CURRENT_PROJECT_VERSION = "$(FLUTTER_BUILD_NUMBER)"; - DEVELOPMENT_TEAM = ZJG3P98JS9; ENABLE_BITCODE = NO; - FRAMEWORK_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/Flutter", - ); - HEADER_SEARCH_PATHS = ( - "$(inherited)", - "'${SRCROOT}/Pods/TensorFlowLite/Frameworks/tensorflow_lite.framework/Headers'", - "\"${PODS_ROOT}/Headers/Public\"", - "\"${PODS_ROOT}/Headers/Public/Flutter\"", - "\"${PODS_ROOT}/Headers/Public/TensorFlowLite\"", - "\"${PODS_ROOT}/Headers/Public/tflite\"", - ); INFOPLIST_FILE = Runner/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/Flutter", - ); PRODUCT_BUNDLE_IDENTIFIER = sq.flutter.tfliteExample; PRODUCT_NAME = "$(TARGET_NAME)"; VERSIONING_SYSTEM = "apple-generic"; @@ -429,27 +434,9 @@ buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; CURRENT_PROJECT_VERSION = "$(FLUTTER_BUILD_NUMBER)"; - DEVELOPMENT_TEAM = ZJG3P98JS9; ENABLE_BITCODE = NO; - FRAMEWORK_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/Flutter", - ); - HEADER_SEARCH_PATHS = ( - "$(inherited)", - "'${SRCROOT}/Pods/TensorFlowLite/Frameworks/tensorflow_lite.framework/Headers'", - "\"${PODS_ROOT}/Headers/Public\"", - "\"${PODS_ROOT}/Headers/Public/Flutter\"", - "\"${PODS_ROOT}/Headers/Public/TensorFlowLite\"", - "\"${PODS_ROOT}/Headers/Public/tflite\"", - ); INFOPLIST_FILE = Runner/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/Flutter", - ); PRODUCT_BUNDLE_IDENTIFIER = sq.flutter.tfliteExample; PRODUCT_NAME = "$(TARGET_NAME)"; VERSIONING_SYSTEM = "apple-generic"; @@ -464,6 +451,7 @@ buildConfigurations = ( 97C147031CF9000F007C117D /* Debug */, 97C147041CF9000F007C117D /* Release */, + 249021D3217E4FDB00AE95B9 /* Profile */, ); defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; @@ -473,6 +461,7 @@ buildConfigurations = ( 97C147061CF9000F007C117D /* Debug */, 97C147071CF9000F007C117D /* Release */, + 249021D4217E4FDB00AE95B9 /* Profile */, ); defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; diff --git a/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000..18d9810 --- /dev/null +++ b/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings new file mode 100644 index 0000000..f9b0d7c --- /dev/null +++ b/example/ios/Runner.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings @@ -0,0 +1,8 @@ + + + + + PreviewsEnabled + + + diff --git a/example/ios/Runner.xcodeproj/xcshareddata/xcschemes/Runner.xcscheme b/example/ios/Runner.xcodeproj/xcshareddata/xcschemes/Runner.xcscheme index 6c78381..a28140c 100644 --- a/example/ios/Runner.xcodeproj/xcshareddata/xcschemes/Runner.xcscheme +++ b/example/ios/Runner.xcodeproj/xcshareddata/xcschemes/Runner.xcscheme @@ -1,6 +1,6 @@ + + - - + + + + - - diff --git a/example/ios/Runner.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/example/ios/Runner.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings new file mode 100644 index 0000000..f9b0d7c --- /dev/null +++ b/example/ios/Runner.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings @@ -0,0 +1,8 @@ + + + + + PreviewsEnabled + + + diff --git a/example/ios/Runner/AppDelegate.m b/example/ios/Runner/AppDelegate.m index 59a72e9..70e8393 100644 --- a/example/ios/Runner/AppDelegate.m +++ b/example/ios/Runner/AppDelegate.m @@ -1,5 +1,5 @@ -#include "AppDelegate.h" -#include "GeneratedPluginRegistrant.h" +#import "AppDelegate.h" +#import "GeneratedPluginRegistrant.h" @implementation AppDelegate diff --git a/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Icon-App-1024x1024@1x.png b/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Icon-App-1024x1024@1x.png index 3d43d11..dc9ada4 100644 Binary files a/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Icon-App-1024x1024@1x.png and b/example/ios/Runner/Assets.xcassets/AppIcon.appiconset/Icon-App-1024x1024@1x.png differ diff --git a/example/ios/Runner/Info.plist b/example/ios/Runner/Info.plist index 3fef927..bb0a45e 100644 --- a/example/ios/Runner/Info.plist +++ b/example/ios/Runner/Info.plist @@ -3,7 +3,7 @@ CFBundleDevelopmentRegion - en + $(DEVELOPMENT_LANGUAGE) CFBundleExecutable $(EXECUTABLE_NAME) CFBundleIdentifier @@ -41,11 +41,5 @@ UIViewControllerBasedStatusBarAppearance - NSPhotoLibraryUsageDescription - We need your permission to access photo gallery - NSCameraUsageDescription - We need your permission to use phone camera - NSMicrophoneUsageDescription - We need your permission to use microsphone diff --git a/example/lib/main.dart b/example/lib/main.dart index 726cde9..963577f 100644 --- a/example/lib/main.dart +++ b/example/lib/main.dart @@ -9,7 +9,7 @@ import 'package:image/image.dart' as img; import 'package:tflite/tflite.dart'; import 'package:image_picker/image_picker.dart'; -void main() => runApp(new App()); +void main() => runApp(const App()); const String mobile = "MobileNet"; const String ssd = "SSD MobileNet"; @@ -18,6 +18,8 @@ const String deeplab = "DeepLab"; const String posenet = "PoseNet"; class App extends StatelessWidget { + const App({Key? key}) : super(key: key); + @override Widget build(BuildContext context) { return MaterialApp( @@ -27,28 +29,31 @@ class App extends StatelessWidget { } class MyApp extends StatefulWidget { + MyApp({Key? key}) : super(key: key); + @override - _MyAppState createState() => new _MyAppState(); + _MyAppState createState() => _MyAppState(); } class _MyAppState extends State { - File _image; - List _recognitions; + File? _image; + List? _recognitions; String _model = mobile; - double _imageHeight; - double _imageWidth; + double? _imageHeight; + double? _imageWidth; bool _busy = false; Future predictImagePicker() async { - var image = await ImagePicker.pickImage(source: ImageSource.gallery); + var image = await ImagePicker().pickImage(source: ImageSource.gallery); if (image == null) return; setState(() { _busy = true; }); - predictImage(image); + File file = File(image.path); + predictImage(file); } - Future predictImage(File image) async { + Future predictImage(File? image) async { if (image == null) return; switch (_model) { @@ -69,9 +74,7 @@ class _MyAppState extends State { // await recognizeImageBinary(image); } - new FileImage(image) - .resolve(new ImageConfiguration()) - .addListener(ImageStreamListener((ImageInfo info, bool _) { + FileImage(image).resolve(const ImageConfiguration()).addListener(ImageStreamListener((ImageInfo info, bool _) { setState(() { _imageHeight = info.image.height.toDouble(); _imageWidth = info.image.width.toDouble(); @@ -100,7 +103,7 @@ class _MyAppState extends State { Future loadModel() async { Tflite.close(); try { - String res; + String? res; switch (_model) { case yolo: res = await Tflite.loadModel( @@ -136,14 +139,13 @@ class _MyAppState extends State { // useGpuDelegate: true, ); } - print(res); + debugPrint(res); } on PlatformException { - print('Failed to load model.'); + debugPrint('Failed to load model.'); } } - Uint8List imageToByteListFloat32( - img.Image image, int inputSize, double mean, double std) { + Uint8List imageToByteListFloat32(img.Image image, int inputSize, double mean, double std) { var convertedBytes = Float32List(1 * inputSize * inputSize * 3); var buffer = Float32List.view(convertedBytes.buffer); int pixelIndex = 0; @@ -174,7 +176,7 @@ class _MyAppState extends State { } Future recognizeImage(File image) async { - int startTime = new DateTime.now().millisecondsSinceEpoch; + int startTime = DateTime.now().millisecondsSinceEpoch; var recognitions = await Tflite.runModelOnImage( path: image.path, numResults: 6, @@ -185,12 +187,12 @@ class _MyAppState extends State { setState(() { _recognitions = recognitions; }); - int endTime = new DateTime.now().millisecondsSinceEpoch; - print("Inference took ${endTime - startTime}ms"); + int endTime = DateTime.now().millisecondsSinceEpoch; + debugPrint("Inference took ${endTime - startTime}ms"); } Future recognizeImageBinary(File image) async { - int startTime = new DateTime.now().millisecondsSinceEpoch; + int startTime = DateTime.now().millisecondsSinceEpoch; var imageBytes = (await rootBundle.load(image.path)).buffer; img.Image oriImage = img.decodeJpg(imageBytes.asUint8List()); img.Image resizedImage = img.copyResize(oriImage, height: 224, width: 224); @@ -202,12 +204,12 @@ class _MyAppState extends State { setState(() { _recognitions = recognitions; }); - int endTime = new DateTime.now().millisecondsSinceEpoch; - print("Inference took ${endTime - startTime}ms"); + int endTime = DateTime.now().millisecondsSinceEpoch; + debugPrint("Inference took ${endTime - startTime}ms"); } Future yolov2Tiny(File image) async { - int startTime = new DateTime.now().millisecondsSinceEpoch; + int startTime = DateTime.now().millisecondsSinceEpoch; var recognitions = await Tflite.detectObjectOnImage( path: image.path, model: "YOLO", @@ -228,12 +230,12 @@ class _MyAppState extends State { setState(() { _recognitions = recognitions; }); - int endTime = new DateTime.now().millisecondsSinceEpoch; - print("Inference took ${endTime - startTime}ms"); + int endTime = DateTime.now().millisecondsSinceEpoch; + debugPrint("Inference took ${endTime - startTime}ms"); } Future ssdMobileNet(File image) async { - int startTime = new DateTime.now().millisecondsSinceEpoch; + int startTime = DateTime.now().millisecondsSinceEpoch; var recognitions = await Tflite.detectObjectOnImage( path: image.path, numResultsPerClass: 1, @@ -248,12 +250,12 @@ class _MyAppState extends State { setState(() { _recognitions = recognitions; }); - int endTime = new DateTime.now().millisecondsSinceEpoch; - print("Inference took ${endTime - startTime}ms"); + int endTime = DateTime.now().millisecondsSinceEpoch; + debugPrint("Inference took ${endTime - startTime}ms"); } Future segmentMobileNet(File image) async { - int startTime = new DateTime.now().millisecondsSinceEpoch; + int startTime = DateTime.now().millisecondsSinceEpoch; var recognitions = await Tflite.runSegmentationOnImage( path: image.path, imageMean: 127.5, @@ -263,24 +265,24 @@ class _MyAppState extends State { setState(() { _recognitions = recognitions; }); - int endTime = new DateTime.now().millisecondsSinceEpoch; - print("Inference took ${endTime - startTime}"); + int endTime = DateTime.now().millisecondsSinceEpoch; + debugPrint("Inference took ${endTime - startTime}"); } Future poseNet(File image) async { - int startTime = new DateTime.now().millisecondsSinceEpoch; + int startTime = DateTime.now().millisecondsSinceEpoch; var recognitions = await Tflite.runPoseNetOnImage( path: image.path, numResults: 2, ); - print(recognitions); + debugPrint(recognitions.toString()); setState(() { _recognitions = recognitions; }); - int endTime = new DateTime.now().millisecondsSinceEpoch; - print("Inference took ${endTime - startTime}ms"); + int endTime = DateTime.now().millisecondsSinceEpoch; + debugPrint("Inference took ${endTime - startTime}ms"); } onSelect(model) async { @@ -291,12 +293,14 @@ class _MyAppState extends State { }); await loadModel(); - if (_image != null) + if (_image != null) { + predictImage(_image); predictImage(_image); - else + } else { setState(() { _busy = false; }); + } } List renderBoxes(Size screen) { @@ -304,9 +308,9 @@ class _MyAppState extends State { if (_imageHeight == null || _imageWidth == null) return []; double factorX = screen.width; - double factorY = _imageHeight / _imageWidth * screen.width; - Color blue = Color.fromRGBO(37, 213, 253, 1.0); - return _recognitions.map((re) { + double factorY = _imageHeight! / _imageWidth! * screen.width; + Color blue = const Color.fromRGBO(37, 213, 253, 1.0); + return _recognitions!.map((re) { return Positioned( left: re["rect"]["x"] * factorX, top: re["rect"]["y"] * factorY, @@ -314,7 +318,7 @@ class _MyAppState extends State { height: re["rect"]["h"] * factorY, child: Container( decoration: BoxDecoration( - borderRadius: BorderRadius.all(Radius.circular(8.0)), + borderRadius: const BorderRadius.all(Radius.circular(8.0)), border: Border.all( color: blue, width: 2, @@ -338,12 +342,11 @@ class _MyAppState extends State { if (_imageHeight == null || _imageWidth == null) return []; double factorX = screen.width; - double factorY = _imageHeight / _imageWidth * screen.width; + double factorY = _imageHeight! / _imageWidth! * screen.width; var lists = []; - _recognitions.forEach((re) { - var color = Color((Random().nextDouble() * 0xFFFFFF).toInt() << 0) - .withOpacity(1.0); + _recognitions!.forEach((re) { + var color = Color((Random().nextDouble() * 0xFFFFFF).toInt() << 0).withOpacity(1.0); var list = re["keypoints"].values.map((k) { return Positioned( left: k["x"] * factorX - 6, @@ -377,21 +380,18 @@ class _MyAppState extends State { left: 0.0, width: size.width, child: _image == null - ? Text('No image selected.') + ? const Text('No image selected.') : Container( - decoration: BoxDecoration( - image: DecorationImage( - alignment: Alignment.topCenter, - image: MemoryImage(_recognitions), - fit: BoxFit.fill)), - child: Opacity(opacity: 0.3, child: Image.file(_image))), + decoration: + BoxDecoration(image: DecorationImage(alignment: Alignment.topCenter, image: MemoryImage(_image!.readAsBytesSync()), fit: BoxFit.fill)), + child: Opacity(opacity: 0.3, child: Image.file(_image!))), )); } else { stackChildren.add(Positioned( top: 0.0, left: 0.0, width: size.width, - child: _image == null ? Text('No image selected.') : Image.file(_image), + child: _image == null ? const Text('No image selected.') : Image.file(_image!), )); } @@ -399,7 +399,7 @@ class _MyAppState extends State { stackChildren.add(Center( child: Column( children: _recognitions != null - ? _recognitions.map((res) { + ? _recognitions!.map((res) { return Text( "${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}", style: TextStyle( @@ -466,7 +466,7 @@ class _MyAppState extends State { floatingActionButton: FloatingActionButton( onPressed: predictImagePicker, tooltip: 'Pick Image', - child: Icon(Icons.image), + child: const Icon(Icons.image), ), ); } diff --git a/example/pubspec.yaml b/example/pubspec.yaml index ceb496c..7449391 100644 --- a/example/pubspec.yaml +++ b/example/pubspec.yaml @@ -1,39 +1,50 @@ name: tflite_example description: Demonstrates how to use the tflite plugin. -# The following defines the version and build number for your application. -# A version number is three numbers separated by dots, like 1.2.43 -# followed by an optional build number separated by a +. -# Both the version and the builder number may be overridden in flutter -# build by specifying --build-name and --build-number, respectively. -# Read more about versioning at semver.org. -version: 1.0.0+1 +# The following line prevents the package from being accidentally published to +# pub.dev using `flutter pub publish`. This is preferred for private packages. +publish_to: 'none' # Remove this line if you wish to publish to pub.dev environment: - sdk: ">=2.0.0-dev.68.0 <3.0.0" + sdk: ">=2.12.0 <3.0.0" +# Dependencies specify other packages that your package needs in order to work. +# To automatically upgrade your package dependencies to the latest versions +# consider running `flutter pub upgrade --major-versions`. Alternatively, +# dependencies can be manually updated by changing the version numbers below to +# the latest version available on pub.dev. To see which dependencies have newer +# versions available, run `flutter pub outdated`. dependencies: flutter: sdk: flutter + tflite: + # When depending on this package from a real application you should use: + # tflite: ^x.y.z + # See https://dart.dev/tools/pub/dependencies#version-constraints + # The example app is bundled with the plugin so we use a path dependency on + # the parent directory to use the current plugin's version. + path: ../ + # The following adds the Cupertino Icons font to your application. # Use with the CupertinoIcons class for iOS style icons. - cupertino_icons: ^0.1.2 + cupertino_icons: ^1.0.2 + image_picker: ^0.8.4+3 + image: ^3.0.8 dev_dependencies: flutter_test: sdk: flutter - image_picker: ^0.6.7 - - image: ^2.1.4 + # The "flutter_lints" package below contains a set of recommended lints to + # encourage good coding practices. The lint set provided by the package is + # activated in the `analysis_options.yaml` file located at the root of your + # package. See that file for information about deactivating specific lint + # rules and activating additional ones. + flutter_lints: ^1.0.0 - tflite: - path: ../ - - test: ^1.12.0 # For information on the generic Dart part of this file, see the -# following page: https://www.dartlang.org/tools/pub/pubspec +# following page: https://dart.dev/tools/pub/pubspec # The following section is specific to Flutter. flutter: @@ -54,13 +65,13 @@ flutter: - assets/deeplabv3_257_mv_gpu.tflite - assets/deeplabv3_257_mv_gpu.txt - assets/posenet_mv1_075_float_from_checkpoints.tflite - + # An image asset can refer to one or more resolution-specific "variants", see - # https://flutter.io/assets-and-images/#resolution-aware. + # https://flutter.dev/assets-and-images/#resolution-aware. # For details regarding adding assets from package dependencies, see - # https://flutter.io/assets-and-images/#from-packages + # https://flutter.dev/assets-and-images/#from-packages # To add custom fonts to your application, add a fonts section here, # in this "flutter" section. Each entry in this list should have a @@ -80,4 +91,4 @@ flutter: # weight: 700 # # For details regarding fonts from package dependencies, - # see https://flutter.io/custom-fonts/#from-packages + # see https://flutter.dev/custom-fonts/#from-packages diff --git a/example/test/widget_test.dart b/example/test/widget_test.dart index 00a807f..cdc58a7 100644 --- a/example/test/widget_test.dart +++ b/example/test/widget_test.dart @@ -1,8 +1,9 @@ // This is a basic Flutter widget test. -// To perform an interaction with a widget in your test, use the WidgetTester utility that Flutter -// provides. For example, you can send tap and scroll gestures. You can also use WidgetTester to -// find child widgets in the widget tree, read text, and verify that the values of widget properties -// are correct. +// +// To perform an interaction with a widget in your test, use the WidgetTester +// utility that Flutter provides. For example, you can send tap and scroll +// gestures. You can also use WidgetTester to find child widgets in the widget +// tree, read text, and verify that the values of widget properties are correct. import 'package:flutter/material.dart'; import 'package:flutter_test/flutter_test.dart'; @@ -12,14 +13,14 @@ import 'package:tflite_example/main.dart'; void main() { testWidgets('Verify Platform version', (WidgetTester tester) async { // Build our app and trigger a frame. - await tester.pumpWidget(new MyApp()); + await tester.pumpWidget(MyApp()); // Verify that platform version is retrieved. expect( - find.byWidgetPredicate( - (Widget widget) => - widget is Text && widget.data.startsWith('Running on:'), - ), - findsOneWidget); + find.byWidgetPredicate( + (Widget widget) => widget is Text && widget.data!.startsWith('Running on:'), + ), + findsOneWidget, + ); }); } diff --git a/example/tflite_example_android.iml b/example/tflite_example_android.iml deleted file mode 100644 index b050030..0000000 --- a/example/tflite_example_android.iml +++ /dev/null @@ -1,27 +0,0 @@ - - - - - - - - - - - - - - - - - - - diff --git a/example/web/favicon.png b/example/web/favicon.png new file mode 100644 index 0000000..8aaa46a Binary files /dev/null and b/example/web/favicon.png differ diff --git a/example/web/icons/Icon-192.png b/example/web/icons/Icon-192.png new file mode 100644 index 0000000..b749bfe Binary files /dev/null and b/example/web/icons/Icon-192.png differ diff --git a/example/web/icons/Icon-512.png b/example/web/icons/Icon-512.png new file mode 100644 index 0000000..88cfd48 Binary files /dev/null and b/example/web/icons/Icon-512.png differ diff --git a/example/web/icons/Icon-maskable-192.png b/example/web/icons/Icon-maskable-192.png new file mode 100644 index 0000000..eb9b4d7 Binary files /dev/null and b/example/web/icons/Icon-maskable-192.png differ diff --git a/example/web/icons/Icon-maskable-512.png b/example/web/icons/Icon-maskable-512.png new file mode 100644 index 0000000..d69c566 Binary files /dev/null and b/example/web/icons/Icon-maskable-512.png differ diff --git a/example/web/index.html b/example/web/index.html new file mode 100644 index 0000000..7e667c6 --- /dev/null +++ b/example/web/index.html @@ -0,0 +1,101 @@ + + + + + + + + + + + + + + + + + tflite_example + + + + + + + diff --git a/example/web/manifest.json b/example/web/manifest.json new file mode 100644 index 0000000..9626c7e --- /dev/null +++ b/example/web/manifest.json @@ -0,0 +1,35 @@ +{ + "name": "tflite_example", + "short_name": "tflite_example", + "start_url": ".", + "display": "standalone", + "background_color": "#0175C2", + "theme_color": "#0175C2", + "description": "Demonstrates how to use the tflite plugin.", + "orientation": "portrait-primary", + "prefer_related_applications": false, + "icons": [ + { + "src": "icons/Icon-192.png", + "sizes": "192x192", + "type": "image/png" + }, + { + "src": "icons/Icon-512.png", + "sizes": "512x512", + "type": "image/png" + }, + { + "src": "icons/Icon-maskable-192.png", + "sizes": "192x192", + "type": "image/png", + "purpose": "maskable" + }, + { + "src": "icons/Icon-maskable-512.png", + "sizes": "512x512", + "type": "image/png", + "purpose": "maskable" + } + ] +} diff --git a/example/yolo.jpg b/example/yolo.jpg deleted file mode 100644 index 33f6714..0000000 Binary files a/example/yolo.jpg and /dev/null differ diff --git a/lib/tflite.dart b/lib/tflite.dart index 66039fd..5c82a8d 100644 --- a/lib/tflite.dart +++ b/lib/tflite.dart @@ -1,36 +1,26 @@ import 'dart:async'; + +import 'package:flutter/services.dart'; import 'dart:typed_data'; import 'dart:ui' show Color; -import 'package:flutter/services.dart'; class Tflite { - static const MethodChannel _channel = const MethodChannel('tflite'); + static const MethodChannel _channel = MethodChannel('tflite'); + + static Future get platformVersion async { + final String? version = await _channel.invokeMethod('getPlatformVersion'); + return version; + } - static Future loadModel( - {required String model, - String labels = "", - int numThreads = 1, - bool isAsset = true, - bool useGpuDelegate = false}) async { + static Future loadModel({required String model, String labels = "", int numThreads = 1, bool isAsset = true, bool useGpuDelegate = false}) async { return await _channel.invokeMethod( 'loadModel', - { - "model": model, - "labels": labels, - "numThreads": numThreads, - "isAsset": isAsset, - 'useGpuDelegate': useGpuDelegate - }, + {"model": model, "labels": labels, "numThreads": numThreads, "isAsset": isAsset, 'useGpuDelegate': useGpuDelegate}, ); } static Future runModelOnImage( - {required String path, - double imageMean = 117.0, - double imageStd = 1.0, - int numResults = 5, - double threshold = 0.1, - bool asynch = true}) async { + {required String path, double imageMean = 117.0, double imageStd = 1.0, int numResults = 5, double threshold = 0.1, bool asynch = true}) async { return await _channel.invokeMethod( 'runModelOnImage', { @@ -44,11 +34,7 @@ class Tflite { ); } - static Future runModelOnBinary( - {required Uint8List binary, - int numResults = 5, - double threshold = 0.1, - bool asynch = true}) async { + static Future runModelOnBinary({required Uint8List binary, int numResults = 5, double threshold = 0.1, bool asynch = true}) async { return await _channel.invokeMethod( 'runModelOnBinary', { @@ -86,18 +72,7 @@ class Tflite { ); } - static const anchors = [ - 0.57273, - 0.677385, - 1.87446, - 2.06253, - 3.33843, - 5.47434, - 7.88282, - 3.52778, - 9.77052, - 9.16828 - ]; + static const anchors = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828]; static Future detectObjectOnImage({ required String path, @@ -164,7 +139,7 @@ class Tflite { double imageStd = 127.5, double threshold = 0.1, int numResultsPerClass = 5, - int rotation: 90, // Android only + int rotation = 90, // Android only // Used in YOLO only List anchors = anchors, int blockSize = 32, @@ -196,11 +171,7 @@ class Tflite { } static Future runPix2PixOnImage( - {required String path, - double imageMean = 0, - double imageStd = 255.0, - String outputType = "png", - bool asynch = true}) async { + {required String path, double imageMean = 0, double imageStd = 255.0, String outputType = "png", bool asynch = true}) async { return await _channel.invokeMethod( 'runPix2PixOnImage', { @@ -213,10 +184,7 @@ class Tflite { ); } - static Future runPix2PixOnBinary( - {required Uint8List binary, - String outputType = "png", - bool asynch = true}) async { + static Future runPix2PixOnBinary({required Uint8List binary, String outputType = "png", bool asynch = true}) async { return await _channel.invokeMethod( 'runPix2PixOnBinary', { @@ -233,7 +201,7 @@ class Tflite { int imageWidth = 720, double imageMean = 0, double imageStd = 255.0, - int rotation: 90, // Android only + int rotation = 90, // Android only String outputType = "png", bool asynch = true, }) async { @@ -254,36 +222,31 @@ class Tflite { // https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/loader/pascal_voc_loader.py static List pascalVOCLabelColors = [ - Color.fromARGB(255, 0, 0, 0).value, // background - Color.fromARGB(255, 128, 0, 0).value, // aeroplane - Color.fromARGB(255, 0, 128, 0).value, // biyclce - Color.fromARGB(255, 128, 128, 0).value, // bird - Color.fromARGB(255, 0, 0, 128).value, // boat - Color.fromARGB(255, 128, 0, 128).value, // bottle - Color.fromARGB(255, 0, 128, 128).value, // bus - Color.fromARGB(255, 128, 128, 128).value, // car - Color.fromARGB(255, 64, 0, 0).value, // cat - Color.fromARGB(255, 192, 0, 0).value, // chair - Color.fromARGB(255, 64, 128, 0).value, // cow - Color.fromARGB(255, 192, 128, 0).value, // diningtable - Color.fromARGB(255, 64, 0, 128).value, // dog - Color.fromARGB(255, 192, 0, 128).value, // horse - Color.fromARGB(255, 64, 128, 128).value, // motorbike - Color.fromARGB(255, 192, 128, 128).value, // person - Color.fromARGB(255, 0, 64, 0).value, // potted plant - Color.fromARGB(255, 128, 64, 0).value, // sheep - Color.fromARGB(255, 0, 192, 0).value, // sofa - Color.fromARGB(255, 128, 192, 0).value, // train - Color.fromARGB(255, 0, 64, 128).value, // tv-monitor + const Color.fromARGB(255, 0, 0, 0).value, // background + const Color.fromARGB(255, 128, 0, 0).value, // aeroplane + const Color.fromARGB(255, 0, 128, 0).value, // biyclce + const Color.fromARGB(255, 128, 128, 0).value, // bird + const Color.fromARGB(255, 0, 0, 128).value, // boat + const Color.fromARGB(255, 128, 0, 128).value, // bottle + const Color.fromARGB(255, 0, 128, 128).value, // bus + const Color.fromARGB(255, 128, 128, 128).value, // car + const Color.fromARGB(255, 64, 0, 0).value, // cat + const Color.fromARGB(255, 192, 0, 0).value, // chair + const Color.fromARGB(255, 64, 128, 0).value, // cow + const Color.fromARGB(255, 192, 128, 0).value, // diningtable + const Color.fromARGB(255, 64, 0, 128).value, // dog + const Color.fromARGB(255, 192, 0, 128).value, // horse + const Color.fromARGB(255, 64, 128, 128).value, // motorbike + const Color.fromARGB(255, 192, 128, 128).value, // person + const Color.fromARGB(255, 0, 64, 0).value, // potted plant + const Color.fromARGB(255, 128, 64, 0).value, // sheep + const Color.fromARGB(255, 0, 192, 0).value, // sofa + const Color.fromARGB(255, 128, 192, 0).value, // train + const Color.fromARGB(255, 0, 64, 128).value, // tv-monitor ]; static Future runSegmentationOnImage( - {required String path, - double imageMean = 0, - double imageStd = 255.0, - List? labelColors, - String outputType = "png", - bool asynch = true}) async { + {required String path, double imageMean = 0, double imageStd = 255.0, List? labelColors, String outputType = "png", bool asynch = true}) async { return await _channel.invokeMethod( 'runSegmentationOnImage', { @@ -297,11 +260,7 @@ class Tflite { ); } - static Future runSegmentationOnBinary( - {required Uint8List binary, - List? labelColors, - String outputType = "png", - bool asynch = true}) async { + static Future runSegmentationOnBinary({required Uint8List binary, List? labelColors, String outputType = "png", bool asynch = true}) async { return await _channel.invokeMethod( 'runSegmentationOnBinary', { @@ -319,7 +278,7 @@ class Tflite { int imageWidth = 720, double imageMean = 0, double imageStd = 255.0, - int rotation: 90, // Android only + int rotation = 90, // Android only List? labelColors, String outputType = "png", bool asynch = true}) async { @@ -362,11 +321,7 @@ class Tflite { } static Future runPoseNetOnBinary( - {required Uint8List binary, - int numResults = 5, - double threshold = 0.5, - int nmsRadius = 20, - bool asynch = true}) async { + {required Uint8List binary, int numResults = 5, double threshold = 0.5, int nmsRadius = 20, bool asynch = true}) async { return await _channel.invokeMethod( 'runPoseNetOnBinary', { @@ -385,7 +340,7 @@ class Tflite { int imageWidth = 720, double imageMean = 127.5, double imageStd = 127.5, - int rotation: 90, // Android only + int rotation = 90, // Android only int numResults = 5, double threshold = 0.5, int nmsRadius = 20, diff --git a/lib/tflite_web.dart b/lib/tflite_web.dart new file mode 100644 index 0000000..d357b1b --- /dev/null +++ b/lib/tflite_web.dart @@ -0,0 +1,44 @@ +import 'dart:async'; +// In order to *not* need this ignore, consider extracting the "web" version +// of your plugin as a separate package, instead of inlining it in the same +// package as the core of your plugin. +// ignore: avoid_web_libraries_in_flutter +import 'dart:html' as html show window; + +import 'package:flutter/services.dart'; +import 'package:flutter_web_plugins/flutter_web_plugins.dart'; + +/// A web implementation of the Tflite plugin. +class TfliteWeb { + static void registerWith(Registrar registrar) { + final MethodChannel channel = MethodChannel( + 'tflite', + const StandardMethodCodec(), + registrar, + ); + + final pluginInstance = TfliteWeb(); + channel.setMethodCallHandler(pluginInstance.handleMethodCall); + } + + /// Handles method calls over the MethodChannel of this plugin. + /// Note: Check the "federated" architecture for a new way of doing this: + /// https://flutter.dev/go/federated-plugins + Future handleMethodCall(MethodCall call) async { + switch (call.method) { + case 'getPlatformVersion': + return getPlatformVersion(); + default: + throw PlatformException( + code: 'Unimplemented', + details: 'tflite for web doesn\'t implement \'${call.method}\'', + ); + } + } + + /// Returns a [String] containing the version of the platform. + Future getPlatformVersion() { + final version = html.window.navigator.userAgent; + return Future.value(version); + } +} diff --git a/pubspec.yaml b/pubspec.yaml index 2cda2a4..ac2b02c 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,64 +1,70 @@ -name: tflite -description: A Flutter plugin for accessing TensorFlow Lite. Supports both iOS and Android. -version: 1.1.2 -homepage: https://github.com/shaqian/flutter_tflite - -environment: - sdk: '>=2.12.0 <3.0.0' - flutter: ">=1.10.0" - -dependencies: - flutter: - sdk: flutter - - meta: ^1.3.0 - -dev_dependencies: - flutter_test: - sdk: flutter - test: ^1.16.5 - -# For information on the generic Dart part of this file, see the -# following page: https://www.dartlang.org/tools/pub/pubspec - -# The following section is specific to Flutter. -flutter: - plugin: - platforms: - android: - package: sq.flutter.tflite - pluginClass: TflitePlugin - ios: - pluginClass: TflitePlugin - - - # To add assets to your plugin package, add an assets section, like this: - # assets: - # - images/a_dot_burr.jpeg - # - images/a_dot_ham.jpeg - # - # For details regarding assets in packages, see - # https://flutter.io/assets-and-images/#from-packages - # - # An image asset can refer to one or more resolution-specific "variants", see - # https://flutter.io/assets-and-images/#resolution-aware. - - # To add custom fonts to your plugin package, add a fonts section here, - # in this "flutter" section. Each entry in this list should have a - # "family" key with the font family name, and a "fonts" key with a - # list giving the asset and other descriptors for the font. For - # example: - # fonts: - # - family: Schyler - # fonts: - # - asset: fonts/Schyler-Regular.ttf - # - asset: fonts/Schyler-Italic.ttf - # style: italic - # - family: Trajan Pro - # fonts: - # - asset: fonts/TrajanPro.ttf - # - asset: fonts/TrajanPro_Bold.ttf - # weight: 700 - # - # For details regarding fonts in packages, see - # https://flutter.io/custom-fonts/#from-packages +name: tflite +description: A Flutter plugin for accessing TensorFlow Lite. Supports both iOS and Android, Web under development. +version: 1.1.3 +homepage: https://github.com/shaqian/flutter_tflite + +environment: + sdk: ">=2.12.0 <3.0.0" + flutter: ">=1.20.0" + +dependencies: + flutter: + sdk: flutter + flutter_web_plugins: + sdk: flutter + +dev_dependencies: + flutter_test: + sdk: flutter + flutter_lints: ^1.0.0 + +# For information on the generic Dart part of this file, see the +# following page: https://dart.dev/tools/pub/pubspec + +# The following section is specific to Flutter. +flutter: + # This section identifies this Flutter project as a plugin project. + # The 'pluginClass' and Android 'package' identifiers should not ordinarily + # be modified. They are used by the tooling to maintain consistency when + # adding or updating assets for this project. + plugin: + platforms: + android: + package: sq.flutter.tflite + pluginClass: TflitePlugin + ios: + pluginClass: TflitePlugin + web: + pluginClass: TfliteWeb + fileName: tflite_web.dart + + # To add assets to your plugin package, add an assets section, like this: + # assets: + # - images/a_dot_burr.jpeg + # - images/a_dot_ham.jpeg + # + # For details regarding assets in packages, see + # https://flutter.dev/assets-and-images/#from-packages + # + # An image asset can refer to one or more resolution-specific "variants", see + # https://flutter.dev/assets-and-images/#resolution-aware. + + # To add custom fonts to your plugin package, add a fonts section here, + # in this "flutter" section. Each entry in this list should have a + # "family" key with the font family name, and a "fonts" key with a + # list giving the asset and other descriptors for the font. For + # example: + # fonts: + # - family: Schyler + # fonts: + # - asset: fonts/Schyler-Regular.ttf + # - asset: fonts/Schyler-Italic.ttf + # style: italic + # - family: Trajan Pro + # fonts: + # - asset: fonts/TrajanPro.ttf + # - asset: fonts/TrajanPro_Bold.ttf + # weight: 700 + # + # For details regarding fonts in packages, see + # https://flutter.dev/custom-fonts/#from-packages