diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index c6e8fd732e70..f3eedc97b9d6 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -22,7 +22,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de } #endif - TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); + TORCH_CHECK([device supportsFamily:MTLGPUFamilyCommon2], "Missing Metal support for MTLGPUFamilyCommon2"); return languageVersion; } @@ -78,29 +78,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de // Check that MacOS 12.3+ version of MPS framework is available // Create the MPSGraph and check method introduced in 12.3+ // which is used by MPS backend. - id mpsCD = NSClassFromString(@"MPSGraph"); - - if ([mpsCD instancesRespondToSelector:@selector - (LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) { - return; - } - - NSArray* devices = [MTLCopyAllDevices() autorelease]; - for (unsigned long i = 0; i < [devices count]; i++) { - id device = devices[i]; - if ([device isLowPower]) { // exclude Intel GPUs - continue; - } - if (![device supportsFamily:MTLGPUFamilyMac2]) { - // Exclude devices that does not support Metal 2.0 - // Virtualised MPS device on MacOS 12.6 should fail this check - TORCH_WARN("Skipping device ", [[device name] UTF8String], " that does not support Metal 2.0"); - continue; - } - _mtl_device = [device retain]; - break; - } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); + _mtl_device = [MTLCreateSystemDefaultDevice() retain]; } bool MPSDevice::isMacOS13Plus(MacOSVersion version) const { diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 4d10b3db1aa3..7282acf1bc9b 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -338,6 +338,7 @@ def generate( # 1. https://cmake.org/cmake/help/latest/manual/cmake.1.html#synopsis # 2. https://stackoverflow.com/a/27169347 args.append(base_dir) + args.append('-DUSE_OPENMP=OFF') self.run(args, env=my_env) def build(self, my_env: Dict[str, str]) -> None: diff --git a/torch/__init__.py b/torch/__init__.py index 1c4d5e45b305..12815a05ad38 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -21,6 +21,8 @@ import pdb import importlib import importlib.util +import json +import draw_things # multipy/deploy is setting this import before importing torch, this is the most # reliable way we have to detect if we're running within deploy. @@ -223,7 +225,10 @@ def load_shared_libraries(library_path): if library_path: global_deps_lib_path = os.path.join(library_path, 'lib', lib_name) try: - ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) + root = os.path.splitext(global_deps_lib_path)[0] + actual_rel_path = open(f'{root}.fwork', 'r').read().rstrip('\n') + actual_path = f'{draw_things.CONTEXT.app_dir}/{actual_rel_path}' + ctypes.CDLL(actual_path, mode=ctypes.RTLD_GLOBAL) except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is