Skip to content

Conversation

@magaonka-amd
Copy link

@magaonka-amd magaonka-amd commented Jan 6, 2026

ROCm's memory_stats() returns 'bytes_limit' instead of 'bytes_reservable_limit'. Add platform-specific handling to test_lax_full_like_efficient to support both CUDA and ROCm.

Motivation

The test tests/multi_device_test.py::MultiDeviceTest::test_lax_full_like_efficient fails on ROCm with KeyError: 'bytes_reservable_limit'. This occurs because ROCm's memory_stats() API returns different keys than CUDA's implementation. The goal is to make the test platform-agnostic and work correctly on both CUDA and ROCm platforms.

Technical Details

The change adds platform detection using jtu.is_device_rocm() to select the appropriate memory statistics key:

  • ROCm: Uses mem_stats['bytes_limit']
  • CUDA: Uses mem_stats['bytes_reservable_limit']

Below is the screenshot of my test run results without and with fix.
without_keyerror_fix

with_keyerror_fix

ROCm's memory_stats() returns 'bytes_limit' instead of
'bytes_reservable_limit'. Add platform-specific handling to
test_lax_full_like_efficient to support both CUDA and ROCm.
@magaonka-amd magaonka-amd marked this pull request as ready for review January 7, 2026 15:19
@magaonka-amd magaonka-amd requested a review from a team as a code owner January 7, 2026 15:19
@magaonka-amd magaonka-amd changed the base branch from main to rocm-jaxlib-v0.8.0 January 7, 2026 15:52
@magaonka-amd magaonka-amd changed the base branch from rocm-jaxlib-v0.8.0 to main January 7, 2026 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant