Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Currently we have the following:

  1. We have some small e2e model training tests for a small model on a few epochs as a sanity integration test.
  2. We also provide examples for users on an MNIST model and a single-layer encoder, with gradual increases in complexity as parallelism techniques are added.

Currently 1 and 2 are the same tests. To avoid this coupling, this PR introduces the following changes:

  • Example tests are still present but moved to the L1 test suite
  • A new test file tests/jax/test_distributed_sanity_e2e_train.py is introduced will be run in L0 and which trains a small model for a few epochs with the following features
    • All parallelisms including single-GPU are included in the same file. This simplifies the test code and allows us to keep track of all loss/accuracy tolerances in a single place.
    • Uses synthetic data so does not require pulling HuggingFace datasets

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Move examples/jax/... tests to L1 test scripts
  • Introduce a new test file tests/jax/test_distributed_sanity_e2e_train.py to consolidate all parallelism types covered by the examples into a single test file that uses synthetic data and runs in L0

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia changed the title [JAX] E2E encoder sanity test with synthetic data [Draft][JAX] E2E encoder sanity test with synthetic data Oct 13, 2025
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L2 jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant