Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Flax with compute dtype inferred from input dtype. #1485

Merged
merged 1 commit into from
Feb 18, 2025

Conversation

phu0ngng
Copy link
Collaborator

Description

Flax modules should use dtype to initialize their parameters while the compute type should depend on the input data type.
This PR adds the capability to infer the compute type from the input type.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested a review from ptrendx February 14, 2025 23:16
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
if not FP8Helper.is_fp8_enabled():
lora_b_kernel = lora_b_kernel.astype(input_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are Lora weights also applied in FP8? It's possible they are not and then this cast would be required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -265,8 +265,8 @@ def test_forward(
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is reference_layer following the correct logic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated reference layer in utils.py

@ptrendx ptrendx added the 2.1.0 label Feb 15, 2025
@phu0ngng
Copy link
Collaborator Author

/te-ci jax L1

3 similar comments
@phu0ngng
Copy link
Collaborator Author

/te-ci jax L1

@phu0ngng
Copy link
Collaborator Author

/te-ci jax L1

@phu0ngng
Copy link
Collaborator Author

/te-ci jax L1

@phu0ngng phu0ngng force-pushed the flax_with_input_dtype branch from 77f2edb to d41d6f8 Compare February 18, 2025 23:37
@phu0ngng phu0ngng merged commit 6673f16 into NVIDIA:main Feb 18, 2025
11 checks passed
@phu0ngng phu0ngng deleted the flax_with_input_dtype branch February 18, 2025 23:49
ptrendx pushed a commit that referenced this pull request Feb 19, 2025
flax module with compute dtype inferred from the inputs

Signed-off-by: Phuong Nguyen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants