Skip to content

Conversation

@AratiGanesh
Copy link

@AratiGanesh AratiGanesh commented Jul 29, 2025

This fix addresses a failure on the MI200 GPU, where the test case was unable to run due to the input array size not being divisible by the number of available GPUs, which is 7 for MI200. The solution trims the input array to ensure its size is divisible by the number of GPUs, allowing for proper distribution across the devices.

@AratiGanesh AratiGanesh requested a review from a team as a code owner July 29, 2025 20:41
@AratiGanesh AratiGanesh changed the title [0.6.0-UT] Fix test shard [0.6.0-UT] Fix Test Shard Jul 30, 2025
def test_shard_map(self):
if jtu.is_device_rocm:
self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map")
# if jtu.is_device_rocm:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove commented lines.

mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
n = len(jax.devices())
x = x[:(x.shape[0] // n) * n]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trims the array, however trimming the array could hinder the mathematical correction of this test. Instead of trimming we can also pad the array like I did in #509 to prevent information loss. Could you please explain how this is mathematically correct and we are not losing info?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants