diff --git a/.gitignore b/.gitignore
index 4d0a1205b7c19..b25763334f227 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
# build, distribute, and bins (+ python proto bindings)
+build.*/
build
build_*/
.build_debug/*
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index d42f23134805e..28efa6ea8cb3f 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -343,7 +343,7 @@ if (onnxruntime_USE_ROCM)
if (ROCM_VERSION_DEV VERSION_LESS "6.2")
message(FATAL_ERROR "CMAKE_HIP_ARCHITECTURES is not set when ROCm version < 6.2")
else()
- set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx942;gfx950;gfx1200;gfx1201;gfx1150")
+ set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx942;gfx950;gfx1200;gfx1201")
endif()
endif()
diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake
index 567bcaf581842..90c0f447800c6 100644
--- a/cmake/onnxruntime_providers_migraphx.cmake
+++ b/cmake/onnxruntime_providers_migraphx.cmake
@@ -2,21 +2,11 @@
# Licensed under the MIT License.
add_definitions(-DUSE_MIGRAPHX=1)
- set(BUILD_LIBRARY_ONLY 1)
- add_definitions("-DONNX_ML=1")
- add_definitions("-DONNX_NAMESPACE=onnx")
- include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR})
- set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME})
- include_directories(${onnx_SOURCE_DIR})
+ include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR})
set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
- if ( CMAKE_COMPILER_IS_GNUCC )
+ if (CMAKE_COMPILER_IS_GNUCC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers")
endif()
- set(CXX_VERSION_DEFINED TRUE)
- set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})
- if ( CMAKE_COMPILER_IS_GNUCC )
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
- endif()
# Add search paths for default rocm installation
list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH})
@@ -33,8 +23,6 @@
find_package(hip REQUIRED)
find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME})
- set(migraphx_libs migraphx::c hip::host)
-
file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc"
@@ -44,15 +32,15 @@
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs})
onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
- add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
- target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
- target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime)
+ add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES})
+ target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
+ target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime)
set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime")
- target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1)
- if(MSVC)
+ target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx)
+ if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def)
- target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32)
+ target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32 shlwapi)
else()
target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 9797d8019f2d3..6c7ebdf67e43f 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -609,7 +609,6 @@ endif()
if(onnxruntime_USE_MIGRAPHX)
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx)
- list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared)
endif()
if(onnxruntime_USE_ROCM)
@@ -694,9 +693,6 @@ endif()
if(onnxruntime_USE_MIGRAPHX)
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*)
- list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h")
- list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx)
- list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared)
endif()
if(onnxruntime_USE_NNAPI_BUILTIN)
diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj
index 6779fd60bcd0a..28a837f44ae07 100644
--- a/csharp/OnnxRuntime.CSharp.proj
+++ b/csharp/OnnxRuntime.CSharp.proj
@@ -163,6 +163,11 @@ CMake creates a target to this project
Targets ="CreateNativePackage" />
+
+
+
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 77c35aac65b92..839c695a23791 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -336,6 +336,13 @@ public struct OrtApi
public IntPtr GetModelEditorApi;
public IntPtr CreateTensorWithDataAndDeleterAsOrtValue;
public IntPtr SessionOptionsSetLoadCancellationFlag;
+
+ public IntPtr CreateMIGraphXProviderOptions;
+ public IntPtr UpdateMIGraphXProviderOptions;
+ public IntPtr GetMIGraphXProviderOptionsAsString;
+ public IntPtr ReleaseMIGraphXProviderOptions;
+ public IntPtr UpdateMIGraphXProviderOptionsWithValue;
+ public IntPtr GetMIGraphXProviderOptionsByName;
}
internal static class NativeMethods
@@ -573,6 +580,18 @@ static NativeMethods()
OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions));
OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString));
OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions));
+ SessionOptionsAppendExecutionProvider_MIGraphX = (DSessionOptionsAppendExecutionProvider_MIGraphX)Marshal.GetDelegateForFunctionPointer(
+ api_.SessionOptionsAppendExecutionProvider_MIGraphX, typeof(DSessionOptionsAppendExecutionProvider_MIGraphX));
+ OrtCreateMIGraphXProviderOptions = (DOrtCreateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateMIGraphXProviderOptions, typeof(DOrtCreateMIGraphXProviderOptions));
+ OrtUpdateMIGraphXProviderOptions = (DOrtUpdateMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateMIGraphXProviderOptions, typeof(DOrtUpdateMIGraphXProviderOptions));
+ OrtGetMIGraphXProviderOptionsAsString = (DOrtGetMIGraphXProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetMIGraphXProviderOptionsAsString, typeof(DOrtGetMIGraphXProviderOptionsAsString));
+ OrtReleaseMIGraphXProviderOptions = (DOrtReleaseMIGraphXProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseMIGraphXProviderOptions, typeof(DOrtReleaseMIGraphXProviderOptions));
+ OrtUpdateMIGraphXProviderOptionsWithValue =
+ (DOrtUpdateMIGraphXProviderOptionsWithValue)Marshal.GetDelegateForFunctionPointer(
+ api_.UpdateMIGraphXProviderOptionsWithValue, typeof(DOrtUpdateMIGraphXProviderOptionsWithValue));
+ OrtGetMIGraphXProviderOptionsByName =
+ (DOrtGetMIGraphXProviderOptionsByName)Marshal.GetDelegateForFunctionPointer(
+ api_.GetMIGraphXProviderOptionsByName, typeof(DOrtGetMIGraphXProviderOptionsByName));
OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2));
OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync));
CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter,
@@ -799,6 +818,80 @@ internal class NativeLib
#endregion
+#region Provider Options API
+ ///
+ /// Creates native OrtMIGraphXProviderOptions instance
+ ///
+ /// (output) native instance of OrtMIGraphXProviderOptions
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtCreateMIGraphXProviderOptions(
+ out IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance);
+ public static DOrtCreateMIGraphXProviderOptions OrtCreateMIGraphXProviderOptions;
+
+ ///
+ /// Updates native OrtMIGraphXProviderOptions instance using given key/value pairs
+ ///
+ /// native instance of OrtMIGraphXProviderOptions
+ /// configuration keys of OrtMIGraphXProviderOptions
+ /// configuration values of OrtMIGraphXProviderOptions
+ /// number of configuration keys
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtUpdateMIGraphXProviderOptions(
+ IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance,
+ IntPtr[] /*(const char* const *)*/ providerOptionsKeys,
+ IntPtr[] /*(const char* const *)*/ providerOptionsValues,
+ UIntPtr /*(size_t)*/ numKeys);
+ public static DOrtUpdateMIGraphXProviderOptions OrtUpdateMIGraphXProviderOptions;
+
+ ///
+ /// Get native OrtMIGraphXProviderOptions in serialized string
+ ///
+ /// instance of OrtAllocator
+ /// is a UTF-8 null terminated string allocated using 'allocator'
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtGetMIGraphXProviderOptionsAsString(
+ IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance,
+ IntPtr /*(OrtAllocator*)*/ allocator,
+ out IntPtr /*(char**)*/ ptr);
+ public static DOrtGetMIGraphXProviderOptionsAsString OrtGetMIGraphXProviderOptionsAsString;
+
+ ///
+ /// Releases native OrtMIGraphXProviderOptions instance
+ ///
+ /// native instance of OrtMIGraphXProviderOptions to be released
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate void DOrtReleaseMIGraphXProviderOptions(IntPtr /*(OrtMIGraphXProviderOptions*)*/ migraphxProviderOptionsInstance);
+ public static DOrtReleaseMIGraphXProviderOptions OrtReleaseMIGraphXProviderOptions;
+
+ ///
+ /// Update native OrtMIGraphXProviderOptions with value
+ ///
+ /// native instance of OrtMIGraphXProviderOptions to be released
+ /// configuration key of OrtMIGraphXProviderOptions
+ /// configuration value of OrtMIGraphXProviderOptions
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr DOrtUpdateMIGraphXProviderOptionsWithValue(
+ IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance,
+ IntPtr /*(char*)*/ providerOptionsKey,
+ IntPtr /*(char*)*/ providerOptionsValue);
+ public static DOrtUpdateMIGraphXProviderOptionsWithValue OrtUpdateMIGraphXProviderOptionsWithValue;
+
+ ///
+ /// Get native OrtMIGraphXProviderOptions value by name
+ ///
+ /// native instance of OrtMIGraphXProviderOptions to be released
+ /// configuration key of OrtMIGraphXProviderOptions
+ /// configuration value of OrtMIGraphXProviderOptions to return
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr DOrtGetMIGraphXProviderOptionsByName(
+ IntPtr /*(OrtMIGraphXProviderOptions**)*/ migraphxProviderOptionsInstance,
+ IntPtr /*(char*)*/ providerOptionsKey,
+ out IntPtr /*(char**)*/ providerOptionsValue);
+ public static DOrtGetMIGraphXProviderOptionsByName OrtGetMIGraphXProviderOptionsByName;
+
+
+#endregion
+
#region Status API
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status);
@@ -1160,6 +1253,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
[DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
+
+ [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)]
+ public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int use_arena, int device_id);
#endif
///
/// Append a TensorRT EP instance (configured based on given provider options) to the native OrtSessionOptions instance
@@ -1221,6 +1317,18 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM;
+ ///
+ /// Append a MIGraphX EP instance (configured based on given provider options) to the native OrtSessionOptions instance
+ ///
+ /// Native OrtSessionOptions instance
+ /// Native OrtMIGraphXProviderOptions instance
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_MIGraphX(
+ IntPtr /*(OrtSessionOptions*)*/ options,
+ IntPtr /*(const OrtMIGraphXProviderOptions*)*/ migraphxProviderOptions);
+
+ public static DSessionOptionsAppendExecutionProvider_MIGraphX SessionOptionsAppendExecutionProvider_MIGraphX;
+
///
/// Free Dimension override (by denotation)
///
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs
index 1b9cd7572170b..335b4ef8b3f65 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.shared.cs
@@ -291,6 +291,142 @@ protected override bool ReleaseHandle()
}
+///
+ /// Holds the options for configuring an MIGraphX Execution Provider instance
+ ///
+ public class OrtMIGraphXProviderOptions : SafeHandle
+ {
+ internal IntPtr Handle
+ {
+ get
+ {
+ return handle;
+ }
+ }
+
+ public int DeviceId
+ {
+ get { return _deviceId; }
+ set
+ {
+ UpdateProviderOptionWithValue(_deviceIdPtr, value.ToString());
+ _deviceId = value;
+ }
+ }
+ private IntPtr _deviceIdPtr = Marshal.StringToHGlobalAnsi("device_id");
+ private int _deviceId = 0;
+
+ public string ModelCacheDir
+ {
+ get { return _modelCacheDir; }
+ set
+ {
+ UpdateProviderOptionWithValue(_modelCacheDirPtr, value);
+ _modelCacheDir = value;
+ }
+ }
+
+ private IntPtr _modelCacheDirPtr = Marshal.StringToHGlobalAnsi("migraphx_model_cache_dir");
+ private string _modelCacheDir = "";
+
+ #region Constructor
+
+ ///
+ /// Constructs an empty OrtMIGraphXProviderOptions instance
+ ///
+ public OrtMIGraphXProviderOptions() : base(IntPtr.Zero, true)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMIGraphXProviderOptions(out handle));
+ }
+
+ #endregion
+
+ #region Finalizer
+
+ ~OrtMIGraphXProviderOptions()
+ {
+ Marshal.FreeHGlobal(_deviceIdPtr);
+ Marshal.FreeHGlobal(_modelCacheDirPtr);
+ }
+
+ #endregion
+
+ #region Public Methods
+
+ ///
+ /// Get MIGraphX EP provider options
+ ///
+ /// return C# UTF-16 encoded string
+ public string GetOptions()
+ {
+ var allocator = OrtAllocator.DefaultInstance;
+ // Process provider options string
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetMIGraphXProviderOptionsAsString(handle,
+ allocator.Pointer, out IntPtr providerOptions));
+ return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions, allocator);
+ }
+
+ ///
+ /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP
+ ///
+ /// Array of keys to set that correspond with values.
+ /// Array of values to set that correspond with keys.
+ /// The number of key/value pairs in the arrays.
+ private static IntPtr UpdateMIGraphXProviderOptions(IntPtr handle, IntPtr[] keys, IntPtr[] values, UIntPtr count)
+ {
+ return NativeMethods.OrtUpdateMIGraphXProviderOptions(handle, keys, values, count);
+ }
+
+ ///
+ /// Updates the configuration knobs of OrtMIGraphXProviderOptions that will eventually be used to configure a MIGraphX EP
+ ///
+ /// key/value pairs used to configure a MIGraphX Execution Provider
+ public void UpdateOptions(Dictionary providerOptions)
+ {
+ ProviderOptionsUpdater.Update(providerOptions, handle, UpdateMIGraphXProviderOptions);
+ }
+
+ #endregion
+
+ #region Public Properties
+
+ ///
+ /// Overrides SafeHandle.IsInvalid
+ ///
+ /// returns true if handle is equal to Zero
+ public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
+
+ #endregion
+
+ #region Private Methods
+
+ private void UpdateProviderOptionWithValue(IntPtr key, string value)
+ {
+ IntPtr valuePtr = Marshal.StringToHGlobalAnsi(value);
+ var nativeStatus = NativeMethods.OrtUpdateMIGraphXProviderOptionsWithValue(handle, key, valuePtr);
+ Marshal.FreeHGlobal(valuePtr);
+ NativeApiStatus.VerifySuccess(nativeStatus);
+ }
+
+ #endregion
+
+ #region SafeHandle
+ ///
+ /// Overrides SafeHandle.ReleaseHandle() to properly dispose of
+ /// the native instance of OrtMIGraphXProviderOptions
+ ///
+ /// always returns true
+ protected override bool ReleaseHandle()
+ {
+ NativeMethods.OrtReleaseMIGraphXProviderOptions(handle);
+ handle = IntPtr.Zero;
+ return true;
+ }
+
+ #endregion
+ }
+
+
///
/// This helper class contains methods to handle values of provider options
///
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 9b0f183f03681..3e1515cf5b78c 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -42,6 +42,9 @@ public class SessionOptions : SafeHandle
private static string[] cudaDelayLoadedLibs = { };
private static string[] trtDelayLoadedLibs = { };
+ // Delay-loaded MIGraphX DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information.
+ private static string[] migxDelayLoadedLibs = { };
+
#region Constructor and Factory methods
///
@@ -189,6 +192,28 @@ public static SessionOptions MakeSessionOptionWithRocmProvider(OrtROCMProviderOp
throw;
}
}
+
+ ///
+ /// A helper method to construct a SessionOptions object for MIGraaphX execution provider.
+ /// Use only if MIGraphX is installed and you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// MIGraphX EP provider options
+ /// A SessionsOptions() object configured for execution on provider options
+ public static SessionOptions MakeSessionOptionWithMIGraphXProvider(OrtMIGraphXProviderOptions migxProviderOptions)
+ {
+ CheckMIGraphXExecutionProviderDLLs();
+ SessionOptions options = new SessionOptions();
+ try
+ {
+ options.AppendExecutionProvider_MIGraphX(migxProviderOptions);
+ return options;
+ }
+ catch (Exception)
+ {
+ options.Dispose();
+ throw;
+ }
+ }
#endregion
#region ExecutionProviderAppends
@@ -331,12 +356,25 @@ public void AppendExecutionProvider_ROCm(OrtROCMProviderOptions rocmProviderOpti
public void AppendExecutionProvider_MIGraphX(int deviceId = 0)
{
#if __MOBILE__
- throw new NotSupportedException($"The MIGraphX Execution Provider is not supported in this build");
+ throw new NotSupportedException("The MIGraphX Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(handle, deviceId));
#endif
}
+ ///
+ /// Use only if you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// device identification
+ public void AppendExecutionProvider_MIGraphX(OrtMIGraphXProviderOptions migraphxProviderOptions)
+ {
+#if __MOBILE__
+ throw new NotSupportedException($"The AMD Nitris Execution Provider is not supported in this build");
+#else
+ NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_MIGraphX(handle, migraphxProviderOptions.Handle));
+#endif
+ }
+
///
/// Use only if you have the onnxruntime package specific to this Execution Provider.
///
@@ -885,6 +923,27 @@ private static bool CheckRocmExecutionProviderDLLs()
return true;
}
+ private static bool CheckMIGraphXExecutionProviderDLLs()
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ foreach (var dll in migxDelayLoadedLibs)
+ {
+ IntPtr handle = LoadLibrary(dll);
+ if (handle != IntPtr.Zero)
+ continue;
+ var sysdir = new StringBuilder(String.Empty, 2048);
+ GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
+ throw new OnnxRuntimeException(
+ ErrorCode.NoSuchFile,
+ $"kernel32.LoadLibrary():'{dll}' not found. MIGraphX are required for GPU execution. " +
+ $". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
+ );
+ }
+ }
+ return true;
+ }
+
#endregion
#region SafeHandle
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index af9b517b77e5c..6e2b4fa6aba3f 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -675,10 +675,10 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true
int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
- int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
+ int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
const char* migraphx_cache_dir; // MIGraphX model cache directory
- bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
+ int migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true
/** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
* Defaults to SIZE_MAX.
@@ -5265,6 +5265,88 @@ struct OrtApi {
* \since Version 1.22.
*/
const OrtEpApi*(ORT_API_CALL* GetEpApi)();
+
+ /// @}
+ /// \name OrtMIGraphXProviderOptions
+ /// @{
+
+ /** \brief Create an OrtMIGraphXProviderOptions
+ *
+ * \param[out] out Newly created ::OrtMIGraphXProviderOptions. Must be released with OrtApi::ReleaseMIGraphXProviderOptions
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.xx.
+ */
+ ORT_API2_STATUS(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out);
+
+ /** \brief Set options in a MIGraphX Execution Provider.
+ *
+ * For example, key="device_id" and value="0"
+ *
+ * \param[in] migraphx_options
+ * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys
+ * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values
+ * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.xx.
+ */
+ ORT_API2_STATUS(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options,
+ _In_reads_(num_keys) const char* const* provider_options_keys,
+ _In_reads_(num_keys) const char* const* provider_options_values,
+ _In_ size_t num_keys);
+
+ /**
+ * Get serialized MIGraphX provider options string.
+ *
+ * For example, "device_id=0;;......"
+ *
+ * \param migraphx_options - OrtMIGraphXProviderOptions instance
+ * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions()
+ * the specified allocator will be used to allocate continuous buffers for output strings and lengths.
+ * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.xx.
+ */
+ ORT_API2_STATUS(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr);
+
+ /** \brief Release an ::OrtMIGraphXProviderOptions
+ *
+ * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does
+ *
+ * \since Version 1.xx.
+ */
+ void(ORT_API_CALL* ReleaseMIGraphXProviderOptions)(_Frees_ptr_opt_ OrtMIGraphXProviderOptions* input);
+
+ /**
+ * Update MIGraphX EP provider option where its data type is pointer, for example 'user_compute_stream'.
+ * If the data type of the provider option can be represented by string please use UpdateMIGraphXProviderOptions.
+ *
+ * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer.
+ *
+ * \param migraphx_options - OrtMIGraphXProviderOptions instance
+ * \param key - Name of the provider option
+ * \param value - A pointer to the instance that will be assigned to this provider option
+ *
+ * \since Version 1.xx.
+ */
+ ORT_API2_STATUS(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value);
+
+ /**
+ * Get MIGraphX EP provider option where its data type is pointer.
+ * If the data type of the provider option can be represented by string please use GetMIGraphXProviderOptionsAsString.
+ *
+ * \param migraphx_options - OrtMIGraphXProviderOptions instance
+ * \param key - Name of the provider option
+ * \param ptr - A pointer to the instance that is kept by the provider option
+ *
+ * \since Version 1.xx.
+ */
+ ORT_API2_STATUS(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr);
};
/*
diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc
index c0d8a4f02bbc3..776fd5fec367f 100644
--- a/onnxruntime/core/providers/dml/dml_provider_factory.cc
+++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc
@@ -562,7 +562,10 @@ static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_devi
sizeof(feature_levels)
));
- auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE);
+ // Use compute queue whenever possible on supported hardware to avoid TDR and maintain UI QoS
+ // Core and generic devices only have compute queues, DX11 has "immediate" submission, DX12 has both
+ auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE) ||
+ (feature_levels.MaxSupportedFeatureLevel >= D3D_FEATURE_LEVEL_12_0);
if (use_compute_command_list)
{
diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc
index cf9f44f4cd8f0..17dfdf4519b16 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc
+++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc
@@ -23,11 +23,11 @@ void MIGraphXAllocator::CheckDevice() const {
#endif
}
-void* MIGraphXAllocator::Alloc(size_t size) {
+void* MIGraphXAllocator::Alloc(const size_t size) {
CheckDevice();
void* p = nullptr;
if (size > 0) {
- HIP_CALL_THROW(hipMalloc((void**)&p, size));
+ HIP_CALL_THROW(hipMalloc(&p, size));
}
return p;
}
@@ -37,7 +37,7 @@ void MIGraphXAllocator::Free(void* p) {
(void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown
}
-void* MIGraphXExternalAllocator::Alloc(size_t size) {
+void* MIGraphXExternalAllocator::Alloc(const size_t size) {
void* p = nullptr;
if (size > 0) {
p = alloc_(size);
@@ -51,27 +51,27 @@ void* MIGraphXExternalAllocator::Alloc(size_t size) {
void MIGraphXExternalAllocator::Free(void* p) {
free_(p);
- std::lock_guard lock(lock_);
- auto it = reserved_.find(p);
- if (it != reserved_.end()) {
+ std::lock_guard lock(lock_);
+ if (const auto it = reserved_.find(p); it != reserved_.end()) {
reserved_.erase(it);
if (empty_cache_) empty_cache_();
}
}
-void* MIGraphXExternalAllocator::Reserve(size_t size) {
+void* MIGraphXExternalAllocator::Reserve(const size_t size) {
void* p = Alloc(size);
- if (!p) return nullptr;
- std::lock_guard lock(lock_);
- ORT_ENFORCE(reserved_.find(p) == reserved_.end());
- reserved_.insert(p);
+ if (p != nullptr) {
+ std::lock_guard lock(lock_);
+ ORT_ENFORCE(reserved_.find(p) == reserved_.end());
+ reserved_.insert(p);
+ }
return p;
}
-void* MIGraphXPinnedAllocator::Alloc(size_t size) {
+void* MIGraphXPinnedAllocator::Alloc(const size_t size) {
void* p = nullptr;
if (size > 0) {
- HIP_CALL_THROW(hipHostMalloc((void**)&p, size));
+ HIP_CALL_THROW(hipHostMalloc(&p, size));
}
return p;
}
diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h
index 2a84445897391..4842a36993f45 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h
@@ -11,26 +11,26 @@ namespace onnxruntime {
class MIGraphXAllocator : public IAllocator {
public:
- MIGraphXAllocator(int device_id, const char* name)
+ MIGraphXAllocator(const OrtDevice::DeviceId device_id, const char* name)
: IAllocator(
- OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator,
- OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device_id)),
+ OrtMemoryInfo(name, OrtDeviceAllocator,
+ OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id),
device_id, OrtMemTypeDefault)) {}
- virtual void* Alloc(size_t size) override;
- virtual void Free(void* p) override;
+ void* Alloc(size_t size) override;
+ void Free(void* p) override;
private:
void CheckDevice() const;
};
-class MIGraphXExternalAllocator : public MIGraphXAllocator {
+class MIGraphXExternalAllocator final : public MIGraphXAllocator {
typedef void* (*ExternalAlloc)(size_t size);
typedef void (*ExternalFree)(void* p);
typedef void (*ExternalEmptyCache)();
public:
- MIGraphXExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache)
+ MIGraphXExternalAllocator(const OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache)
: MIGraphXAllocator(device_id, name) {
alloc_ = reinterpret_cast(alloc);
free_ = reinterpret_cast(free);
@@ -51,10 +51,10 @@ class MIGraphXExternalAllocator : public MIGraphXAllocator {
class MIGraphXPinnedAllocator final : public IAllocator {
public:
- MIGraphXPinnedAllocator(const int device_id, const char* name)
+ MIGraphXPinnedAllocator(const OrtDevice::DeviceId device_id, const char* name)
: IAllocator(
OrtMemoryInfo(name, OrtDeviceAllocator,
- OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(device_id)),
+ OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, device_id),
device_id, OrtMemTypeCPUOutput)) {}
void* Alloc(size_t size) override;
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
index 998c4af3f7576..6c588bb29c823 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
@@ -1,11 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License
-#include
#include
+#include
+#include
#include
-#include
+#include
#include
-#include
+#include
+#include
#include "core/providers/shared_library/provider_api.h"
#define ORT_API_MANUAL_INIT
@@ -40,7 +42,7 @@ namespace onnxruntime {
class Memcpy final : public OpKernel {
public:
- Memcpy(const OpKernelInfo& info) : OpKernel(info) {}
+ explicit Memcpy(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* ctx) const override {
const auto* X = ctx->Input(0);
@@ -56,16 +58,13 @@ class Memcpy final : public OpKernel {
}
};
-template
-KernelCreateInfo BuildKernelCreateInfo();
-
ONNX_OPERATOR_KERNEL_EX(
MemcpyFromHost,
kOnnxDomain,
1,
kMIGraphXExecutionProvider,
- (*KernelDefBuilder::Create())
- .InputMemoryType(OrtMemTypeCPUInput, 0)
+ KernelDefBuilder::Create()
+ ->InputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Memcpy);
@@ -74,14 +73,11 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
1,
kMIGraphXExecutionProvider,
- (*KernelDefBuilder::Create())
- .OutputMemoryType(OrtMemTypeCPUOutput, 0)
+ KernelDefBuilder::Create()
+ ->OutputMemoryType(OrtMemTypeCPUOutput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Memcpy);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
-
static std::shared_ptr s_kernel_registry;
void InitializeRegistry() {
@@ -106,16 +102,13 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c
}
MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info)
- : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) {
+ : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) {
InitProviderOrtApi();
get_flags_from_session_info(info);
metadef_id_generator_ = ModelMetadefIdGenerator::Create();
get_flags_from_env();
}
-MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {
-}
-
void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) {
// Set GPU device to be used
HIP_CALL_THROW(hipSetDevice(info_.device_id));
@@ -178,17 +171,17 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut
void MIGraphXExecutionProvider::get_flags_from_env() {
LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:";
- // whether fp16 is enable
- const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable);
+ // whether fp16 is enabled
+ const std::string fp16_enable_env = GetEnvironmentVar(migraphx_env_vars::kFP16Enable);
if (!fp16_enable_env.empty()) {
- fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
+ fp16_enable_ = std::stoi(fp16_enable_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_;
}
- const std::string bf16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kBF16Enable);
+ const std::string bf16_enable_env = GetEnvironmentVar(migraphx_env_vars::kBF16Enable);
if (!bf16_enable_env.empty()) {
#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2)
- bf16_enable_ = (std::stoi(bf16_enable_env) == 0 ? false : true);
+ bf16_enable_ = std::stoi(bf16_enable_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_BF16_ENABLE: " << fp16_enable_;
#else
LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater";
@@ -201,10 +194,10 @@ void MIGraphXExecutionProvider::get_flags_from_env() {
}
// whether fp8 quantization is enabled
- const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable);
+ const std::string fp8_enable_env = GetEnvironmentVar(migraphx_env_vars::kFP8Enable);
if (!fp8_enable_env.empty()) {
#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4)
- fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true);
+ fp8_enable_ = std::stoi(fp8_enable_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_;
#else
LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater";
@@ -213,9 +206,9 @@ void MIGraphXExecutionProvider::get_flags_from_env() {
}
// whether int8 is enabled
- const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable);
+ const std::string int8_enable_env = GetEnvironmentVar(migraphx_env_vars::kINT8Enable);
if (!int8_enable_env.empty()) {
- int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true);
+ int8_enable_ = std::stoi(int8_enable_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_;
}
@@ -225,23 +218,22 @@ void MIGraphXExecutionProvider::get_flags_from_env() {
if (int8_enable_ || fp8_enable_) {
const std::string int8_calibration_cache_name_env =
- onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName);
+ GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName);
if (!int8_calibration_cache_name_env.empty()) {
int8_calibration_cache_name_ = int8_calibration_cache_name_env;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_;
}
- const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath);
+ const std::string cache_path = GetEnvironmentVar(migraphx_env_vars::kCachePath);
if (!cache_path.empty()) {
calibration_cache_path_ = cache_path;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_;
}
const std::string int8_use_native_migraphx_calibration_table_env =
- onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable);
+ GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable);
if (!int8_use_native_migraphx_calibration_table_env.empty()) {
- int8_use_native_migraphx_calibration_table_ =
- (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true);
+ int8_use_native_migraphx_calibration_table_ = std::stoi(int8_use_native_migraphx_calibration_table_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: "
<< int8_use_native_migraphx_calibration_table_;
}
@@ -269,21 +261,21 @@ void MIGraphXExecutionProvider::get_flags_from_env() {
}
// dump unsupported ops
- const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kDumpModelOps);
+ const std::string dump_model_ops_env = GetEnvironmentVar(migraphx_env_vars::kDumpModelOps);
if (!dump_model_ops_env.empty()) {
- dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true);
+ dump_model_ops_ = std::stoi(dump_model_ops_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_;
}
// Allow for exhaustive tune during compile
- const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune);
+ const std::string exhaustive_tune_env = GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune);
if (!exhaustive_tune_env.empty()) {
- exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true);
+ exhaustive_tune_ = std::stoi(exhaustive_tune_env) != 0;
LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_;
}
}
-void MIGraphXExecutionProvider::print_migraphx_ep_flags() {
+void MIGraphXExecutionProvider::print_migraphx_ep_flags() const {
LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << info_.device_id
<< "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_
<< "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_
@@ -297,14 +289,14 @@ void MIGraphXExecutionProvider::print_migraphx_ep_flags() {
<< "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_;
}
-AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id,
- size_t migx_mem_limit,
+AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id,
+ const size_t migx_mem_limit,
ArenaExtendStrategy arena_extend_strategy,
MIGraphXExecutionProviderExternalAllocatorInfo
external_allocator_info,
const OrtArenaCfg* default_memory_arena_cfg) {
if (external_allocator_info.UseExternalAllocator()) {
- AllocatorCreationInfo default_memory_info(
+ const AllocatorCreationInfo default_memory_info(
[external_allocator_info](OrtDevice::DeviceId id) {
return std::make_unique(id, HIP,
external_allocator_info.alloc,
@@ -315,39 +307,38 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic
false);
return CreateAllocator(default_memory_info);
- } else {
- AllocatorCreationInfo default_memory_info(
- [](OrtDevice::DeviceId id) {
- return std::make_unique(id, HIP);
- },
- device_id,
- true,
- {default_memory_arena_cfg ? *default_memory_arena_cfg
- : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy),
- -1, -1, -1, -1L)},
- // make it stream aware
- true,
- // enable cross stream sharing?
- false);
-
- // ROCM malloc/free is expensive so always use an arena
- return CreateAllocator(default_memory_info);
}
+ const AllocatorCreationInfo default_memory_info(
+ [](OrtDevice::DeviceId id) {
+ return std::make_unique(id, HIP);
+ },
+ device_id,
+ true,
+ {default_memory_arena_cfg ? *default_memory_arena_cfg
+ : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy),
+ -1, -1, -1, -1L)},
+ // make it stream aware
+ true,
+ // enable cross stream sharing?
+ false);
+
+ // ROCM malloc/free is expensive so always use an arena
+ return CreateAllocator(default_memory_info);
}
std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() {
- AllocatorCreationInfo default_memory_info(
- [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, info_.device_id);
- AllocatorCreationInfo pinned_allocator_info(
- [](OrtDevice::DeviceId device_id) {
- return std::make_unique(device_id, onnxruntime::CUDA_PINNED);
+ const AllocatorCreationInfo default_memory_info(
+ [](const OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, info_.device_id);
+ const AllocatorCreationInfo pinned_allocator_info(
+ [](const OrtDevice::DeviceId device_id) {
+ return std::make_unique(device_id, CUDA_PINNED);
},
0);
- return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
+ return {CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
}
-std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const {
- return std::make_unique();
+std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const {
+ return std::make_unique();
}
static bool IsTypeSupported(const NodeArg* node_arg) {
@@ -382,7 +373,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
}
}
-static bool getMIGraphXType(ONNXTensorElementDataType type,
+static bool getMIGraphXType(const ONNXTensorElementDataType type,
migraphx_shape_datatype_t& mgx_type) {
mgx_type = migraphx_shape_float_type;
switch (type) {
@@ -411,8 +402,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
mgx_type = migraphx_shape_fp8e5m2fnuz_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
- mgx_type = migraphx_shape_int8_type;
- break;
+ // No `break` intentional
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
mgx_type = migraphx_shape_int8_type;
break;
@@ -426,8 +416,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
mgx_type = migraphx_shape_int64_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4:
- mgx_type = migraphx_shape_uint8_type;
- break;
+ // No `break` intentional
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
mgx_type = migraphx_shape_uint8_type;
break;
@@ -444,8 +433,8 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
mgx_type = migraphx_shape_bool_type;
break;
default:
- LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU";
- LOGS_DEFAULT(WARNING) << "implementation";
+ LOGS_DEFAULT(WARNING) << "MIGraphX: unsupported data type " << type
+ << ", fallback to CPU" << "implementation";
return false;
}
@@ -462,7 +451,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) {
return result;
}
-static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) {
+static bool IsUnsupportedOpMode(const GraphViewer& graph_viewer, const Node* node) {
std::vector input_nodes;
const auto& optype = node->OpType();
if (optype == "ArgMax" or optype == "ArgMin") {
@@ -636,7 +625,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
}
}
} else if (optype == "Split") {
- // cannot process input dim of 0 size
+ // cannot process input dim of size 0
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
@@ -678,24 +667,21 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return false;
}
-void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters,
+void SubgraphPostProcessing(const GraphViewer& graph_viewer, std::vector>& clusters,
[[maybe_unused]] const logging::Logger& logger) {
// Then check whether a subgraph should fall back to CPU
- // 1. Check whether a subgraph contains a RNN operator
+ // 1. Check whether a subgraph contains an RNN operator
std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"};
std::unordered_set op_names = {"AveragePool", "Conv", "Gemm", "LRN", "MatMul", "MaxPool"};
- auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) {
+ const auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) {
for (auto index : git) {
- auto node = graph_viewer.GetNode(index);
- if (node->OpType() == "Reshape") {
- const auto& args = node->InputDefs();
- if (args.size() == 2) {
- std::vector node_inputs;
- if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) {
- return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) {
- return std::find(git.begin(), git.end(), index) != git.end();
- }));
+ if (auto node = graph_viewer.GetNode(index); node->OpType() == "Reshape") {
+ if (const auto& args = node->InputDefs(); args.size() == 2) {
+ if (std::vector node_inputs; canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) {
+ return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) {
+ return std::find(git.begin(), git.end(), i) != git.end();
+ });
} else {
return true;
}
@@ -717,7 +703,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v
const auto& node = graph_viewer.GetNode(nid);
const auto& op_type = node->OpType();
if (op_names.count(op_type) > 0) {
- // check number of elements in input
+ // check the number of elements in input
auto inputs = node->InputDefs();
if (std::any_of(inputs.begin(), inputs.end(), [&](auto& arg) {
const auto& arg_s = arg->Shape();
@@ -747,7 +733,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v
}
static bool IsNodeSupported(const std::set& op_set,
- const onnxruntime::GraphViewer& graph_viewer,
+ const GraphViewer& graph_viewer,
const NodeIndex node_idx,
[[maybe_unused]] const logging::Logger& logger) {
const auto& node = graph_viewer.GetNode(node_idx);
@@ -763,7 +749,7 @@ static bool IsNodeSupported(const std::set& op_set,
// check data type
bool are_types_supported = true;
- node->ForEachDef([&are_types_supported](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) {
+ node->ForEachDef([&are_types_supported](const NodeArg& node_arg, bool /*is_input*/) {
are_types_supported &= IsTypeSupported(&node_arg);
});
@@ -800,7 +786,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st
}
// Find inputs and outputs of the subgraph
- std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create();
+ std::unique_ptr sub_graph = IndexedSubGraph::Create();
std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_set erased;
int input_order = 0;
@@ -881,15 +867,15 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st
// Sort inputs and outputs by the order they were added
std::multimap inputs, outputs;
- for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) {
- inputs.insert(std::pair(it->second, it->first));
+ for (auto& [fst, snd] : fused_inputs) {
+ inputs.insert({snd, fst});
}
- for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) {
- outputs.insert(std::pair(it->second, it->first));
+ for (auto& [fst, snd] : fused_outputs) {
+ outputs.insert({snd, fst});
}
- // It is possible that an output of an node is put bebind the output of an later
+ // It is possible that an output of a node is put behind the output of a later
// node in the graph output list. So we should sort the output name according
// to the graph output names
std::vector output_names;
@@ -906,7 +892,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st
}
for (auto& name : graph_output_names) {
- if (std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end())
+ if (graph_out_names.find(name) != graph_out_names.end())
output_names.push_back(name);
}
@@ -1105,7 +1091,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) {
// Collect inputs that are initializers
graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers,
- &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) {
+ &graph_viewer](const NodeArg& node_arg, bool is_input) {
if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) {
mgx_required_initializers.insert(node_arg.Name());
} },
@@ -1119,7 +1105,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
}
// Returns a vector clusters(or node_idx). For each unsupported node, the graph
-// is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph).
+// is split into 3 parts. supported_cluster + (Unsupported_node + rest_of_the_graph).
// This functions returns vector of all supported_subgraphx by amdmigraphx
static std::vector>
GetPartitionedSubgraphs(const std::vector& topological_order,
@@ -1136,7 +1122,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order,
if (!this_subgraph.empty()) {
mgx_subgraphx.push_back(std::move(this_subgraph));
}
- // Point prev to node idx past this unsuported node.
+ // Point prev to node idx past this unsupported node.
prev = ++it;
}
@@ -1150,7 +1136,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order,
}
std::vector>
-MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
+MIGraphXExecutionProvider::GetCapability(const GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_lookup*/,
const GraphOptimizerRegistry& /* graph_optimizer_registry */,
IResourceAccountant* /* resource_accountant */) const {
@@ -1178,7 +1164,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
// If all ops are supported, no partitioning is required. Short-circuit and avoid splitting.
if (unsupported_nodes.empty()) {
- auto node_indices = graph_viewer.GetNodesInTopologicalOrder();
+ const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
auto sub_graph = GetSubGraph(node_indices, graph_viewer);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
} else { // unsupported_nodes_idx.empty()
@@ -1257,9 +1243,9 @@ bool get_input_output_names(const GraphViewer& graph,
// Useful to default to EP to trigger the compile if file doesn't exist or loading fails.
bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try {
if (!path.empty() && exists(path)) {
- LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path.string();
+ LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string();
prog = migraphx::load(path.string().c_str());
- LOGS_DEFAULT(INFO) << "load model : Success";
+ LOGS_DEFAULT(VERBOSE) << "load model : Success";
return true;
}
return false;
@@ -1269,26 +1255,26 @@ bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path
void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) {
if (!path.empty()) {
- LOGS_DEFAULT(INFO) << "Model Save at " << path << ": Begin";
+ LOGS_DEFAULT(VERBOSE) << "Model Save at " << path << ": Begin";
migraphx::file_options fo;
fo.set_file_format("msgpack");
save(prog, path.string().c_str(), fo);
- LOGS_DEFAULT(INFO) << "Model Save: Complete";
+ LOGS_DEFAULT(VERBOSE) << "Model Save: Complete";
}
}
-// Order matters here especially if the program uses mixed quantization
+// Order matters here, especially if the program uses mixed quantization
// Calibrate on full precision for int8/fp8 and then quantize down to fp16
-void calibrate_and_quantize(migraphx::program& prog,
+void calibrate_and_quantize(const migraphx::program& prog,
const migraphx::target& t,
- const migraphx::program_parameters quant_params,
- bool fp16_enable,
- bool bf16_enable,
- bool int8_enable,
- bool fp8_enable,
- bool int8_calibration_cache_available,
+ const migraphx::program_parameters& quant_params,
+ const bool fp16_enable,
+ const bool bf16_enable,
+ const bool int8_enable,
+ const bool fp8_enable,
+ const bool int8_calibration_cache_available,
std::unordered_map& dynamic_range_map) {
- // Read in the calibration data and map it to an migraphx paramater map for the calibration ops
+ // Read in the calibration data and map it to a migraphx parameter map for the calibration ops
if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) {
LOGS_DEFAULT(WARNING) << "Quantizing input program";
@@ -1296,8 +1282,8 @@ void calibrate_and_quantize(migraphx::program& prog,
// Add all calibration data read in from int8 table
for (auto& [cal_key, cal_val] : dynamic_range_map) {
- auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
- quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val))));
+ const auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
+ quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, &cal_val));
}
// perform static quantization on the programs
@@ -1336,9 +1322,9 @@ void calibrate_and_quantize(migraphx::program& prog,
#endif
}
-void compile_program(migraphx::program& prog,
+void compile_program(const migraphx::program& prog,
const migraphx::target& t,
- bool exhaustive_tune) {
+ const bool exhaustive_tune) {
LOGS_DEFAULT(WARNING) << "Model Compile: Begin";
migraphx::compile_options co;
co.set_fast_math(false);
@@ -1377,7 +1363,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
const Node& fused_node = fused_node_graph.fused_node;
std::filesystem::path model_cache_file;
- auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view(device_prop_.gcnArchName)) + "-";
+ auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view{device_prop_.gcnArchName}) + "-";
// Get model input names (only first layer)
const Graph* cur_graph = &graph_body_viewer.GetGraph();
@@ -1433,7 +1419,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
if (!no_input_shape) {
if (!load_precompiled_model(prog, model_cache_file)) {
- LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model";
+ LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model";
#ifndef ENABLE_TRAINING_CORE
#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH
options.set_external_data_path(model_path_.parent_path().string());
@@ -1462,13 +1448,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
map_input_index_[fused_node.Name()] = input_name_index;
map_no_input_shape_[fused_node.Name()] = no_input_shape;
NodeComputeInfo compute_info;
- compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
- std::unique_ptr p = std::make_unique();
+ compute_info.create_state_func = [=](const ComputeContext* context, FunctionState* state) {
+ auto p = std::make_unique();
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_,
- int8_calibration_cache_available_, dynamic_range_map_,
- model_cache_path_.string(), dump_model_ops_};
+ int8_calibration_cache_available_, dynamic_range_map_, model_cache_path_.string(), dump_model_ops_};
*state = p.release();
return 0;
};
@@ -1480,7 +1465,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
Ort::KernelContext ctx(context);
- MIGraphXFuncState* mgx_state = reinterpret_cast(state);
+ auto mgx_state = static_cast(state);
std::unordered_map& map_input_name_index = mgx_state->input_name_indexes;
std::unordered_map& map_dynamic_range = mgx_state->dynamic_range_map;
@@ -1502,7 +1487,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
std::vector input_shapes;
if (no_input_shape) {
- LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again";
+ LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again";
for (auto& it : map_input_name_index) {
auto& name = it.first;
auto& index = it.second;
@@ -1514,7 +1499,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
input_shape_match = false;
}
} else {
- LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model";
+ LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model";
param_shapes = prog.get_parameter_shapes();
auto prog_output_shapes = prog.get_output_shapes();
@@ -1546,8 +1531,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
}
}
- // input shapes are different, needs to re-parse onnx and
- // re-compile the program
+ // input shapes are different, needs to reparse onnx and recompile the program
if (!input_shape_match) {
std::filesystem::path model_cache_file;
// empty cache path means the MXR caching is disabled - always compile
@@ -1602,7 +1586,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
if (param_shapes.size() > 0) {
for (auto&& name : param_shapes.names()) {
if (map_input_name_index.count(name) > 0) {
- LOGS_DEFAULT(INFO) << "Setting parameters for:" << name;
+ LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name;
auto input_tensor = ctx.GetInput(map_input_name_index[name]);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shape = tensor_info.GetShape();
@@ -1616,21 +1600,21 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
}
- LOGS_DEFAULT(INFO) << "Writing Raw tensor data ";
+ LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data ";
m.add(name, migraphx::argument(param_shapes[name],
const_cast(input_tensor.GetTensorRawData())));
}
- // It is a output argument
+ // It is an output argument
else {
- auto compute_output_index = [](const std::string& name) -> int {
- std::string out_name_prefix = "#output_";
- auto pos = name.find(out_name_prefix);
- if (pos == std::string::npos) {
+ auto compute_output_index = [](const std::string_view sv) -> int {
+ constexpr std::string_view out_name_prefix = "#output_";
+ const auto pos = sv.find(out_name_prefix);
+ if (pos == std::string_view::npos) {
return -1;
}
- std::string index_str = name.substr(pos + out_name_prefix.length());
- return std::stoi(index_str);
+ const auto index_str = sv.substr(pos + out_name_prefix.length());
+ return ToInteger(Trim(index_str, std::isdigit));
};
int output_index = compute_output_index(name);
@@ -1652,13 +1636,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
{
// lock to avoid race condition
- std::lock_guard lock(*(mgx_state->mgx_mu_ptr));
+ std::lock_guard lock(*mgx_state->mgx_mu_ptr);
void* rocm_stream;
Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream));
auto prog_outputs = prog.run_async(m, static_cast(rocm_stream));
- // In case of input parameters are reused as output parameter call hipMemcpy
+ // In the case of input parameters are reused as output parameter calls hipMemcpy
auto output_num = prog_outputs.size();
if (prog_output_indices.size() < output_num) {
for (std::size_t i = 0; i < output_num; ++i) {
@@ -1677,7 +1661,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector&
static_cast(rocm_stream)));
}
}
- };
+ }
return Status::OK();
};
@@ -1693,30 +1677,26 @@ void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis
RegisterMIGraphXStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/);
}
-OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
- if (mem_type == OrtMemTypeCPUInput) return OrtDevice();
- if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/);
+OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(const OrtMemType mem_type) const {
+ if (mem_type == OrtMemTypeCPUInput) return {};
+ if (mem_type == OrtMemTypeCPUOutput) return {OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/};
return default_device_;
}
Status MIGraphXExecutionProvider::Sync() const {
- HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr)));
-
- auto status = hipStreamQuery(stream_);
- if (status != hipSuccess) {
- return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL);
+ HIP_CALL_THROW(hipStreamSynchronize(nullptr));
+ if (hipStreamQuery(stream_) != hipSuccess) {
+ return {common::ONNXRUNTIME, common::EP_FAIL};
}
return Status::OK();
}
-Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
+Status MIGraphXExecutionProvider::OnRunStart(const RunOptions& /*run_options*/) {
return Status::OK();
}
-Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
- auto status = hipStreamQuery(stream_);
-
- if (status != hipSuccess) {
+Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const RunOptions& /*run_options*/) {
+ if (hipStreamQuery(stream_) != hipSuccess) {
HIP_CALL_THROW(hipStreamSynchronize(stream_));
}
return Status::OK();
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
index cf1ad9711c2aa..c5c1f0f2f1650 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
@@ -27,7 +27,6 @@ constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH";
constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH";
constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";
-
} // namespace migraphx_env_vars
// Information to construct kernel function state.
@@ -54,32 +53,32 @@ struct MIGraphXFuncState {
};
// Logical device representation.
-class MIGraphXExecutionProvider : public IExecutionProvider {
+class MIGraphXExecutionProvider final : public IExecutionProvider {
public:
explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info);
- ~MIGraphXExecutionProvider();
+ ~MIGraphXExecutionProvider() override = default;
void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info);
void get_flags_from_env();
- void print_migraphx_ep_flags();
+ void print_migraphx_ep_flags() const;
Status Sync() const override;
- Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
+ Status OnRunStart(const RunOptions& run_options) override;
- Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
+ Status OnRunEnd(bool sync_stream, const RunOptions& run_options) override;
std::vector>
- GetCapability(const onnxruntime::GraphViewer& graph_viewer,
+ GetCapability(const GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_lookup*/,
const GraphOptimizerRegistry& /* graph_optimizer_registry */,
IResourceAccountant* /* resource_accountant */) const override;
- common::Status Compile(const std::vector& fused_nodes,
- std::vector& node_compute_funcs) override;
+ Status Compile(const std::vector& fused_nodes,
+ std::vector& node_compute_funcs) override;
- virtual std::shared_ptr GetKernelRegistry() const override;
- std::unique_ptr GetDataTransfer() const override;
+ std::shared_ptr GetKernelRegistry() const override;
+ std::unique_ptr GetDataTransfer() const override;
static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy,
MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg);
@@ -111,7 +110,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
migraphx::target t_;
std::mutex mgx_mu_;
hipStream_t stream_ = nullptr;
- hipDeviceProp_t device_prop_;
+ hipDeviceProp_t device_prop_{};
bool exhaustive_tune_ = false;
mutable std::filesystem::path model_path_;
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h
index cb25db032ebf2..a1389a2e4b680 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h
@@ -332,4 +332,28 @@ inline std::string GenerateGraphId(const GraphViewer& graph_viewer) {
return std::string{s.data(), ptr};
}
+inline std::string_view TrimLeft(std::string_view sv, int (*fn)(int) = std::isspace) {
+ return sv.substr(0, sv.end() - std::find_if(sv.begin(), sv.end(), [fn](int ch) {
+ return fn(ch);
+ }));
+}
+
+inline std::string_view TrimRight(std::string_view sv, int (*fn)(int) = std::isspace) {
+ return sv.substr(sv.end() - std::find_if(sv.rbegin(), sv.rend(), [fn](int ch) {
+ return fn(ch);
+ }).base());
+}
+
+inline std::string_view Trim(std::string_view sv, int (*fn)(int) = std::isspace) {
+ return TrimRight(TrimLeft(sv, fn), fn);
+}
+
+inline int ToInteger(const std::string_view sv) {
+ int result = 0;
+ if (auto [_, ec] = std::from_chars(sv.data(), sv.data() + sv.length(), result); ec == std::errc()) {
+ return result;
+ }
+ ORT_THROW("invalid input for conversion to integer");
+}
+
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
index 4a088c8c9b6ee..30650005bbc21 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
+++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
@@ -1,29 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License
#include
+#include
#include
+#include
+
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#include
+#include
+#endif
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_provider_factory.h"
-#include "migraphx_execution_provider.h"
-#include "migraphx_execution_provider_info.h"
-#include "migraphx_provider_factory_creator.h"
-#include "migraphx_allocator.h"
-#include "gpu_data_transfer.h"
+#include "core/providers/migraphx/migraphx_execution_provider.h"
+#include "core/providers/migraphx/migraphx_execution_provider_info.h"
+#include "core/providers/migraphx/migraphx_provider_factory_creator.h"
+#include "core/providers/migraphx/migraphx_allocator.h"
+#include "core/providers/migraphx/gpu_data_transfer.h"
#include "core/framework/provider_options.h"
#include "core/session/onnxruntime_c_api.h"
-using namespace onnxruntime;
-
namespace onnxruntime {
void InitializeRegistry();
void DeleteRegistry();
-struct MIGraphXProviderFactory : IExecutionProviderFactory {
- MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {}
- ~MIGraphXProviderFactory() override {}
+struct MIGraphXProviderFactory final : IExecutionProviderFactory {
+ explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {}
+ ~MIGraphXProviderFactory() override = default;
std::unique_ptr CreateProvider() override;
@@ -36,15 +42,15 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() {
}
struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX {
- std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override {
+ std::unique_ptr CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, const char* name) override {
return std::make_unique(device_id, name);
}
- std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override {
+ std::unique_ptr CreateMIGraphXPinnedAllocator(const OrtDevice::DeviceId device_id, const char* name) override {
return std::make_unique(device_id, name);
}
- void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override {
+ void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, const size_t count) override {
// hipMemcpy() operates on the default stream
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice));
@@ -53,24 +59,26 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX {
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer
// to device memory, but the DMA to final destination may not have completed.
- HIP_CALL_THROW(hipStreamSynchronize(0));
+ HIP_CALL_THROW(hipStreamSynchronize(nullptr));
}
// Used by onnxruntime_pybind_state.cc
- void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override {
+ void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, const size_t count) override {
// For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed.
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost));
}
- std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override {
- return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg);
+ std::shared_ptr CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, const size_t mem_limit, const ArenaExtendStrategy arena_extend_strategy, const MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override {
+ return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg);
}
} g_info;
-struct MIGraphX_Provider : Provider {
+struct MIGraphX_Provider final : Provider {
void* GetInfo() override { return &g_info; }
- std::shared_ptr CreateExecutionProviderFactory(int device_id) override {
+ virtual ~MIGraphX_Provider() = default;
+
+ std::shared_ptr CreateExecutionProviderFactory(const int device_id) override {
MIGraphXExecutionProviderInfo info;
info.device_id = static_cast(device_id);
info.target_device = "gpu";
@@ -96,14 +104,14 @@ struct MIGraphX_Provider : Provider {
if (options.migraphx_cache_dir != nullptr) {
info.model_cache_dir = options.migraphx_cache_dir;
}
- info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy);
+ info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy);
info.mem_limit = options.migraphx_mem_limit;
return std::make_shared(info);
}
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
- auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options);
- auto& migx_options = *reinterpret_cast(provider_options);
+ auto internal_options = MIGraphXExecutionProviderInfo::FromProviderOptions(options);
+ auto& migx_options = *static_cast(provider_options);
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_bf16_enable = internal_options.bf16_enable;
@@ -123,7 +131,7 @@ struct MIGraphX_Provider : Provider {
strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size);
#endif
dest[str_size] = '\0';
- migx_options.migraphx_int8_calibration_table_name = (const char*)dest;
+ migx_options.migraphx_int8_calibration_table_name = static_cast(dest);
}
migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
@@ -135,10 +143,23 @@ struct MIGraphX_Provider : Provider {
ProviderOptions GetProviderOptions(const void* provider_options) override {
auto& options = *reinterpret_cast(provider_options);
- return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options);
+ return MIGraphXExecutionProviderInfo::ToProviderOptions(options);
}
void Initialize() override {
+#ifdef _WIN32
+ HMODULE module = nullptr;
+ if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+ static_cast(static_cast(InitializeRegistry)),
+ &module) != 0) {
+ char buffer[MAX_PATH];
+ if (GetModuleFileName(module, buffer, sizeof(buffer)) != 0) {
+ PathRemoveFileSpec(buffer);
+ SetDllDirectory(buffer);
+ }
+ }
+#endif
InitializeRegistry();
}
@@ -151,7 +172,6 @@ struct MIGraphX_Provider : Provider {
} // namespace onnxruntime
extern "C" {
-
ORT_API(onnxruntime::Provider*, GetProvider) {
return &onnxruntime::g_provider;
}
diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h
index d1c9457bafa0f..313603a4ecbf0 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h
@@ -1,7 +1,11 @@
-// Copyright 2019 AMD AMDMIGraphX
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License
-#include "core/framework/provider_options.h"
-#include "onnxruntime_c_api.h"
+#pragma once
+
+#include
+#include "core/framework/ortdevice.h"
+#include "core/session/onnxruntime_c_api.h"
namespace onnxruntime {
class IAllocator;
@@ -12,11 +16,11 @@ enum class ArenaExtendStrategy : int32_t;
struct MIGraphXExecutionProviderExternalAllocatorInfo;
struct ProviderInfo_MIGraphX {
- virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0;
- virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0;
+ virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0;
+ virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0;
virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0;
virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0;
- virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0;
+ virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0;
protected:
~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance
diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h
index 02d30ad0f6fbb..db169b9e2f5a9 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h
@@ -6,6 +6,7 @@
#include
#include "core/providers/providers.h"
+#include "core/framework/provider_options.h"
struct OrtMIGraphXProviderOptions;
@@ -14,5 +15,6 @@ namespace onnxruntime {
struct MIGraphXProviderFactoryCreator {
static std::shared_ptr Create(int device_id);
static std::shared_ptr Create(const OrtMIGraphXProviderOptions* options);
+ static std::shared_ptr Create(const ProviderOptions&);
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 02696524042e7..994082c272fd2 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -3029,7 +3029,12 @@ static constexpr OrtApi ort_api_1_to_22 = {
&OrtApis::GetEpApi,
// End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information)
-};
+ &OrtApis::CreateMIGraphXProviderOptions,
+ &OrtApis::UpdateMIGraphXProviderOptions,
+ &OrtApis::GetMIGraphXProviderOptionsAsString,
+ &OrtApis::ReleaseMIGraphXProviderOptions,
+ &OrtApis::UpdateMIGraphXProviderOptionsWithValue,
+ &OrtApis::GetMIGraphXProviderOptionsByName};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
static_assert(sizeof(OrtApiBase) == sizeof(void*) * 2, "New methods can't be added to OrtApiBase as it is not versioned");
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index addeb36d4087d..4692733d5b9e8 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -597,4 +597,15 @@ ORT_API(const OrtKeyValuePairs*, EpDevice_EpOptions, _In_ const OrtEpDevice* ep_
ORT_API(const OrtHardwareDevice*, EpDevice_Device, _In_ const OrtEpDevice* ep_device);
ORT_API(const OrtEpApi*, GetEpApi);
+ORT_API_STATUS_IMPL(CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out);
+ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptions, _Inout_ OrtMIGraphXProviderOptions* migraphx_options,
+ _In_reads_(num_keys) const char* const* provider_options_keys,
+ _In_reads_(num_keys) const char* const* provider_options_values,
+ size_t num_keys);
+ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsAsString, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr);
+ORT_API(void, ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions*);
+
+ORT_API_STATUS_IMPL(UpdateMIGraphXProviderOptionsWithValue, _Inout_ OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _In_ void* value);
+ORT_API_STATUS_IMPL(GetMIGraphXProviderOptionsByName, _In_ const OrtMIGraphXProviderOptions* migraphx_options, _In_ const char* key, _Outptr_ void** ptr);
+
} // namespace OrtApis
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index 7fcaee48581f6..8d079ff485254 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -6,6 +6,7 @@
#include
#include
+#include
#include "core/common/inlined_containers.h"
#include "core/common/path_string.h"
@@ -114,6 +115,7 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn
#include "core/providers/rocm/rocm_provider_factory.h"
#include "core/providers/dnnl/dnnl_provider_factory.h"
#include "core/providers/migraphx/migraphx_provider_factory.h"
+#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#include "core/providers/openvino/openvino_provider_factory.h"
#include "core/providers/tensorrt/tensorrt_provider_factory.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
@@ -1961,7 +1963,7 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(in
return s_library_dnnl.Get().CreateExecutionProviderFactory(use_arena);
}
-std::shared_ptr MIGraphXProviderFactoryCreator::Create(int device_id) {
+std::shared_ptr MIGraphXProviderFactoryCreator::Create(const int device_id) {
return s_library_migraphx.Get().CreateExecutionProviderFactory(device_id);
}
@@ -2065,6 +2067,12 @@ std::shared_ptr NvProviderFactoryCreator::Create(
return nullptr;
}
+std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) {
+ OrtMIGraphXProviderOptions migraphx_options;
+ s_library_migraphx.Get().UpdateProviderOptions(&migraphx_options, provider_options);
+ return s_library_migraphx.Get().CreateExecutionProviderFactory(&migraphx_options);
+}
+
std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) {
return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options);
}
@@ -2655,7 +2663,8 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions,
defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || \
defined(USE_CANN) || \
defined(USE_DNNL) || \
- defined(USE_ROCM)
+ defined(USE_ROCM) || \
+ defined(USE_MIGRAPHX)
static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterator& begin,
const onnxruntime::ProviderOptions::iterator& end) {
std::ostringstream options;
@@ -3170,3 +3179,123 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_
return nullptr;
API_IMPL_END
}
+
+ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) {
+ API_IMPL_BEGIN
+#ifdef USE_MIGRAPHX
+ auto migraphx_options = std::make_unique();
+ memset(migraphx_options.get(), 0, sizeof(OrtMIGraphXProviderOptions));
+ *out = migraphx_options.release();
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(out);
+ return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build.");
+#endif
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions,
+ _Inout_ OrtMIGraphXProviderOptions* migraphx_options,
+ _In_reads_(num_keys) const char* const* provider_options_keys,
+ _In_reads_(num_keys) const char* const* provider_options_values,
+ size_t num_keys) {
+ API_IMPL_BEGIN
+#ifdef USE_MIGRAPHX
+ onnxruntime::ProviderOptions provider_options_map;
+ for (size_t i = 0; i != num_keys; ++i) {
+ if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' ||
+ provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty");
+ }
+
+ provider_options_map[provider_options_keys[i]] = provider_options_values[i];
+ }
+
+ onnxruntime::s_library_migraphx.Get().UpdateProviderOptions(reinterpret_cast(migraphx_options), provider_options_map);
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(provider_options_keys);
+ ORT_UNUSED_PARAMETER(provider_options_values);
+ ORT_UNUSED_PARAMETER(num_keys);
+ return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build.");
+#endif
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString,
+ _In_ const OrtMIGraphXProviderOptions* migraphx_options,
+ _Inout_ OrtAllocator* allocator,
+ _Outptr_ char** ptr) {
+ API_IMPL_BEGIN
+#ifdef USE_MIGRAPHX
+ onnxruntime::ProviderOptions options =
+ onnxruntime::s_library_migraphx.Get().GetProviderOptions(reinterpret_cast(migraphx_options));
+ std::string options_str = BuildOptionsString(options.begin(), options.end());
+ *ptr = onnxruntime::StrDup(options_str, allocator);
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(allocator);
+ ORT_UNUSED_PARAMETER(ptr);
+ return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build.");
+#endif
+ API_IMPL_END
+}
+
+ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) {
+#ifdef USE_MIGRAPHX
+ std::unique_ptr p(ptr);
+ if (ptr->migraphx_cache_dir != nullptr) {
+ onnxruntime::AllocatorDefaultFree(const_cast(ptr->migraphx_cache_dir));
+ }
+#else
+ ORT_UNUSED_PARAMETER(ptr);
+#endif
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue,
+ _Inout_ OrtMIGraphXProviderOptions* migraphx_options,
+ _In_ const char* key,
+ _In_ void* value) {
+ API_IMPL_BEGIN
+#ifdef USE_MIGRAPHX
+ auto sv = std::string_view{key};
+ OrtAllocator* allocator;
+ GetAllocatorWithDefaultOptions(&allocator);
+ if (sv == onnxruntime::migraphx_provider_option::kDeviceId) {
+ auto dv = std::string_view{static_cast(value)};
+ if (std::from_chars(dv.data(), dv.data() + dv.length(), migraphx_options->device_id).ec == std::errc::invalid_argument) {
+ ORT_THROW("Cannot convert from string to integer - invalid argument");
+ }
+ } else if (sv == onnxruntime::migraphx_provider_option::kModelCacheDir) {
+ auto sd = std::string_view{static_cast(value)};
+ migraphx_options->migraphx_cache_dir = onnxruntime::StrDup(sd.data(), allocator);
+ } else {
+ ORT_THROW("Unsupported provider option name: '" + std::string{sv} + "'");
+ }
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(key);
+ ORT_UNUSED_PARAMETER(value);
+ return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build.");
+#endif
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName,
+ _In_ const OrtMIGraphXProviderOptions* migraphx_options,
+ _In_ const char* key,
+ _Outptr_ void** ptr) {
+ API_IMPL_BEGIN
+#ifdef USE_MIGRAPHX
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(key);
+ ORT_UNUSED_PARAMETER(ptr);
+ return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build.");
+#endif
+ API_IMPL_END
+}
diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc
index 1f74ee3b3f2ee..4c0d3e8d22354 100644
--- a/onnxruntime/core/session/provider_registration.cc
+++ b/onnxruntime/core/session/provider_registration.cc
@@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
VitisAI,
CoreML,
NvTensorRtRtx, // TensorRt EP for RTX GPUs.
+ MIGraphX
};
struct EpToAppend {
@@ -109,7 +110,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
const char* canonical_name = nullptr;
};
- static std::array supported_eps = {
+ static std::array supported_eps = {
EpToAppend{EpID::DML, "DML", kDmlExecutionProvider},
EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider},
EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider},
@@ -121,7 +122,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
EpToAppend{EpID::JS, "JS", kJsExecutionProvider},
EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider},
EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider},
- EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}};
+ EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider},
+ EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}};
ProviderOptions provider_options;
OrtStatus* status = ParseProviderOptions(provider_options_keys,
@@ -279,6 +281,14 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value)));
#else
status = create_not_supported_status();
+#endif
+ break;
+ }
+ case EpID::MIGraphX: {
+#if defined(USE_MIGRAPHX)
+ options->provider_factories.push_back(MIGraphXProviderFactoryCreator::Create(provider_options));
+#else
+ status = create_not_supported_status();
#endif
break;
}
@@ -657,4 +667,55 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
ORT_UNUSED_PARAMETER(num_keys);
return CreateNotEnabledStatus("VitisAI");
}
+
+ORT_API_STATUS_IMPL(OrtApis::CreateMIGraphXProviderOptions, _Outptr_ OrtMIGraphXProviderOptions** out) {
+ ORT_UNUSED_PARAMETER(out);
+ return CreateNotEnabledStatus("MIGraphX");
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptions,
+ _Inout_ OrtMIGraphXProviderOptions* migraphx_options,
+ _In_reads_(num_keys) const char* const* provider_options_keys,
+ _In_reads_(num_keys) const char* const* provider_options_values,
+ size_t num_keys) {
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(provider_options_keys);
+ ORT_UNUSED_PARAMETER(provider_options_values);
+ ORT_UNUSED_PARAMETER(num_keys);
+ return CreateNotEnabledStatus("MIGraphX");
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsAsString,
+ _In_ const OrtMIGraphXProviderOptions* migraphx_options, _Inout_ OrtAllocator* allocator,
+ _Outptr_ char** ptr) {
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(allocator);
+ ORT_UNUSED_PARAMETER(ptr);
+ return CreateStatus(ORT_FAIL, "MIGraphX execution provider is not enabled in this build.");
+}
+
+ORT_API(void, OrtApis::ReleaseMIGraphXProviderOptions, _Frees_ptr_opt_ OrtMIGraphXProviderOptions* ptr) {
+ ORT_UNUSED_PARAMETER(ptr);
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UpdateMIGraphXProviderOptionsWithValue,
+ _Inout_ OrtMIGraphXProviderOptions* migraphx_options,
+ _In_ const char* key,
+ _In_ void* value) {
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(key);
+ ORT_UNUSED_PARAMETER(value);
+ return CreateNotEnabledStatus("MIGraphX");
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetMIGraphXProviderOptionsByName,
+ _In_ const OrtMIGraphXProviderOptions* migraphx_options,
+ _In_ const char* key,
+ _Outptr_ void** ptr) {
+ ORT_UNUSED_PARAMETER(migraphx_options);
+ ORT_UNUSED_PARAMETER(key);
+ ORT_UNUSED_PARAMETER(ptr);
+ return CreateNotEnabledStatus("MIGraphX");
+}
+
#endif
diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
index 0428b19357d51..55026d9aab986 100644
--- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
+++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
@@ -207,6 +207,7 @@ std::unique_ptr GetGPUDataTransfer() {
#endif
#ifdef USE_MIGRAPHX
+
void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes);
}
diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc
index a7e8e0345e9d1..2e52c192ec521 100644
--- a/onnxruntime/python/onnxruntime_pybind_state_common.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc
@@ -51,7 +51,6 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten
onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info{};
#endif
-
void DlpackCapsuleDestructor(PyObject* data) {
DLManagedTensor* dlmanaged_tensor = reinterpret_cast(PyCapsule_GetPointer(data, "dltensor"));
if (dlmanaged_tensor) {
diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h
index 82502f17e808e..9740a857f7577 100644
--- a/onnxruntime/python/onnxruntime_pybind_state_common.h
+++ b/onnxruntime/python/onnxruntime_pybind_state_common.h
@@ -37,13 +37,13 @@ struct OrtStatus {
#define BACKEND_PROC "CPU"
#endif
-#if USE_DNNL
+#ifdef USE_DNNL
#define BACKEND_DNNL "-DNNL"
#else
#define BACKEND_DNNL ""
#endif
-#if USE_MIGRAPHX
+#ifdef USE_MIGRAPHX
#define BACKEND_MIGRAPHX "-MIGRAPHX"
#else
#define BACKEND_MIGRAPHX ""
diff --git a/setup.py b/setup.py
index 1e426ea8e060b..33e47757d05e0 100644
--- a/setup.py
+++ b/setup.py
@@ -237,8 +237,10 @@ def run(self):
rocm_dependencies = [
"libamd_comgr.so.2",
+ "libamd_comgr.so.3",
"libamdhip64.so.5",
"libamdhip64.so.6",
+ "libamdhip64.so.7",
"libdrm.so.2",
"libdrm_amdgpu.so.1",
"libelf.so.1",
@@ -250,13 +252,18 @@ def run(self):
"libnuma.so.1",
"librccl.so.1",
"libhipblas.so.2",
+ "libhipblas.so.3"
+ "libhipblaslt.so.1",
"librocblas.so.3",
"librocblas.so.4",
+ "librocblas.so.5",
"librocfft.so.0",
"libroctx64.so.4",
"librocm_smi64.so.5",
"librocm_smi64.so.6",
+ "librocm_smi64.so.7",
"libroctracer64.so.4",
+ "librocsolver.so.0",
"libtinfo.so.6",
"libmigraphx_c.so.3",
"libmigraphx.so.2",
@@ -410,6 +417,7 @@ def finalize_options(self):
libs.extend(["onnxruntime_providers_nv_tensorrt_rtx.dll"])
libs.extend(["onnxruntime_providers_openvino.dll"])
libs.extend(["onnxruntime_providers_cuda.dll"])
+ libs.extend(["onnxruntime_providers_migraphx.dll"])
libs.extend(["onnxruntime_providers_vitisai.dll"])
libs.extend(["onnxruntime_providers_qnn.dll"])
# DirectML Libs
diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py
index 52b7837b142b5..c23911b1cac76 100644
--- a/tools/ci_build/build.py
+++ b/tools/ci_build/build.py
@@ -1894,6 +1894,7 @@ def build_nuget_package(
use_winml,
use_qnn,
use_dml,
+ use_migraphx,
enable_training_apis,
msbuild_extra_options,
):
@@ -1931,6 +1932,9 @@ def build_nuget_package(
elif use_tensorrt:
execution_provider = "/p:ExecutionProvider=tensorrt"
package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT"
+ elif use_migraphx:
+ execution_provider = "/p:ExecutionProvider=migraphx"
+ package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX"
elif use_dnnl:
execution_provider = "/p:ExecutionProvider=dnnl"
package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL"
@@ -2518,6 +2522,7 @@ def main():
getattr(args, "use_winml", False),
args.use_qnn,
getattr(args, "use_dml", False),
+ args.use_migraphx,
args.enable_training_apis,
normalize_arg_list(args.msbuild_extra_options),
)
diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py
index dbb7be829494e..abc848e5525ca 100644
--- a/tools/ci_build/build_args.py
+++ b/tools/ci_build/build_args.py
@@ -608,8 +608,9 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None:
rocm_group.add_argument("--use_rocm", action="store_true", help="Enable ROCm EP.")
rocm_group.add_argument("--rocm_version", help="ROCm stack version.")
rocm_group.add_argument("--rocm_home", help="Path to ROCm installation directory.")
- rocm_group.add_argument("--rocm_gfx_arch", help='Provide gfx arch. Example --rocm_gfx_arch gfx942'
- ' or --rocm_gfx_arch "gfx90a;gfx942"')
+ rocm_group.add_argument(
+ "--rocm_gfx_arch", help='Provide gfx arch. Example --rocm_gfx_arch gfx942 or --rocm_gfx_arch "gfx90a;gfx942"'
+ )
# ROCm-specific profiling
rocm_group.add_argument(
"--enable_rocm_profiling",
diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py
index 419fdd47458f7..737931b7c0c5e 100644
--- a/tools/nuget/generate_nuspec_for_native_nuget.py
+++ b/tools/nuget/generate_nuspec_for_native_nuget.py
@@ -44,7 +44,13 @@ def get_package_name(os, cpu_arch, ep, is_training_package):
def is_this_file_needed(ep, filename, package_name):
if package_name == "Microsoft.ML.OnnxRuntime.Gpu":
return False
- return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename)
+ if package_name == "Microsoft.ML.OnnxRuntime.MIGraphX":
+ return False
+ return (
+ (ep != "cuda" or "cuda" in filename)
+ and (ep != "tensorrt" or "cuda" not in filename)
+ and (ep != "migraphx" or "migraphx" not in filename)
+ )
# nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files
@@ -64,7 +70,8 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs,
if (
child_file.suffix in suffixes
and is_this_file_needed(ep, child_file.name, package_name)
- and package_name != "Microsoft.ML.OnnxRuntime.Gpu.Linux"
+ and package_name
+ not in ["Microsoft.ML.OnnxRuntime.Gpu.Linux", "Microsoft.ML.OnnxRuntime.MIGraphX.Linux"]
):
files_list.append(
''
@@ -95,6 +102,7 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs,
child_file.suffix == ".so"
and is_this_file_needed(ep, child_file.name, package_name)
and package_name != "Microsoft.ML.OnnxRuntime.Gpu.Windows"
+ and package_name != "Microsoft.ML.OnnxRuntime.MIGraphX.Windows"
):
files_list.append(
''
@@ -138,7 +146,7 @@ def parse_arguments():
required=False,
default="None",
type=str,
- choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"],
+ choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"],
help="The selected execution provider for this build.",
)
parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.")
@@ -182,6 +190,12 @@ def generate_description(line_list, package_name):
description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA."
elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name:
description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA."
+ elif "Microsoft.ML.OnnxRuntime.MIGraphX.Linux" in package_name:
+ description = "This package contains Linux native shared library artifacts for ONNX Runtime with the MIGraphX."
+ elif "Microsoft.ML.OnnxRuntime.MIGraphX.Windows" in package_name:
+ description = (
+ "This package contains Windows native shared library artifacts for ONNX Runtime with the MIGraphX."
+ )
elif "Intel.ML.OnnxRuntime" in package_name:
description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO."
elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package
@@ -224,6 +238,8 @@ def add_common_dependencies(xml_text, package_name, version):
if package_name == "Microsoft.ML.OnnxRuntime.Gpu":
xml_text.append('')
xml_text.append('')
+ if package_name == "Microsoft.ML.OnnxRuntime.MIGraphX":
+ xml_text.append('')
def generate_dependencies(xml_text, package_name, version):
@@ -358,6 +374,9 @@ def generate_files(line_list, args):
is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning"
is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe"
is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN"
+ is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX"
+ is_migraphx_win_sub_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX.Windows"
+ is_migraphx_linux_sub_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX.Linux"
is_training_package = args.package_name in [
"Microsoft.ML.OnnxRuntime.Training",
"Microsoft.ML.OnnxRuntime.Training.Gpu",
@@ -383,6 +402,7 @@ def generate_files(line_list, args):
"openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll",
"cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll",
"qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll",
+ "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll",
"onnxruntime_perf_test": "onnxruntime_perf_test.exe",
"onnx_test_runner": "onnx_test_runner.exe",
}
@@ -420,7 +440,7 @@ def generate_files(line_list, args):
include_dir = f"{build_dir}\\native\\include"
# Sub.Gpu packages do not include the onnxruntime headers
- if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu":
+ if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX":
files_list.append(
"'
)
+ if args.execution_provider == "migraphx" or (is_migraphx_win_sub_package and not is_ado_packaging_build):
+ files_list.append(
+ "'
+ )
+ files_list.append(
+ "'
+ )
+
# process all other library dependencies
- if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package:
+ if is_cpu_package or is_cuda_gpu_package or is_migraphx_package or is_dml_package or is_mklml_package:
# Process dnnl dependency
if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["dnnl"])):
files_list.append(
@@ -898,6 +938,9 @@ def generate_files(line_list, args):
or is_cuda_gpu_linux_sub_package
or is_cuda_gpu_win_sub_package
or is_rocm_gpu_package
+ or is_migraphx_package
+ or is_migraphx_linux_sub_package
+ or is_migraphx_win_sub_package
or is_dml_package
or is_mklml_package
or is_snpe_package