Skip to content

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Aug 30, 2025

This PR implements a comprehensive pytest test suite for the 09_gemm_one_shot_all_reduce example and refactors the code to enable proper functional testing in CI environments.

Key Changes

Refactored Example Structure:

  • Extracted gemm_one_shot_all_reduce() function from benchmark.py that takes input matrices and algorithm parameters
  • Function encapsulates the core GEMM + all-reduce computation logic and returns the result matrix
  • Enables reuse between the command-line example (with arg-parsed parameters) and test suite (with parametrized inputs)
  • Maintains all original functionality while improving testability

Comprehensive Test Suite:

  • Removed error catching for missing AMD GPU libraries since ROCm is expected in CI
  • Added functional tests that validate actual GEMM computation results using the extracted function
  • Parametrized tests covering multiple data types (fp16, fp32, bf16) and matrix configurations
  • Tests validate algorithm requirements like matrix dimension divisibility constraints
  • Uses existing validation functions to verify mathematical correctness
  • Updated tolerance in tests to match example (atol=2)

Testing Features:

  • Tests core algorithm functionality rather than just import capabilities
  • Validates block size calculations and parameter handling
  • Checks file structure and content organization
  • Follows established patterns from other example tests

The refactored code enables robust continuous integration testing of the GEMM one-shot all-reduce algorithm while maintaining the original example's command-line interface and functionality.

Fixes #63.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@Copilot Copilot AI changed the title [WIP] Implement pytest for 09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py Implement pytest for 09_gemm_one_shot_all_reduce example Aug 30, 2025
@Copilot Copilot AI requested a review from mawad-amd August 30, 2025 22:10
Copilot finished work on behalf of mawad-amd August 30, 2025 22:10
…g in tests

- Extract gemm_one_shot_all_reduce() function from main() in benchmark.py
- Function takes input matrices and parameters, returns result matrix
- Can be reused by both command-line example and tests
- Remove error catching for missing AMD GPU libraries in tests (ROCm expected in CI)
- Simplify test suite to focus on functional testing using the extracted function
- Add parametrized tests for different data types and matrix sizes
- Tests now properly validate GEMM results using existing validation functions

Co-authored-by: mawad-amd <[email protected]>
@Copilot Copilot AI changed the title Implement pytest for 09_gemm_one_shot_all_reduce example Refactor 09_gemm_one_shot_all_reduce example with pytest and testable function Aug 31, 2025
@Copilot Copilot AI requested a review from mawad-amd August 31, 2025 03:27
Copilot finished work on behalf of mawad-amd August 31, 2025 03:27
assert result_C.dtype == datatype, f"Expected output dtype {datatype}, got {result_C.dtype}"

# Validate the result using the existing validation function
success = validate_gemm(A, B, result_C, shmem, atol=0.1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot increase the tolerance and match the one used in the example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the tolerance to match the example (atol=2). See commit f3a224d.

Copilot finished work on behalf of mawad-amd August 31, 2025 04:15
@Copilot Copilot AI requested a review from mawad-amd August 31, 2025 04:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement pytest for 09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py
2 participants