Skip to content

Switch to use CUDA driver APIs in Device constructor #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

leofang
Copy link
Member

@leofang leofang commented Feb 21, 2025

Blocked by #459 & #439 (comment).

Before this PR:

In [1]: %timeit Device()
658 ns ± 1.11 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

With this PR:

In [1]: %timeit Device()
412 ns ± 2.22 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

(Bindings are built from the main branch.)

Copy link
Contributor

copy-pr-bot bot commented Feb 21, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@leofang leofang self-assigned this Feb 22, 2025
@leofang leofang added the blocked This task is currently blocked by other tasks label Feb 22, 2025
@leofang leofang added enhancement Any code-related improvements P1 Medium priority - Should do cuda.core Everything related to the cuda.core module and removed blocked This task is currently blocked by other tasks labels Apr 5, 2025
@leofang leofang added this to the cuda.core beta 4 milestone Apr 5, 2025
@leofang leofang changed the title WIP: Switch to use CUDA driver APIs in Device constructor Switch to use CUDA driver APIs in Device constructor Apr 6, 2025
@leofang
Copy link
Member Author

leofang commented Apr 6, 2025

/ok to test

Copy link

github-actions bot commented Apr 6, 2025

@leofang leofang requested review from rwgk and ksimpson-work April 7, 2025 17:39
@leofang leofang marked this pull request as ready for review April 7, 2025 17:39
@leofang leofang marked this pull request as draft April 7, 2025 22:19
@leofang leofang marked this pull request as ready for review May 24, 2025 02:16
@leofang
Copy link
Member Author

leofang commented May 24, 2025

/ok to test c9fac0b

@leofang
Copy link
Member Author

leofang commented May 28, 2025

This is ready.

total = handle_return(runtime.cudaGetDeviceCount())
assert_type(device_id, int)
if not (0 <= device_id < total):
total = handle_return(driver.cuDeviceGetCount())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that the happy path is common, this driver.cuDeviceGetCount() call seems redundant.

Also assuming it's not actually worth the cycles checking for isinstance, we could replace the else block here with:

        elif device_id < 0:
            raise ValueError(f"device_id must be >= 0, got {device_id!r}")

Then below (new line 998) we could do this:

        try:
            return devices[device_id]
        except IndexError:
            raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id!r}")

WDYT?

Comment on lines +963 to +964
else:
ctx = handle_return(driver.cuCtxGetCurrent())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific error code or set of error codes we should be handling here? If the above driver.cuCtxGetDevice() call returns an error that we don't expect we should probably raise it as an exception instead of it propagating to the driver.cuCtxGetCurrent() call?

try:
devices = _tls.devices
except AttributeError:
total = handle_return(driver.cuDeviceGetCount())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can reuse total that was already calculated above?

devices = _tls.devices
except AttributeError:
total = handle_return(driver.cuDeviceGetCount())
devices = _tls.devices = []
for dev_id in range(total):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone tries to create a Device with a specific ID, why do we need to initialize all of the devices at that point? Instead of calling driver.cuDeviceGetCount() to get the number of devices to initialize all of _tls.devices, could we use driver.cuDeviceGet and lazily populate _tls.devices as devices are created?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had discussions about this code back in March. @leofang wrote here:

Each thread always has its own copy of _tls.devices.

I'm still unclear though TBH: Is that why we cannot lazily populate?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda.core Everything related to the cuda.core module enhancement Any code-related improvements P1 Medium priority - Should do
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants