diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..b25763334f227 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # build, distribute, and bins (+ python proto bindings) +build.*/ build build_*/ .build_debug/* diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index d42f23134805e..28efa6ea8cb3f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -343,7 +343,7 @@ if (onnxruntime_USE_ROCM) if (ROCM_VERSION_DEV VERSION_LESS "6.2") message(FATAL_ERROR "CMAKE_HIP_ARCHITECTURES is not set when ROCm version < 6.2") else() - set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx942;gfx950;gfx1200;gfx1201;gfx1150") + set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx942;gfx950;gfx1200;gfx1201") endif() endif() diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 567bcaf581842..90c0f447800c6 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -2,21 +2,11 @@ # Licensed under the MIT License. add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) + if (CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) @@ -33,8 +23,6 @@ find_package(hip REQUIRED) find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host) - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" @@ -44,15 +32,15 @@ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) - if(MSVC) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx) + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32 shlwapi) else() target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9797d8019f2d3..6c7ebdf67e43f 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -609,7 +609,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_ROCM) @@ -694,9 +693,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) - list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_NNAPI_BUILTIN) diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 6779fd60bcd0a..28a837f44ae07 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -163,6 +163,11 @@ CMake creates a target to this project Targets ="CreateNativePackage" /> + + + diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 77c35aac65b92..839c695a23791 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -336,6 +336,13 @@ public struct OrtApi public IntPtr GetModelEditorApi; public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; public IntPtr SessionOptionsSetLoadCancellationFlag; + + public IntPtr CreateMIGraphXProviderOptions; + public IntPtr UpdateMIGraphXProviderOptions; + public IntPtr GetMIGraphXProviderOptionsAsString; + public IntPtr ReleaseMIGraphXProviderOptions; + public IntPtr UpdateMIGraphXProviderOptionsWithValue; + public IntPtr GetMIGraphXProviderOptionsByName; } internal static class NativeMethods @@ -573,6 +580,18 @@ static NativeMethods() OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions)); OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString)); OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); + SessionOptionsAppendExecutionProvider_MIGraphX = (DSessionOptionsAppendExecutionProvider_MIGraphX)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider_MIGraphX, typeof(DSessionOptionsAppendExecutionProvider_MIGraphX)); + OrtCreateMIGraphXProviderOptions = (DOrtCreateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateMIGraphXProviderOptions, typeof(DOrtCreateMIGraphXProviderOptions)); + OrtUpdateMIGraphXProviderOptions = (DOrtUpdateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateMIGraphXProviderOptions, typeof(DOrtUpdateMIGraphXProviderOptions)); + OrtGetMIGraphXProviderOptionsAsString = (DOrtGetMIGraphXProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetMIGraphXProviderOptionsAsString, typeof(DOrtGetMIGraphXProviderOptionsAsString)); + OrtReleaseMIGraphXProviderOptions = (DOrtReleaseMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseMIGraphXProviderOptions, typeof(DOrtReleaseMIGraphXProviderOptions)); + OrtUpdateMIGraphXProviderOptionsWithValue = + (DOrtUpdateMIGraphXProviderOptionsWithValue)Marshal.GetDelegateForFunctionPointer( + api_.UpdateMIGraphXProviderOptionsWithValue, typeof(DOrtUpdateMIGraphXProviderOptionsWithValue)); + OrtGetMIGraphXProviderOptionsByName = + (DOrtGetMIGraphXProviderOptionsByName)Marshal.GetDelegateForFunctionPointer( + api_.GetMIGraphXProviderOptionsByName, typeof(DOrtGetMIGraphXProviderOptionsByName)); OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, @@ -799,6 +818,80 @@ internal class NativeLib #endregion +#region Provider Options API + /// + /// Creates native OrtMIGraphXProviderOptions instance + /// + /// (output) native instance of OrtMIGraphXProviderOptions + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateMIGraphXProviderOptions( + out IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance); + public static DOrtCreateMIGraphXProviderOptions OrtCreateMIGraphXProviderOptions; + + /// + /// Updates native OrtMIGraphXProviderOptions instance using given key/value pairs + /// + /// native instance of OrtMIGraphXProviderOptions + /// configuration keys of OrtMIGraphXProviderOptions + /// configuration values of OrtMIGraphXProviderOptions + /// number of configuration keys + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtUpdateMIGraphXProviderOptions( + IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance, + IntPtr[] /*(const char* const *)*/ providerOptionsKeys, + IntPtr[] /*(const char* const *)*/ providerOptionsValues, + UIntPtr /*(size_t)*/ numKeys); + public static DOrtUpdateMIGraphXProviderOptions OrtUpdateMIGraphXProviderOptions; + + /// + /// Get native OrtMIGraphXProviderOptions in serialized string + /// + /// instance of OrtAllocator + /// is a UTF-8 null terminated string allocated using 'allocator' + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetMIGraphXProviderOptionsAsString( + IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ ptr); + public static DOrtGetMIGraphXProviderOptionsAsString OrtGetMIGraphXProviderOptionsAsString; + + /// + /// Releases native OrtMIGraphXProviderOptions instance + /// + /// native instance of OrtMIGraphXProviderOptions to be released + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseMIGraphXProviderOptions(IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance); + public static DOrtReleaseMIGraphXProviderOptions OrtReleaseMIGraphXProviderOptions; + + /// + /// Update native OrtMIGraphXProviderOptions with value + /// + /// native instance of OrtMIGraphXProviderOptions to be released + /// configuration key of OrtMIGraphXProviderOptions + /// configuration value of OrtMIGraphXProviderOptions + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtUpdateMIGraphXProviderOptionsWithValue( + IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, + IntPtr /*(char*)*/ providerOptionsKey, + IntPtr /*(char*)*/ providerOptionsValue); + public static DOrtUpdateMIGraphXProviderOptionsWithValue OrtUpdateMIGraphXProviderOptionsWithValue; + + /// + /// Get native OrtMIGraphXProviderOptions value by name + /// + /// native instance of OrtMIGraphXProviderOptions to be released + /// configuration key of OrtMIGraphXProviderOptions + /// configuration value of OrtMIGraphXProviderOptions to return + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetMIGraphXProviderOptionsByName( + IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance, + IntPtr /*(char*)*/ providerOptionsKey, + out IntPtr /*(char**)*/ providerOptionsValue); + public static DOrtGetMIGraphXProviderOptionsByName OrtGetMIGraphXProviderOptionsByName; + + +#endregion + #region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); @@ -1160,6 +1253,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); + + [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int use_arena, int device_id); #endif /// /// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance @@ -1221,6 +1317,18 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM; + /// + /// Append a MIGraphX EP instance (configured based on given provider options) to the native OrtSessionOptions instance + /// + /// Native OrtSessionOptions instance + /// Native OrtMIGraphXProviderOptions instance + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_MIGraphX( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtMIGraphXProviderOptions*)*/ migraphxProviderOptions); + + public static DSessionOptionsAppendExecutionProvider_MIGraphX SessionOptionsAppendExecutionProvider_MIGraphX; + /// /// Free Dimension override (by denotation) /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs index 1b9cd7572170b..335b4ef8b3f65 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs @@ -291,6 +291,142 @@ protected override bool ReleaseHandle() } +/// + /// Holds the options for configuring an MIGraphX Execution Provider instance + /// + public class OrtMIGraphXProviderOptions : SafeHandle + { + internal IntPtr Handle + { + get + { + return handle; + } + } + + public int DeviceId + { + get { return _deviceId; } + set + { + UpdateProviderOptionWithValue(_deviceIdPtr, value.ToString()); + _deviceId = value; + } + } + private IntPtr _deviceIdPtr = Marshal.StringToHGlobalAnsi("device_id"); + private int _deviceId = 0; + + public string ModelCacheDir + { + get { return _modelCacheDir; } + set + { + UpdateProviderOptionWithValue(_modelCacheDirPtr, value); + _modelCacheDir = value; + } + } + + private IntPtr _modelCacheDirPtr = Marshal.StringToHGlobalAnsi("migraphx_model_cache_dir"); + private string _modelCacheDir = ""; + + #region Constructor + + /// + /// Constructs an empty OrtMIGraphXProviderOptions instance + /// + public OrtMIGraphXProviderOptions() : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMIGraphXProviderOptions(out handle)); + } + + #endregion + + #region Finalizer + + ~OrtMIGraphXProviderOptions() + { + Marshal.FreeHGlobal(_deviceIdPtr); + Marshal.FreeHGlobal(_modelCacheDirPtr); + } + + #endregion + + #region Public Methods + + /// + /// Get MIGraphX EP provider options + /// + /// return C# UTF-16 encoded string + public string GetOptions() + { + var allocator = OrtAllocator.DefaultInstance; + // Process provider options string + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMIGraphXProviderOptionsAsString(handle, + allocator.Pointer, out IntPtr providerOptions)); + return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator); + } + + /// + /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP + /// + /// Array of keys to set that correspond with values. + /// Array of values to set that correspond with keys. + /// The number of key/value pairs in the arrays. + private static IntPtr UpdateMIGraphXProviderOptions(IntPtr handle, IntPtr[] keys, IntPtr[] values, UIntPtr count) + { + return NativeMethods.OrtUpdateMIGraphXProviderOptions(handle, keys, values, count); + } + + /// + /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP + /// + /// key/value pairs used to configure a MIGraphX Execution Provider + public void UpdateOptions(Dictionary providerOptions) + { + ProviderOptionsUpdater.Update(providerOptions, handle, UpdateMIGraphXProviderOptions); + } + + #endregion + + #region Public Properties + + /// + /// Overrides SafeHandle.IsInvalid + /// + /// returns true if handle is equal to Zero + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + #endregion + + #region Private Methods + + private void UpdateProviderOptionWithValue(IntPtr key, string value) + { + IntPtr valuePtr = Marshal.StringToHGlobalAnsi(value); + var nativeStatus = NativeMethods.OrtUpdateMIGraphXProviderOptionsWithValue(handle, key, valuePtr); + Marshal.FreeHGlobal(valuePtr); + NativeApiStatus.VerifySuccess(nativeStatus); + } + + #endregion + + #region SafeHandle + /// + /// Overrides SafeHandle.ReleaseHandle() to properly dispose of + /// the native instance of OrtMIGraphXProviderOptions + /// + /// always returns true + protected override bool ReleaseHandle() + { + NativeMethods.OrtReleaseMIGraphXProviderOptions(handle); + handle = IntPtr.Zero; + return true; + } + + #endregion + } + + /// /// This helper class contains methods to handle values of provider options /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 9b0f183f03681..3e1515cf5b78c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -42,6 +42,9 @@ public class SessionOptions : SafeHandle private static string[] cudaDelayLoadedLibs = { }; private static string[] trtDelayLoadedLibs = { }; + // Delay-loaded MIGraphX DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information. + private static string[] migxDelayLoadedLibs = { }; + #region Constructor and Factory methods /// @@ -189,6 +192,28 @@ public static SessionOptions MakeSessionOptionWithRocmProvider(OrtROCMProviderOp throw; } } + + /// + /// A helper method to construct a SessionOptions object for MIGraaphX execution provider. + /// Use only if MIGraphX is installed and you have the onnxruntime package specific to this Execution Provider. + /// + /// MIGraphX EP provider options + /// A SessionsOptions() object configured for execution on provider options + public static SessionOptions MakeSessionOptionWithMIGraphXProvider(OrtMIGraphXProviderOptions migxProviderOptions) + { + CheckMIGraphXExecutionProviderDLLs(); + SessionOptions options = new SessionOptions(); + try + { + options.AppendExecutionProvider_MIGraphX(migxProviderOptions); + return options; + } + catch (Exception) + { + options.Dispose(); + throw; + } + } #endregion #region ExecutionProviderAppends @@ -331,12 +356,25 @@ public void AppendExecutionProvider_ROCm(OrtROCMProviderOptions rocmProviderOpti public void AppendExecutionProvider_MIGraphX(int deviceId = 0) { #if __MOBILE__ - throw new NotSupportedException($"The MIGraphX Execution Provider is not supported in this build"); + throw new NotSupportedException("The MIGraphX Execution Provider is not supported in this build"); #else NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(handle, deviceId)); #endif } + /// + /// Use only if you have the onnxruntime package specific to this Execution Provider. + /// + /// device identification + public void AppendExecutionProvider_MIGraphX(OrtMIGraphXProviderOptions migraphxProviderOptions) + { +#if __MOBILE__ + throw new NotSupportedException($"The AMD Nitris Execution Provider is not supported in this build"); +#else + NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_MIGraphX(handle, migraphxProviderOptions.Handle)); +#endif + } + /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// @@ -885,6 +923,27 @@ private static bool CheckRocmExecutionProviderDLLs() return true; } + private static bool CheckMIGraphXExecutionProviderDLLs() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + foreach (var dll in migxDelayLoadedLibs) + { + IntPtr handle = LoadLibrary(dll); + if (handle != IntPtr.Zero) + continue; + var sysdir = new StringBuilder(String.Empty, 2048); + GetSystemDirectory(sysdir, (uint)sysdir.Capacity); + throw new OnnxRuntimeException( + ErrorCode.NoSuchFile, + $"kernel32.LoadLibrary():'{dll}' not found. MIGraphX are required for GPU execution. " + + $". Verify it is available in the system directory={sysdir}. Else copy it to the output folder." + ); + } + } + return true; + } + #endregion #region SafeHandle diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index af9b517b77e5c..6e2b4fa6aba3f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -675,10 +675,10 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name const char* migraphx_cache_dir; // MIGraphX model cache directory - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false + int migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. @@ -5265,6 +5265,88 @@ struct OrtApi { * \since Version 1.22. */ const OrtEpApi*(ORT_API_CALL* GetEpApi)(); + + /// @} + /// \name OrtMIGraphXProviderOptions + /// @{ + + /** \brief Create an OrtMIGraphXProviderOptions + * + * \param[out] out Newly created ::OrtMIGraphXProviderOptions. Must be released with OrtApi::ReleaseMIGraphXProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out); + + /** \brief Set options in a MIGraphX Execution Provider. + * + * For example, key="device_id" and value="0" + * + * \param[in] migraphx_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized MIGraphX provider options string. + * + * For example, "device_id=0;;......" + * + * \param migraphx_options - OrtMIGraphXProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtMIGraphXProviderOptions + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.xx. + */ + void(ORT_API_CALL* ReleaseMIGraphXProviderOptions)(_Frees_ptr_opt_ OrtMIGraphXProviderOptions* input); + + /** + * Update MIGraphX EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateMIGraphXProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param migraphx_options - OrtMIGraphXProviderOptions instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value); + + /** + * Get MIGraphX EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetMIGraphXProviderOptionsAsString. + * + * \param migraphx_options - OrtMIGraphXProviderOptions instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.xx. + */ + ORT_API2_STATUS(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr); }; /* diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index c0d8a4f02bbc3..776fd5fec367f 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -562,7 +562,10 @@ static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_devi sizeof(feature_levels) )); - auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE); + // Use compute queue whenever possible on supported hardware to avoid TDR and maintain UI QoS + // Core and generic devices only have compute queues, DX11 has "immediate" submission, DX12 has both + auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE) || + (feature_levels.MaxSupportedFeatureLevel >= D3D_FEATURE_LEVEL_12_0); if (use_compute_command_list) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index cf9f44f4cd8f0..17dfdf4519b16 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -23,11 +23,11 @@ void MIGraphXAllocator::CheckDevice() const { #endif } -void* MIGraphXAllocator::Alloc(size_t size) { +void* MIGraphXAllocator::Alloc(const size_t size) { CheckDevice(); void* p = nullptr; if (size > 0) { - HIP_CALL_THROW(hipMalloc((void**)&p, size)); + HIP_CALL_THROW(hipMalloc(&p, size)); } return p; } @@ -37,7 +37,7 @@ void MIGraphXAllocator::Free(void* p) { (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown } -void* MIGraphXExternalAllocator::Alloc(size_t size) { +void* MIGraphXExternalAllocator::Alloc(const size_t size) { void* p = nullptr; if (size > 0) { p = alloc_(size); @@ -51,27 +51,27 @@ void* MIGraphXExternalAllocator::Alloc(size_t size) { void MIGraphXExternalAllocator::Free(void* p) { free_(p); - std::lock_guard lock(lock_); - auto it = reserved_.find(p); - if (it != reserved_.end()) { + std::lock_guard lock(lock_); + if (const auto it = reserved_.find(p); it != reserved_.end()) { reserved_.erase(it); if (empty_cache_) empty_cache_(); } } -void* MIGraphXExternalAllocator::Reserve(size_t size) { +void* MIGraphXExternalAllocator::Reserve(const size_t size) { void* p = Alloc(size); - if (!p) return nullptr; - std::lock_guard lock(lock_); - ORT_ENFORCE(reserved_.find(p) == reserved_.end()); - reserved_.insert(p); + if (p != nullptr) { + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + } return p; } -void* MIGraphXPinnedAllocator::Alloc(size_t size) { +void* MIGraphXPinnedAllocator::Alloc(const size_t size) { void* p = nullptr; if (size > 0) { - HIP_CALL_THROW(hipHostMalloc((void**)&p, size)); + HIP_CALL_THROW(hipHostMalloc(&p, size)); } return p; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index 2a84445897391..4842a36993f45 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -11,26 +11,26 @@ namespace onnxruntime { class MIGraphXAllocator : public IAllocator { public: - MIGraphXAllocator(int device_id, const char* name) + MIGraphXAllocator(const OrtDevice::DeviceId device_id, const char* name) : IAllocator( - OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device_id)), + OrtMemoryInfo(name, OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeDefault)) {} - virtual void* Alloc(size_t size) override; - virtual void Free(void* p) override; + void* Alloc(size_t size) override; + void Free(void* p) override; private: void CheckDevice() const; }; -class MIGraphXExternalAllocator : public MIGraphXAllocator { +class MIGraphXExternalAllocator final : public MIGraphXAllocator { typedef void* (*ExternalAlloc)(size_t size); typedef void (*ExternalFree)(void* p); typedef void (*ExternalEmptyCache)(); public: - MIGraphXExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + MIGraphXExternalAllocator(const OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) : MIGraphXAllocator(device_id, name) { alloc_ = reinterpret_cast(alloc); free_ = reinterpret_cast(free); @@ -51,10 +51,10 @@ class MIGraphXExternalAllocator : public MIGraphXAllocator { class MIGraphXPinnedAllocator final : public IAllocator { public: - MIGraphXPinnedAllocator(const int device_id, const char* name) + MIGraphXPinnedAllocator(const OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(device_id)), + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, device_id), device_id, OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 998c4af3f7576..6c588bb29c823 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include #include +#include +#include #include -#include +#include #include -#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -40,7 +42,7 @@ namespace onnxruntime { class Memcpy final : public OpKernel { public: - Memcpy(const OpKernelInfo& info) : OpKernel(info) {} + explicit Memcpy(const OpKernelInfo& info) : OpKernel(info) {} Status Compute(OpKernelContext* ctx) const override { const auto* X = ctx->Input(0); @@ -56,16 +58,13 @@ class Memcpy final : public OpKernel { } }; -template -KernelCreateInfo BuildKernelCreateInfo(); - ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, 1, kMIGraphXExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 0) + KernelDefBuilder::Create() + ->InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); @@ -74,14 +73,11 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 1, kMIGraphXExecutionProvider, - (*KernelDefBuilder::Create()) - .OutputMemoryType(OrtMemTypeCPUOutput, 0) + KernelDefBuilder::Create() + ->OutputMemoryType(OrtMemTypeCPUOutput, 0) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyToHost); - static std::shared_ptr s_kernel_registry; void InitializeRegistry() { @@ -106,16 +102,13 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { + : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); get_flags_from_session_info(info); metadef_id_generator_ = ModelMetadefIdGenerator::Create(); get_flags_from_env(); } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} - void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); @@ -178,17 +171,17 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut void MIGraphXExecutionProvider::get_flags_from_env() { LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:"; - // whether fp16 is enable - const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); + // whether fp16 is enabled + const std::string fp16_enable_env = GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { - fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + fp16_enable_ = std::stoi(fp16_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; } - const std::string bf16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kBF16Enable); + const std::string bf16_enable_env = GetEnvironmentVar(migraphx_env_vars::kBF16Enable); if (!bf16_enable_env.empty()) { #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) - bf16_enable_ = (std::stoi(bf16_enable_env) == 0 ? false : true); + bf16_enable_ = std::stoi(bf16_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_BF16_ENABLE: " << fp16_enable_; #else LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; @@ -201,10 +194,10 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // whether fp8 quantization is enabled - const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + const std::string fp8_enable_env = GetEnvironmentVar(migraphx_env_vars::kFP8Enable); if (!fp8_enable_env.empty()) { #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + fp8_enable_ = std::stoi(fp8_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; #else LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; @@ -213,9 +206,9 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // whether int8 is enabled - const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); + const std::string int8_enable_env = GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { - int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + int8_enable_ = std::stoi(int8_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } @@ -225,23 +218,22 @@ void MIGraphXExecutionProvider::get_flags_from_env() { if (int8_enable_ || fp8_enable_) { const std::string int8_calibration_cache_name_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); + GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { int8_calibration_cache_name_ = int8_calibration_cache_name_env; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; } - const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); + const std::string cache_path = GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; } const std::string int8_use_native_migraphx_calibration_table_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); + GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); if (!int8_use_native_migraphx_calibration_table_env.empty()) { - int8_use_native_migraphx_calibration_table_ = - (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + int8_use_native_migraphx_calibration_table_ = std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " << int8_use_native_migraphx_calibration_table_; } @@ -269,21 +261,21 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kDumpModelOps); + const std::string dump_model_ops_env = GetEnvironmentVar(migraphx_env_vars::kDumpModelOps); if (!dump_model_ops_env.empty()) { - dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + dump_model_ops_ = std::stoi(dump_model_ops_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; } // Allow for exhaustive tune during compile - const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); + const std::string exhaustive_tune_env = GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); if (!exhaustive_tune_env.empty()) { - exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + exhaustive_tune_ = std::stoi(exhaustive_tune_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; } } -void MIGraphXExecutionProvider::print_migraphx_ep_flags() { +void MIGraphXExecutionProvider::print_migraphx_ep_flags() const { LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << info_.device_id << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ @@ -297,14 +289,14 @@ void MIGraphXExecutionProvider::print_migraphx_ep_flags() { << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } -AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, - size_t migx_mem_limit, +AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, + const size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( + const AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { return std::make_unique(id, HIP, external_allocator_info.alloc, @@ -315,39 +307,38 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic false); return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, - // make it stream aware - true, - // enable cross stream sharing? - false); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); } + const AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP); + }, + device_id, + true, + {default_memory_arena_cfg ? *default_memory_arena_cfg + : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, + // make it stream aware + true, + // enable cross stream sharing? + false); + + // ROCM malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, info_.device_id); - AllocatorCreationInfo pinned_allocator_info( - [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + const AllocatorCreationInfo default_memory_info( + [](const OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, info_.device_id); + const AllocatorCreationInfo pinned_allocator_info( + [](const OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, CUDA_PINNED); }, 0); - return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; + return {CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } -std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { - return std::make_unique(); +std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { + return std::make_unique(); } static bool IsTypeSupported(const NodeArg* node_arg) { @@ -382,7 +373,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { } } -static bool getMIGraphXType(ONNXTensorElementDataType type, +static bool getMIGraphXType(const ONNXTensorElementDataType type, migraphx_shape_datatype_t& mgx_type) { mgx_type = migraphx_shape_float_type; switch (type) { @@ -411,8 +402,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: - mgx_type = migraphx_shape_int8_type; - break; + // No `break` intentional case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -426,8 +416,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, mgx_type = migraphx_shape_int64_type; break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: - mgx_type = migraphx_shape_uint8_type; - break; + // No `break` intentional case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; @@ -444,8 +433,8 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, mgx_type = migraphx_shape_bool_type; break; default: - LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU"; - LOGS_DEFAULT(WARNING) << "implementation"; + LOGS_DEFAULT(WARNING) << "MIGraphX: unsupported data type " << type + << ", fallback to CPU" << "implementation"; return false; } @@ -462,7 +451,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { return result; } -static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { +static bool IsUnsupportedOpMode(const GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); if (optype == "ArgMax" or optype == "ArgMin") { @@ -636,7 +625,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } } } else if (optype == "Split") { - // cannot process input dim of 0 size + // cannot process input dim of size 0 const auto arg_s = node->InputDefs()[0]->Shape(); if (arg_s != nullptr) { const auto& tensor_dims = arg_s->dim(); @@ -678,24 +667,21 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return false; } -void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, +void SubgraphPostProcessing(const GraphViewer& graph_viewer, std::vector>& clusters, [[maybe_unused]] const logging::Logger& logger) { // Then check whether a subgraph should fall back to CPU - // 1. Check whether a subgraph contains a RNN operator + // 1. Check whether a subgraph contains an RNN operator std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"}; std::unordered_set op_names = {"AveragePool", "Conv", "Gemm", "LRN", "MatMul", "MaxPool"}; - auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) { + const auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) { for (auto index : git) { - auto node = graph_viewer.GetNode(index); - if (node->OpType() == "Reshape") { - const auto& args = node->InputDefs(); - if (args.size() == 2) { - std::vector node_inputs; - if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { - return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { - return std::find(git.begin(), git.end(), index) != git.end(); - })); + if (auto node = graph_viewer.GetNode(index); node->OpType() == "Reshape") { + if (const auto& args = node->InputDefs(); args.size() == 2) { + if (std::vector node_inputs; canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { + return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) { + return std::find(git.begin(), git.end(), i) != git.end(); + }); } else { return true; } @@ -717,7 +703,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v const auto& node = graph_viewer.GetNode(nid); const auto& op_type = node->OpType(); if (op_names.count(op_type) > 0) { - // check number of elements in input + // check the number of elements in input auto inputs = node->InputDefs(); if (std::any_of(inputs.begin(), inputs.end(), [&](auto& arg) { const auto& arg_s = arg->Shape(); @@ -747,7 +733,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v } static bool IsNodeSupported(const std::set& op_set, - const onnxruntime::GraphViewer& graph_viewer, + const GraphViewer& graph_viewer, const NodeIndex node_idx, [[maybe_unused]] const logging::Logger& logger) { const auto& node = graph_viewer.GetNode(node_idx); @@ -763,7 +749,7 @@ static bool IsNodeSupported(const std::set& op_set, // check data type bool are_types_supported = true; - node->ForEachDef([&are_types_supported](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) { + node->ForEachDef([&are_types_supported](const NodeArg& node_arg, bool /*is_input*/) { are_types_supported &= IsTypeSupported(&node_arg); }); @@ -800,7 +786,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unique_ptr sub_graph = IndexedSubGraph::Create(); std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; @@ -881,15 +867,15 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st // Sort inputs and outputs by the order they were added std::multimap inputs, outputs; - for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); + for (auto& [fst, snd] : fused_inputs) { + inputs.insert({snd, fst}); } - for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { - outputs.insert(std::pair(it->second, it->first)); + for (auto& [fst, snd] : fused_outputs) { + outputs.insert({snd, fst}); } - // It is possible that an output of an node is put bebind the output of an later + // It is possible that an output of a node is put behind the output of a later // node in the graph output list. So we should sort the output name according // to the graph output names std::vector output_names; @@ -906,7 +892,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } for (auto& name : graph_output_names) { - if (std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end()) + if (graph_out_names.find(name) != graph_out_names.end()) output_names.push_back(name); } @@ -1105,7 +1091,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { // Collect inputs that are initializers graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, - &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { + &graph_viewer](const NodeArg& node_arg, bool is_input) { if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) { mgx_required_initializers.insert(node_arg.Name()); } }, @@ -1119,7 +1105,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, } // Returns a vector clusters(or node_idx). For each unsupported node, the graph -// is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph). +// is split into 3 parts. supported_cluster + (Unsupported_node + rest_of_the_graph). // This functions returns vector of all supported_subgraphx by amdmigraphx static std::vector> GetPartitionedSubgraphs(const std::vector& topological_order, @@ -1136,7 +1122,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, if (!this_subgraph.empty()) { mgx_subgraphx.push_back(std::move(this_subgraph)); } - // Point prev to node idx past this unsuported node. + // Point prev to node idx past this unsupported node. prev = ++it; } @@ -1150,7 +1136,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, } std::vector> -MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, +MIGraphXExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { @@ -1178,7 +1164,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v // If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. if (unsupported_nodes.empty()) { - auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); auto sub_graph = GetSubGraph(node_indices, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } else { // unsupported_nodes_idx.empty() @@ -1257,9 +1243,9 @@ bool get_input_output_names(const GraphViewer& graph, // Useful to default to EP to trigger the compile if file doesn't exist or loading fails. bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try { if (!path.empty() && exists(path)) { - LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path.string(); + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string(); prog = migraphx::load(path.string().c_str()); - LOGS_DEFAULT(INFO) << "load model : Success"; + LOGS_DEFAULT(VERBOSE) << "load model : Success"; return true; } return false; @@ -1269,26 +1255,26 @@ bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { if (!path.empty()) { - LOGS_DEFAULT(INFO) << "Model Save at " << path << ": Begin"; + LOGS_DEFAULT(VERBOSE) << "Model Save at " << path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); save(prog, path.string().c_str(), fo); - LOGS_DEFAULT(INFO) << "Model Save: Complete"; + LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } -// Order matters here especially if the program uses mixed quantization +// Order matters here, especially if the program uses mixed quantization // Calibrate on full precision for int8/fp8 and then quantize down to fp16 -void calibrate_and_quantize(migraphx::program& prog, +void calibrate_and_quantize(const migraphx::program& prog, const migraphx::target& t, - const migraphx::program_parameters quant_params, - bool fp16_enable, - bool bf16_enable, - bool int8_enable, - bool fp8_enable, - bool int8_calibration_cache_available, + const migraphx::program_parameters& quant_params, + const bool fp16_enable, + const bool bf16_enable, + const bool int8_enable, + const bool fp8_enable, + const bool int8_calibration_cache_available, std::unordered_map& dynamic_range_map) { - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + // Read in the calibration data and map it to a migraphx parameter map for the calibration ops if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(WARNING) << "Quantizing input program"; @@ -1296,8 +1282,8 @@ void calibrate_and_quantize(migraphx::program& prog, // Add all calibration data read in from int8 table for (auto& [cal_key, cal_val] : dynamic_range_map) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + const auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, &cal_val)); } // perform static quantization on the programs @@ -1336,9 +1322,9 @@ void calibrate_and_quantize(migraphx::program& prog, #endif } -void compile_program(migraphx::program& prog, +void compile_program(const migraphx::program& prog, const migraphx::target& t, - bool exhaustive_tune) { + const bool exhaustive_tune) { LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; migraphx::compile_options co; co.set_fast_math(false); @@ -1377,7 +1363,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& const Node& fused_node = fused_node_graph.fused_node; std::filesystem::path model_cache_file; - auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view(device_prop_.gcnArchName)) + "-"; + auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view{device_prop_.gcnArchName}) + "-"; // Get model input names (only first layer) const Graph* cur_graph = &graph_body_viewer.GetGraph(); @@ -1433,7 +1419,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!no_input_shape) { if (!load_precompiled_model(prog, model_cache_file)) { - LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; + LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; #ifndef ENABLE_TRAINING_CORE #ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH options.set_external_data_path(model_path_.parent_path().string()); @@ -1462,13 +1448,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& map_input_index_[fused_node.Name()] = input_name_index; map_no_input_shape_[fused_node.Name()] = no_input_shape; NodeComputeInfo compute_info; - compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { - std::unique_ptr p = std::make_unique(); + compute_info.create_state_func = [=](const ComputeContext* context, FunctionState* state) { + auto p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map_, - model_cache_path_.string(), dump_model_ops_}; + int8_calibration_cache_available_, dynamic_range_map_, model_cache_path_.string(), dump_model_ops_}; *state = p.release(); return 0; }; @@ -1480,7 +1465,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::KernelContext ctx(context); - MIGraphXFuncState* mgx_state = reinterpret_cast(state); + auto mgx_state = static_cast(state); std::unordered_map& map_input_name_index = mgx_state->input_name_indexes; std::unordered_map& map_dynamic_range = mgx_state->dynamic_range_map; @@ -1502,7 +1487,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector input_shapes; if (no_input_shape) { - LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1514,7 +1499,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1546,8 +1531,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } } - // input shapes are different, needs to re-parse onnx and - // re-compile the program + // input shapes are different, needs to reparse onnx and recompile the program if (!input_shape_match) { std::filesystem::path model_cache_file; // empty cache path means the MXR caching is disabled - always compile @@ -1602,7 +1586,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; + LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1616,21 +1600,21 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; + LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } - // It is a output argument + // It is an output argument else { - auto compute_output_index = [](const std::string& name) -> int { - std::string out_name_prefix = "#output_"; - auto pos = name.find(out_name_prefix); - if (pos == std::string::npos) { + auto compute_output_index = [](const std::string_view sv) -> int { + constexpr std::string_view out_name_prefix = "#output_"; + const auto pos = sv.find(out_name_prefix); + if (pos == std::string_view::npos) { return -1; } - std::string index_str = name.substr(pos + out_name_prefix.length()); - return std::stoi(index_str); + const auto index_str = sv.substr(pos + out_name_prefix.length()); + return ToInteger(Trim(index_str, std::isdigit)); }; int output_index = compute_output_index(name); @@ -1652,13 +1636,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& { // lock to avoid race condition - std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); + std::lock_guard lock(*mgx_state->mgx_mu_ptr); void* rocm_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); auto prog_outputs = prog.run_async(m, static_cast(rocm_stream)); - // In case of input parameters are reused as output parameter call hipMemcpy + // In the case of input parameters are reused as output parameter calls hipMemcpy auto output_num = prog_outputs.size(); if (prog_output_indices.size() < output_num) { for (std::size_t i = 0; i < output_num; ++i) { @@ -1677,7 +1661,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& static_cast(rocm_stream))); } } - }; + } return Status::OK(); }; @@ -1693,30 +1677,26 @@ void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis RegisterMIGraphXStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/); } -OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); +OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(const OrtMemType mem_type) const { + if (mem_type == OrtMemTypeCPUInput) return {}; + if (mem_type == OrtMemTypeCPUOutput) return {OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/}; return default_device_; } Status MIGraphXExecutionProvider::Sync() const { - HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr))); - - auto status = hipStreamQuery(stream_); - if (status != hipSuccess) { - return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL); + HIP_CALL_THROW(hipStreamSynchronize(nullptr)); + if (hipStreamQuery(stream_) != hipSuccess) { + return {common::ONNXRUNTIME, common::EP_FAIL}; } return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status MIGraphXExecutionProvider::OnRunStart(const RunOptions& /*run_options*/) { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { - auto status = hipStreamQuery(stream_); - - if (status != hipSuccess) { +Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const RunOptions& /*run_options*/) { + if (hipStreamQuery(stream_) != hipSuccess) { HIP_CALL_THROW(hipStreamSynchronize(stream_)); } return Status::OK(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index cf1ad9711c2aa..c5c1f0f2f1650 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -27,7 +27,6 @@ constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"; constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"; constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; - } // namespace migraphx_env_vars // Information to construct kernel function state. @@ -54,32 +53,32 @@ struct MIGraphXFuncState { }; // Logical device representation. -class MIGraphXExecutionProvider : public IExecutionProvider { +class MIGraphXExecutionProvider final : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider(); + ~MIGraphXExecutionProvider() override = default; void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); void get_flags_from_env(); - void print_migraphx_ep_flags(); + void print_migraphx_ep_flags() const; Status Sync() const override; - Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunStart(const RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const RunOptions& run_options) override; std::vector> - GetCapability(const onnxruntime::GraphViewer& graph_viewer, + GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; - common::Status Compile(const std::vector& fused_nodes, - std::vector& node_compute_funcs) override; + Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; - virtual std::shared_ptr GetKernelRegistry() const override; - std::unique_ptr GetDataTransfer() const override; + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); @@ -111,7 +110,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; - hipDeviceProp_t device_prop_; + hipDeviceProp_t device_prop_{}; bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index cb25db032ebf2..a1389a2e4b680 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -332,4 +332,28 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { return std::string{s.data(), ptr}; } +inline std::string_view TrimLeft(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(0, sv.end() - std::find_if(sv.begin(), sv.end(), [fn](int ch) { + return fn(ch); + })); +} + +inline std::string_view TrimRight(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(sv.end() - std::find_if(sv.rbegin(), sv.rend(), [fn](int ch) { + return fn(ch); + }).base()); +} + +inline std::string_view Trim(std::string_view sv, int (*fn)(int) = std::isspace) { + return TrimRight(TrimLeft(sv, fn), fn); +} + +inline int ToInteger(const std::string_view sv) { + int result = 0; + if (auto [_, ec] = std::from_chars(sv.data(), sv.data() + sv.length(), result); ec == std::errc()) { + return result; + } + ORT_THROW("invalid input for conversion to integer"); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 4a088c8c9b6ee..30650005bbc21 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,29 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License #include +#include #include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#include +#endif #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_provider_factory_creator.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_provider_factory_creator.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" -using namespace onnxruntime; - namespace onnxruntime { void InitializeRegistry(); void DeleteRegistry(); -struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} - ~MIGraphXProviderFactory() override {} +struct MIGraphXProviderFactory final : IExecutionProviderFactory { + explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {} + ~MIGraphXProviderFactory() override = default; std::unique_ptr CreateProvider() override; @@ -36,15 +42,15 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { } struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { - std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXPinnedAllocator(const OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { + void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, const size_t count) override { // hipMemcpy() operates on the default stream HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); @@ -53,24 +59,26 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { // The function will return once the pageable buffer has been copied to the staging memory for DMA transfer // to device memory, but the DMA to final destination may not have completed. - HIP_CALL_THROW(hipStreamSynchronize(0)); + HIP_CALL_THROW(hipStreamSynchronize(nullptr)); } // Used by onnxruntime_pybind_state.cc - void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { + void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, const size_t count) override { // For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + std::shared_ptr CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, const size_t mem_limit, const ArenaExtendStrategy arena_extend_strategy, const MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { + return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); } } g_info; -struct MIGraphX_Provider : Provider { +struct MIGraphX_Provider final : Provider { void* GetInfo() override { return &g_info; } - std::shared_ptr CreateExecutionProviderFactory(int device_id) override { + virtual ~MIGraphX_Provider() = default; + + std::shared_ptr CreateExecutionProviderFactory(const int device_id) override { MIGraphXExecutionProviderInfo info; info.device_id = static_cast(device_id); info.target_device = "gpu"; @@ -96,14 +104,14 @@ struct MIGraphX_Provider : Provider { if (options.migraphx_cache_dir != nullptr) { info.model_cache_dir = options.migraphx_cache_dir; } - info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); + info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); info.mem_limit = options.migraphx_mem_limit; return std::make_shared(info); } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& migx_options = *reinterpret_cast(provider_options); + auto internal_options = MIGraphXExecutionProviderInfo::FromProviderOptions(options); + auto& migx_options = *static_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; migx_options.migraphx_bf16_enable = internal_options.bf16_enable; @@ -123,7 +131,7 @@ struct MIGraphX_Provider : Provider { strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + migx_options.migraphx_int8_calibration_table_name = static_cast(dest); } migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; @@ -135,10 +143,23 @@ struct MIGraphX_Provider : Provider { ProviderOptions GetProviderOptions(const void* provider_options) override { auto& options = *reinterpret_cast(provider_options); - return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + return MIGraphXExecutionProviderInfo::ToProviderOptions(options); } void Initialize() override { +#ifdef _WIN32 + HMODULE module = nullptr; + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(static_cast(InitializeRegistry)), + &module) != 0) { + char buffer[MAX_PATH]; + if (GetModuleFileName(module, buffer, sizeof(buffer)) != 0) { + PathRemoveFileSpec(buffer); + SetDllDirectory(buffer); + } + } +#endif InitializeRegistry(); } @@ -151,7 +172,6 @@ struct MIGraphX_Provider : Provider { } // namespace onnxruntime extern "C" { - ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index d1c9457bafa0f..313603a4ecbf0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -1,7 +1,11 @@ -// Copyright 2019 AMD AMDMIGraphX +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License -#include "core/framework/provider_options.h" -#include "onnxruntime_c_api.h" +#pragma once + +#include +#include "core/framework/ortdevice.h" +#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class IAllocator; @@ -12,11 +16,11 @@ enum class ArenaExtendStrategy : int32_t; struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h index 02d30ad0f6fbb..db169b9e2f5a9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h @@ -6,6 +6,7 @@ #include #include "core/providers/providers.h" +#include "core/framework/provider_options.h" struct OrtMIGraphXProviderOptions; @@ -14,5 +15,6 @@ namespace onnxruntime { struct MIGraphXProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(const OrtMIGraphXProviderOptions* options); + static std::shared_ptr Create(const ProviderOptions&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 02696524042e7..994082c272fd2 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3029,7 +3029,12 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::GetEpApi, // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) -}; + &OrtApis::CreateMIGraphXProviderOptions, + &OrtApis::UpdateMIGraphXProviderOptions, + &OrtApis::GetMIGraphXProviderOptionsAsString, + &OrtApis::ReleaseMIGraphXProviderOptions, + &OrtApis::UpdateMIGraphXProviderOptionsWithValue, + &OrtApis::GetMIGraphXProviderOptionsByName}; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. static_assert(sizeof(OrtApiBase) == sizeof(void*) * 2, "New methods can't be added to OrtApiBase as it is not versioned"); diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index addeb36d4087d..4692733d5b9e8 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -597,4 +597,15 @@ ORT_API(const OrtKeyValuePairs*, EpDevice_EpOptions, _In_ const OrtEpDevice* ep_ ORT_API(const OrtHardwareDevice*, EpDevice_Device, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtEpApi*, GetEpApi); +ORT_API_STATUS_IMPL(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out); +ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys); +ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); +ORT_API(void, ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions*); + +ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value); +ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr); + } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 7fcaee48581f6..8d079ff485254 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -6,6 +6,7 @@ #include #include +#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" @@ -114,6 +115,7 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn #include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/dnnl/dnnl_provider_factory.h" #include "core/providers/migraphx/migraphx_provider_factory.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/openvino/openvino_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" @@ -1961,7 +1963,7 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(in return s_library_dnnl.Get().CreateExecutionProviderFactory(use_arena); } -std::shared_ptr MIGraphXProviderFactoryCreator::Create(int device_id) { +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const int device_id) { return s_library_migraphx.Get().CreateExecutionProviderFactory(device_id); } @@ -2065,6 +2067,12 @@ std::shared_ptr NvProviderFactoryCreator::Create( return nullptr; } +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + OrtMIGraphXProviderOptions migraphx_options; + s_library_migraphx.Get().UpdateProviderOptions(&migraphx_options, provider_options); + return s_library_migraphx.Get().CreateExecutionProviderFactory(&migraphx_options); +} + std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); } @@ -2655,7 +2663,8 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions, defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || \ defined(USE_CANN) || \ defined(USE_DNNL) || \ - defined(USE_ROCM) + defined(USE_ROCM) || \ + defined(USE_MIGRAPHX) static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterator& begin, const onnxruntime::ProviderOptions::iterator& end) { std::ostringstream options; @@ -3170,3 +3179,123 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ return nullptr; API_IMPL_END } + +ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + auto migraphx_options = std::make_unique(); + memset(migraphx_options.get(), 0, sizeof(OrtMIGraphXProviderOptions)); + *out = migraphx_options.release(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(out); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + onnxruntime::ProviderOptions provider_options_map; + for (size_t i = 0; i != num_keys; ++i) { + if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || + provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty"); + } + + provider_options_map[provider_options_keys[i]] = provider_options_values[i]; + } + + onnxruntime::s_library_migraphx.Get().UpdateProviderOptions(reinterpret_cast(migraphx_options), provider_options_map); + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, + _Inout_ OrtAllocator* allocator, + _Outptr_ char** ptr) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + onnxruntime::ProviderOptions options = + onnxruntime::s_library_migraphx.Get().GetProviderOptions(reinterpret_cast(migraphx_options)); + std::string options_str = BuildOptionsString(options.begin(), options.end()); + *ptr = onnxruntime::StrDup(options_str, allocator); + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) { +#ifdef USE_MIGRAPHX + std::unique_ptr p(ptr); + if (ptr->migraphx_cache_dir != nullptr) { + onnxruntime::AllocatorDefaultFree(const_cast(ptr->migraphx_cache_dir)); + } +#else + ORT_UNUSED_PARAMETER(ptr); +#endif +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _In_ void* value) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + auto sv = std::string_view{key}; + OrtAllocator* allocator; + GetAllocatorWithDefaultOptions(&allocator); + if (sv == onnxruntime::migraphx_provider_option::kDeviceId) { + auto dv = std::string_view{static_cast(value)}; + if (std::from_chars(dv.data(), dv.data() + dv.length(), migraphx_options->device_id).ec == std::errc::invalid_argument) { + ORT_THROW("Cannot convert from string to integer - invalid argument"); + } + } else if (sv == onnxruntime::migraphx_provider_option::kModelCacheDir) { + auto sd = std::string_view{static_cast(value)}; + migraphx_options->migraphx_cache_dir = onnxruntime::StrDup(sd.data(), allocator); + } else { + ORT_THROW("Unsupported provider option name: '" + std::string{sv} + "'"); + } + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(value); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _Outptr_ void** ptr) { + API_IMPL_BEGIN +#ifdef USE_MIGRAPHX + return nullptr; +#else + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +#endif + API_IMPL_END +} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 1f74ee3b3f2ee..4c0d3e8d22354 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. + MIGraphX }; struct EpToAppend { @@ -109,7 +110,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -121,7 +122,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::JS, "JS", kJsExecutionProvider}, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, - EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}}; + EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -279,6 +281,14 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + break; + } + case EpID::MIGraphX: { +#if defined(USE_MIGRAPHX) + options->provider_factories.push_back(MIGraphXProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif break; } @@ -657,4 +667,55 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, ORT_UNUSED_PARAMETER(num_keys); return CreateNotEnabledStatus("VitisAI"); } + +ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) { + ORT_UNUSED_PARAMETER(out); + return CreateNotEnabledStatus("MIGraphX"); +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + size_t num_keys) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("MIGraphX"); +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, + _Outptr_ char** ptr) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build."); +} + +ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) { + ORT_UNUSED_PARAMETER(ptr); +} + +ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue, + _Inout_ OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _In_ void* value) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(value); + return CreateNotEnabledStatus("MIGraphX"); +} + +ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName, + _In_ const OrtMIGraphXProviderOptions* migraphx_options, + _In_ const char* key, + _Outptr_ void** ptr) { + ORT_UNUSED_PARAMETER(migraphx_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(ptr); + return CreateNotEnabledStatus("MIGraphX"); +} + #endif diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 0428b19357d51..55026d9aab986 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -207,6 +207,7 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX + void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes); } diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index a7e8e0345e9d1..2e52c192ec521 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -51,7 +51,6 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info{}; #endif - void DlpackCapsuleDestructor(PyObject* data) { DLManagedTensor* dlmanaged_tensor = reinterpret_cast(PyCapsule_GetPointer(data, "dltensor")); if (dlmanaged_tensor) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 82502f17e808e..9740a857f7577 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -37,13 +37,13 @@ struct OrtStatus { #define BACKEND_PROC "CPU" #endif -#if USE_DNNL +#ifdef USE_DNNL #define BACKEND_DNNL "-DNNL" #else #define BACKEND_DNNL "" #endif -#if USE_MIGRAPHX +#ifdef USE_MIGRAPHX #define BACKEND_MIGRAPHX "-MIGRAPHX" #else #define BACKEND_MIGRAPHX "" diff --git a/setup.py b/setup.py index 1e426ea8e060b..33e47757d05e0 100644 --- a/setup.py +++ b/setup.py @@ -237,8 +237,10 @@ def run(self): rocm_dependencies = [ "libamd_comgr.so.2", + "libamd_comgr.so.3", "libamdhip64.so.5", "libamdhip64.so.6", + "libamdhip64.so.7", "libdrm.so.2", "libdrm_amdgpu.so.1", "libelf.so.1", @@ -250,13 +252,18 @@ def run(self): "libnuma.so.1", "librccl.so.1", "libhipblas.so.2", + "libhipblas.so.3" + "libhipblaslt.so.1", "librocblas.so.3", "librocblas.so.4", + "librocblas.so.5", "librocfft.so.0", "libroctx64.so.4", "librocm_smi64.so.5", "librocm_smi64.so.6", + "librocm_smi64.so.7", "libroctracer64.so.4", + "librocsolver.so.0", "libtinfo.so.6", "libmigraphx_c.so.3", "libmigraphx.so.2", @@ -410,6 +417,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_nv_tensorrt_rtx.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_migraphx.dll"]) libs.extend(["onnxruntime_providers_vitisai.dll"]) libs.extend(["onnxruntime_providers_qnn.dll"]) # DirectML Libs diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 52b7837b142b5..c23911b1cac76 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1894,6 +1894,7 @@ def build_nuget_package( use_winml, use_qnn, use_dml, + use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -1931,6 +1932,9 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" + elif use_migraphx: + execution_provider = "/p:ExecutionProvider=migraphx" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2518,6 +2522,7 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), + args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index dbb7be829494e..abc848e5525ca 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -608,8 +608,9 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: rocm_group.add_argument("--use_rocm", action="store_true", help="Enable ROCm EP.") rocm_group.add_argument("--rocm_version", help="ROCm stack version.") rocm_group.add_argument("--rocm_home", help="Path to ROCm installation directory.") - rocm_group.add_argument("--rocm_gfx_arch", help='Provide gfx arch. Example --rocm_gfx_arch gfx942' - ' or --rocm_gfx_arch "gfx90a;gfx942"') + rocm_group.add_argument( + "--rocm_gfx_arch", help='Provide gfx arch. Example --rocm_gfx_arch gfx942 or --rocm_gfx_arch "gfx90a;gfx942"' + ) # ROCm-specific profiling rocm_group.add_argument( "--enable_rocm_profiling", diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 419fdd47458f7..737931b7c0c5e 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -44,7 +44,13 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) + if package_name == "Microsoft.ML.OnnxRuntime.MIGraphX": + return False + return ( + (ep != "cuda" or "cuda" in filename) + and (ep != "tensorrt" or "cuda" not in filename) + and (ep != "migraphx" or "migraphx" not in filename) + ) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -64,7 +70,8 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs, if ( child_file.suffix in suffixes and is_this_file_needed(ep, child_file.name, package_name) - and package_name != "Microsoft.ML.OnnxRuntime.Gpu.Linux" + and package_name + not in ["Microsoft.ML.OnnxRuntime.Gpu.Linux", "Microsoft.ML.OnnxRuntime.MIGraphX.Linux"] ): files_list.append( '' @@ -95,6 +102,7 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs, child_file.suffix == ".so" and is_this_file_needed(ep, child_file.name, package_name) and package_name != "Microsoft.ML.OnnxRuntime.Gpu.Windows" + and package_name != "Microsoft.ML.OnnxRuntime.MIGraphX.Windows" ): files_list.append( '' @@ -138,7 +146,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +190,12 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Microsoft.ML.OnnxRuntime.MIGraphX.Linux" in package_name: + description = "This package contains Linux native shared library artifacts for ONNX Runtime with the MIGraphX." + elif "Microsoft.ML.OnnxRuntime.MIGraphX.Windows" in package_name: + description = ( + "This package contains Windows native shared library artifacts for ONNX Runtime with the MIGraphX." + ) elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -224,6 +238,8 @@ def add_common_dependencies(xml_text, package_name, version): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": xml_text.append('') xml_text.append('') + if package_name == "Microsoft.ML.OnnxRuntime.MIGraphX": + xml_text.append('') def generate_dependencies(xml_text, package_name, version): @@ -358,6 +374,9 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" + is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" + is_migraphx_win_sub_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX.Windows" + is_migraphx_linux_sub_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX.Linux" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -383,6 +402,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", + "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -420,7 +440,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": files_list.append( "' ) + if args.execution_provider == "migraphx" or (is_migraphx_win_sub_package and not is_ado_packaging_build): + files_list.append( + "' + ) + files_list.append( + "' + ) + # process all other library dependencies - if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package: + if is_cpu_package or is_cuda_gpu_package or is_migraphx_package or is_dml_package or is_mklml_package: # Process dnnl dependency if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["dnnl"])): files_list.append( @@ -898,6 +938,9 @@ def generate_files(line_list, args): or is_cuda_gpu_linux_sub_package or is_cuda_gpu_win_sub_package or is_rocm_gpu_package + or is_migraphx_package + or is_migraphx_linux_sub_package + or is_migraphx_win_sub_package or is_dml_package or is_mklml_package or is_snpe_package