-
Notifications
You must be signed in to change notification settings - Fork 531
Remove nvidia-mathdx dependency
#2295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR removes the nvidia-mathdx dependency to fix cross-platform build issues where x86-specific binaries were being included on non-x86 architectures. Since TransformerEngine only uses the Philox4x32 pseudo-random number generator from mathdx for stochastic rounding in FP4 quantization kernels, the solution re-implements this limited functionality directly in transformer_engine/common/util/curanddx.hpp. The changes update build configurations (CMakeLists.txt, pyproject.toml, CI workflows), modify four CUDA kernel files to use the internal implementation, and maintain backward compatibility by preserving the same Philox4_32-10 algorithm behavior.
Potential Issues
Critical - Integer overflow handling in PRNG state increment
The new curanddx.hpp implementation has potential correctness issues in counter increment logic:
-
Line 66 in
philox_state_incr: The conditionif (nhi <= ctr.y)appears incorrect for overflow detection. Afterctr.y += nhi, overflow occurs whenctr.y < previous_value, but this check comparesnhiwith the newctr.yvalue, which may incorrectly skip carry propagation to higher words. This could truncate the effective counter space from 128 bits to 64 bits. -
Lines 72-80 in
philox_state_incr_hi: No overflow handling whenctr.woverflows afterctr.w += nhi. This could cause state collisions in parallel workloads with many subsequences.
For reference, CUDA's cuRAND Philox implementation uses explicit overflow tracking (e.g., if (new_val < old_val) carry = 1;). These bugs could manifest as duplicate random sequences in distributed training or multi-GPU scenarios where subsequence skipping is used.
Recommendation: Compare against the reference cuRAND Philox implementation or the original mathdx code to verify the carry logic is correct.
Minor - Missing test coverage
The PR description indicates no tests were added ("I have added tests that prove my fix is effective" is unchecked). Given this changes random number generation for stochastic rounding, numerical validation tests comparing outputs with the old mathdx implementation would be valuable to ensure identical behavior.
Minor - Documentation gap
The PR description notes "I have commented my code, particularly in hard-to-understand areas" is unchecked. The new curanddx.hpp has minimal inline documentation explaining the Philox algorithm, designated initializer usage, or the overflow handling strategy.
Confidence
2 out of 5 - While the overall approach is sound (vendoring a small dependency to fix portability issues), the potential overflow bugs in the PRNG implementation create significant risk for numerical correctness in production workloads. The lack of test coverage exacerbates this concern. The rest of the changes (build system, include paths, API migration) are straightforward and low-risk.
8 files reviewed, 1 comment
| __forceinline__ __device__ void philox_state_incr(size_t n) { | ||
| unsigned int nlo = (unsigned int)(n); | ||
| unsigned int nhi = (unsigned int)(n >> 32); | ||
|
|
||
| ctr.x += nlo; | ||
| if (ctr.x < nlo) nhi++; | ||
|
|
||
| ctr.y += nhi; | ||
| if (nhi <= ctr.y) return; | ||
| if (++ctr.z) return; | ||
| ++ctr.w; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: overflow logic has subtle carry propagation; the increment chain ctr.x += nlo; if (ctr.x < nlo) nhi++; ctr.y += nhi; if (nhi <= ctr.y) return; stops early if nhi <= ctr.y, but this skips carrying into ctr.z/ctr.w when ctr.y wraps. If ctr.y overflows after adding nhi, the carry is lost. should the condition on line 67 be if (nhi > ctr.y) (overflow occurred) instead of if (nhi <= ctr.y) return (no overflow)?
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer has not added any new commits after the previous feedback regarding the carry propagation bug in curanddx.hpp. No incremental changes were detected beyond the initial PR submission. The previous review correctly identified a critical carry logic issue in philox_state_incr (line 67 of curanddx.hpp) that remains unaddressed: the condition if (nhi <= ctr.y) return; will fail to propagate overflow carries from ctr.y into ctr.z and ctr.w, potentially breaking the PRNG sequence. Since no new code has been added, this summary only confirms that the previously identified issue persists and should be resolved before merging.
PR Description Notes:
- The checklist indicates "I have commented my code, particularly in hard-to-understand areas" is unchecked, which aligns with the observation that the custom Philox implementation in
curanddx.hpplacks inline comments explaining the subtle carry logic. - "I have added tests that prove my fix is effective or that my feature works" is also unchecked. Given that this PR replaces a well-tested external library with a custom implementation containing a potential carry bug, adding unit tests for the Philox RNG state increment and generate functions would be prudent.
- "New and existing unit tests pass locally with my changes" is unchecked, raising concerns about whether the incorrect carry logic in
curanddx.hpphas been validated against the previousnvidia-mathdxbehavior.
8 files reviewed, no comments
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
|
||
| __device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) { | ||
| __device__ __forceinline__ uint32_t | ||
| get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>& rng, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add a comment somewhere that philox4x32_native_state<10> means 10 rounds of philox4_32. The template parameter is the round number.
|
@dorispnvidia A question - putting this implementation here in this way will make it Apache licensed. Is that ok? What is the original license of this file? We could (and probably should) change the license for this specific file to match whatever license is in the original. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking temporarily to make sure we put the right license on the philox implementation.
I will ask Lukasz about it. |
|
Any code we'd develop for direct inclusion into TE should use TE license. Apache 2.0 is ok, code we contribute is only philox implementation we skipped MathDX API or any other pieces that we'd like to protect. |
Description
Remove
nvidia-mathdxdependency.Type of change
Changes
nvidia-mathdxdependency.Checklist: