-
Notifications
You must be signed in to change notification settings - Fork 168
Switch to use CUDA driver APIs in Device
constructor
#460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Device
constructorDevice
constructor
/ok to test |
|
/ok to test c9fac0b |
This is ready. |
total = handle_return(runtime.cudaGetDeviceCount()) | ||
assert_type(device_id, int) | ||
if not (0 <= device_id < total): | ||
total = handle_return(driver.cuDeviceGetCount()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming that the happy path is common, this driver.cuDeviceGetCount()
call seems redundant.
Also assuming it's not actually worth the cycles checking for isinstance
, we could replace the else
block here with:
elif device_id < 0:
raise ValueError(f"device_id must be >= 0, got {device_id!r}")
Then below (new line 998) we could do this:
try:
return devices[device_id]
except IndexError:
raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id!r}")
WDYT?
else: | ||
ctx = handle_return(driver.cuCtxGetCurrent()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific error code or set of error codes we should be handling here? If the above driver.cuCtxGetDevice()
call returns an error that we don't expect we should probably raise it as an exception instead of it propagating to the driver.cuCtxGetCurrent()
call?
try: | ||
devices = _tls.devices | ||
except AttributeError: | ||
total = handle_return(driver.cuDeviceGetCount()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can reuse total that was already calculated above?
devices = _tls.devices | ||
except AttributeError: | ||
total = handle_return(driver.cuDeviceGetCount()) | ||
devices = _tls.devices = [] | ||
for dev_id in range(total): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone tries to create a Device with a specific ID, why do we need to initialize all of the devices at that point? Instead of calling driver.cuDeviceGetCount()
to get the number of devices to initialize all of _tls.devices
, could we use driver.cuDeviceGet
and lazily populate _tls.devices
as devices are created?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had discussions about this code back in March. @leofang wrote here:
Each thread always has its own copy of
_tls.devices
.
I'm still unclear though TBH: Is that why we cannot lazily populate?
Blocked by #459 & #439 (comment).Before this PR:
With this PR:
(Bindings are built from the main branch.)