Skip to content

Changes for PyTorch to work in Draw Things #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: 2.4-release
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
}
#endif

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
TORCH_CHECK([device supportsFamily:MTLGPUFamilyCommon2], "Missing Metal support for MTLGPUFamilyCommon2");
return languageVersion;
}

Expand Down Expand Up @@ -78,29 +78,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
// Check that MacOS 12.3+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 12.3+
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");

if ([mpsCD instancesRespondToSelector:@selector
(LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) {
return;
}

NSArray* devices = [MTLCopyAllDevices() autorelease];
for (unsigned long i = 0; i < [devices count]; i++) {
id<MTLDevice> device = devices[i];
if ([device isLowPower]) { // exclude Intel GPUs
continue;
}
if (![device supportsFamily:MTLGPUFamilyMac2]) {
// Exclude devices that does not support Metal 2.0
// Virtualised MPS device on MacOS 12.6 should fail this check
TORCH_WARN("Skipping device ", [[device name] UTF8String], " that does not support Metal 2.0");
continue;
}
_mtl_device = [device retain];
break;
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
_mtl_device = [MTLCreateSystemDefaultDevice() retain];
}

bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
Expand Down
1 change: 1 addition & 0 deletions tools/setup_helpers/cmake.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def generate(
# 1. https://cmake.org/cmake/help/latest/manual/cmake.1.html#synopsis
# 2. https://stackoverflow.com/a/27169347
args.append(base_dir)
args.append('-DUSE_OPENMP=OFF')
self.run(args, env=my_env)

def build(self, my_env: Dict[str, str]) -> None:
Expand Down
7 changes: 6 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import pdb
import importlib
import importlib.util
import json
import draw_things

# multipy/deploy is setting this import before importing torch, this is the most
# reliable way we have to detect if we're running within deploy.
Expand Down Expand Up @@ -223,7 +225,10 @@ def load_shared_libraries(library_path):
if library_path:
global_deps_lib_path = os.path.join(library_path, 'lib', lib_name)
try:
ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
root = os.path.splitext(global_deps_lib_path)[0]
actual_rel_path = open(f'{root}.fwork', 'r').read().rstrip('\n')
actual_path = f'{draw_things.CONTEXT.app_dir}/{actual_rel_path}'
ctypes.CDLL(actual_path, mode=ctypes.RTLD_GLOBAL)
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu12 is
Expand Down
Loading