Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #24436: Fix variadic reduction shared memory estimation. #24498

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

PR #24436: Fix variadic reduction shared memory estimation.

Imported from GitHub PR #24436

Currently, the logic is broken for variadic reductions with heterogeneous input types, since it always uses the first input's primitive type to estimate the shared memory buffer size. It should be summing up the primitive sizes instead.

Also expand the comments a bit to explain better what's going on there.

This should fix jax-ml/jax#27190.
Copybara import of the project:

--
8bdb3fb by Johannes Reifferscheid [email protected]:

Fix variadic reduction shared memory estimation.

Currently, the logic is broken for variadic reductions with heterogeneous
input types, since it always uses the first input's primitive type to
estimate the shared memory buffer size. It should be summing up the
primitive sizes instead.

Also expand the comments a bit to explain better what's going on there.

Merging this change closes #24436

FUTURE_COPYBARA_INTEGRATE_REVIEW=#24436 from jreiffers:variadic 8bdb3fb

Imported from GitHub PR #24436

Currently, the logic is broken for variadic reductions with heterogeneous input types, since it always uses the first input's primitive type to estimate the shared memory buffer size. It should be summing up the primitive sizes instead.

Also expand the comments a bit to explain better what's going on there.

This should fix jax-ml/jax#27190.
Copybara import of the project:

--
8bdb3fb by Johannes Reifferscheid <[email protected]>:

Fix variadic reduction shared memory estimation.

Currently, the logic is broken for variadic reductions with heterogeneous
input types, since it always uses the first input's primitive type to
estimate the shared memory buffer size. It should be summing up the
primitive sizes instead.

Also expand the comments a bit to explain better what's going on there.

Merging this change closes #24436

FUTURE_COPYBARA_INTEGRATE_REVIEW=#24436 from jreiffers:variadic 8bdb3fb
PiperOrigin-RevId: 743034494
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ptxas error: Fusion uses too much shared data
1 participant