diff --git a/ai_image_test/lib/engines/onnx_directml_engine.dart b/ai_image_test/lib/engines/onnx_directml_engine.dart index c8ba676..9658735 100644 --- a/ai_image_test/lib/engines/onnx_directml_engine.dart +++ b/ai_image_test/lib/engines/onnx_directml_engine.dart @@ -60,22 +60,46 @@ class OnnxDirectMLEngine implements InferenceEngine { bool gpuInitialized = false; - // Try to initialize DirectML on Windows + // Try to initialize DirectML on Windows. + // + // NOTE: The high-level `onnxruntime` Dart package (v1.4.x) does NOT + // expose an `appendExecutionProvider_DML` method on OrtSessionOptions. + // True DirectML GPU acceleration therefore requires: + // 1. `onnxruntime-directml.dll` (from Microsoft.ML.OnnxRuntime.DirectML) + // placed next to the app executable, AND + // 2. Calling OrtSessionOptions_AppendExecutionProvider_DML via the raw + // C API (see OnnxRuntimeDirectML in onnx_directml_ffi.dart). + // + // Because the Dart package manages the OrtSessionOptions* pointer + // internally and does not expose it, we cannot inject the DirectML + // provider into the high-level session options object at this time. + // Instead, XNNPACK is used as the best CPU-accelerated fallback. + // Replace the lines below with a proper DirectML injection once the + // package exposes the raw options pointer or a dedicated pub package + // (e.g. `onnxruntime_directml`) becomes available. if (device == 'DirectML' && Platform.isWindows) { try { _directML = OnnxRuntimeDirectML(); if (_directML!.initialize()) { - // Note: The current onnxruntime Dart package doesn't expose DirectML provider - // We'll attempt to use it through the standard provider mechanism - // For now, log that DirectML is requested but we'll use CPU/XNNPACK - debugPrint('DirectML requested but not directly supported by onnxruntime package'); - debugPrint('Using XNNPACK as alternative GPU-optimized backend'); + if (_directML!.isDirectMLFunctionAvailable) { + // DirectML DLL is present and the provider function is resolved. + // Ideally we would call: + // _directML!.appendExecutionProviderDML(optionsPtr, 0) + // but the OrtSessionOptions* pointer is not accessible from + // the Dart package's public API. + // Falling back to XNNPACK as the next-best option. + debugPrint('OnnxDirectMLEngine: DirectML DLL found but cannot inject ' + 'provider via Dart package API; using XNNPACK'); + } else { + debugPrint('OnnxDirectMLEngine: onnxruntime-directml.dll not found; ' + 'using XNNPACK (install DirectML DLL for GPU acceleration)'); + } _opts!.appendXnnpackProvider(); _actualDevice = 'XNNPACK'; gpuInitialized = true; } } catch (e) { - debugPrint('DirectML initialization failed: $e'); + debugPrint('OnnxDirectMLEngine: DirectML initialization failed: $e'); } } @@ -149,12 +173,14 @@ class OnnxDirectMLEngine implements InferenceEngine { if (image == null) return null; final resized = img.copyResize(image, width: inputSize, height: inputSize); final data = []; - // NCHW format + // NCHW format: (pixel - 127.5) / 127.5 正規化 → [-1.0, 1.0] + // MobileFaceNet等の顔認証モデルの標準的な前処理 for (int c = 0; c < 3; c++) { for (int y = 0; y < inputSize; y++) { for (int x = 0; x < inputSize; x++) { final px = resized.getPixel(x, y); - data.add((c == 0 ? px.r : c == 1 ? px.g : px.b) / 255.0); + final raw = c == 0 ? px.r : c == 1 ? px.g : px.b; + data.add((raw - 127.5) / 127.5); } } } diff --git a/ai_image_test/lib/engines/onnx_engine.dart b/ai_image_test/lib/engines/onnx_engine.dart index f80aa5c..4e627da 100644 --- a/ai_image_test/lib/engines/onnx_engine.dart +++ b/ai_image_test/lib/engines/onnx_engine.dart @@ -136,12 +136,14 @@ class OnnxEngine implements InferenceEngine { if (image == null) return null; final resized = img.copyResize(image, width: inputSize, height: inputSize); final data = []; - // NCHW format + // NCHW format: (pixel - 127.5) / 127.5 正規化 → [-1.0, 1.0] + // MobileFaceNet等の顔認証モデルの標準的な前処理 for (int c = 0; c < 3; c++) { for (int y = 0; y < inputSize; y++) { for (int x = 0; x < inputSize; x++) { final px = resized.getPixel(x, y); - data.add((c == 0 ? px.r : c == 1 ? px.g : px.b) / 255.0); + final raw = c == 0 ? px.r : c == 1 ? px.g : px.b; + data.add((raw - 127.5) / 127.5); } } } diff --git a/ai_image_test/lib/engines/tflite_engine.dart b/ai_image_test/lib/engines/tflite_engine.dart index 48ab5b0..da8dc2a 100644 --- a/ai_image_test/lib/engines/tflite_engine.dart +++ b/ai_image_test/lib/engines/tflite_engine.dart @@ -119,12 +119,17 @@ class TfliteEngine implements InferenceEngine { } return _runInferenceQuantized(data); } else { - // float32 モデル: 0.0-1.0 正規化入力 + // float32 モデル: (pixel - 127.5) / 127.5 正規化入力 → [-1.0, 1.0] + // MobileFaceNet等の顔認証モデルの標準的な前処理 final data = []; for (int y = 0; y < inputSize; y++) { for (int x = 0; x < inputSize; x++) { final px = resized.getPixel(x, y); - data.addAll([px.r / 255.0, px.g / 255.0, px.b / 255.0]); + data.addAll([ + (px.r - 127.5) / 127.5, + (px.g - 127.5) / 127.5, + (px.b - 127.5) / 127.5, + ]); } } return _runInference(data); diff --git a/ai_image_test/lib/services/onnx_directml_ffi.dart b/ai_image_test/lib/services/onnx_directml_ffi.dart index d2298b8..e91972e 100644 --- a/ai_image_test/lib/services/onnx_directml_ffi.dart +++ b/ai_image_test/lib/services/onnx_directml_ffi.dart @@ -1,42 +1,96 @@ import 'dart:ffi' as ffi; import 'dart:io'; -import 'package:ffi/ffi.dart'; import 'package:flutter/foundation.dart'; +// FFI type definitions for OrtSessionOptions_AppendExecutionProvider_DML +// C signature: +// OrtStatus* OrtSessionOptions_AppendExecutionProvider_DML( +// OrtSessionOptions* options, int device_id) +typedef _AppendDmlNative = ffi.Pointer Function( + ffi.Pointer options, ffi.Int32 deviceId); +typedef _AppendDmlDart = ffi.Pointer Function( + ffi.Pointer options, int deviceId); + /// FFI bindings for ONNX Runtime with DirectML support. -/// +/// /// This provides low-level bindings to the ONNX Runtime C API /// with DirectML execution provider for GPU acceleration on Windows. +/// +/// **Requirements for DirectML support:** +/// The standard `onnxruntime` Dart package bundles `onnxruntime.dll` which +/// does NOT include DirectML. To enable GPU acceleration via DirectML you +/// need the DirectML-enabled DLL (e.g. `onnxruntime-directml.dll` from the +/// Microsoft.ML.OnnxRuntime.DirectML NuGet package) placed next to the app +/// executable, OR a Dart package that bundles that DLL. As of onnxruntime +/// v1.4.1 there is no separate `onnxruntime_directml` pub.dev package, so +/// the DLL swap is the recommended approach on Windows. class OnnxRuntimeDirectML { static const String _libName = 'onnxruntime'; - + late final ffi.DynamicLibrary _lib; bool _initialized = false; + _AppendDmlDart? _appendDml; /// Initialize the DirectML library. - /// Returns true if successful, false otherwise. + /// Returns true if the DirectML-capable DLL was loaded successfully. bool initialize() { if (_initialized) return true; - + + if (!Platform.isWindows) { + throw UnsupportedError('DirectML is only supported on Windows'); + } + try { - // Try to load ONNX Runtime DLL with DirectML support - // User needs to have onnxruntime-directml.dll in the path - if (Platform.isWindows) { - try { - _lib = ffi.DynamicLibrary.open('$_libName.dll'); - } catch (e) { - debugPrint('Failed to load onnxruntime.dll, trying onnxruntime-directml.dll: $e'); - _lib = ffi.DynamicLibrary.open('onnxruntime-directml.dll'); - } - } else { - throw UnsupportedError('DirectML is only supported on Windows'); + // Prefer the DirectML-enabled DLL; fall back to the standard one so + // that the app still runs (without GPU acceleration) when the DirectML + // DLL is absent. + try { + _lib = ffi.DynamicLibrary.open('onnxruntime-directml.dll'); + debugPrint('OnnxRuntimeDirectML: loaded onnxruntime-directml.dll'); + } catch (_) { + _lib = ffi.DynamicLibrary.open('$_libName.dll'); + debugPrint('OnnxRuntimeDirectML: DirectML DLL not found, loaded standard onnxruntime.dll'); + } + + // Try to resolve the DirectML provider function. + // This will succeed only when onnxruntime-directml.dll is loaded. + try { + final fn = _lib.lookup>( + 'OrtSessionOptions_AppendExecutionProvider_DML'); + _appendDml = fn.asFunction<_AppendDmlDart>(); + debugPrint('OnnxRuntimeDirectML: OrtSessionOptions_AppendExecutionProvider_DML resolved'); + } catch (e) { + debugPrint('OnnxRuntimeDirectML: AppendExecutionProvider_DML not found ' + '(standard DLL without DirectML): $e'); + _appendDml = null; } - + _initialized = true; - debugPrint('ONNX Runtime DirectML library loaded successfully'); - return true; + return _appendDml != null; // true only if DirectML is really available + } catch (e) { + debugPrint('OnnxRuntimeDirectML: failed to load ONNX Runtime library: $e'); + return false; + } + } + + /// Returns true if the DirectML execution provider function was resolved. + bool get isDirectMLFunctionAvailable => _initialized && _appendDml != null; + + /// Attempt to append the DirectML execution provider to the given raw + /// OrtSessionOptions* pointer. + /// + /// Returns true on success, false if DirectML is not available or the call + /// fails. [optionsPtr] must be the raw `OrtSessionOptions*` obtained from + /// the underlying ONNX Runtime C API. + bool appendExecutionProviderDML( + ffi.Pointer optionsPtr, int deviceId) { + if (_appendDml == null) return false; + try { + final status = _appendDml!(optionsPtr, deviceId); + // A null status pointer means success in the ORT C API. + return status == ffi.nullptr; } catch (e) { - debugPrint('Failed to load ONNX Runtime DirectML library: $e'); + debugPrint('OnnxRuntimeDirectML: appendExecutionProvider_DML failed: $e'); return false; } } @@ -44,12 +98,11 @@ class OnnxRuntimeDirectML { /// Check if DirectML is available on this system. static bool isAvailable() { if (!Platform.isWindows) return false; - try { final lib = OnnxRuntimeDirectML(); return lib.initialize(); } catch (e) { - debugPrint('DirectML not available: $e'); + debugPrint('OnnxRuntimeDirectML: not available: $e'); return false; } } @@ -57,5 +110,6 @@ class OnnxRuntimeDirectML { /// Dispose resources. void dispose() { _initialized = false; + _appendDml = null; } } diff --git a/ai_image_test/pubspec.yaml b/ai_image_test/pubspec.yaml index e178ddf..583201e 100644 --- a/ai_image_test/pubspec.yaml +++ b/ai_image_test/pubspec.yaml @@ -35,6 +35,12 @@ dependencies: # Use with the CupertinoIcons class for iOS style icons. cupertino_icons: ^1.0.8 tflite_flutter: ^0.12.1 + # NOTE: Windows DirectML (GPU) support requires the DirectML-enabled ONNX + # Runtime DLL. The standard `onnxruntime` package bundles onnxruntime.dll + # WITHOUT DirectML. To enable GPU acceleration on Windows, place + # `onnxruntime-directml.dll` (from the Microsoft.ML.OnnxRuntime.DirectML + # NuGet package) next to the app executable. A future `onnxruntime_directml` + # pub.dev package may provide a drop-in replacement when available. onnxruntime: ^1.4.1 image: ^4.7.2 file_picker: ^10.3.10