-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Flax with compute dtype inferred from input dtype. #1485
Conversation
axes=lora_b_kernel_axes, | ||
) | ||
lora_b_kernel = lora_b_kernel.astype(self.dtype) | ||
if not FP8Helper.is_fp8_enabled(): | ||
lora_b_kernel = lora_b_kernel.astype(input_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are Lora weights also applied in FP8? It's possible they are not and then this cast would be required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
@@ -265,8 +265,8 @@ def test_forward( | |||
"""Test only the forward""" | |||
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) | |||
|
|||
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) | |||
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) | |||
ref_layer_cls = partial(self.reference_layer, **self.attrs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is reference_layer following the correct logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated reference layer in utils.py
/te-ci jax L1 |
3 similar comments
/te-ci jax L1 |
/te-ci jax L1 |
/te-ci jax L1 |
Signed-off-by: Phuong Nguyen <[email protected]>
77f2edb
to
d41d6f8
Compare
flax module with compute dtype inferred from the inputs Signed-off-by: Phuong Nguyen <[email protected]>
Description
Flax modules should use
dtype
to initialize their parameters while the compute type should depend on the input data type.This PR adds the capability to infer the compute type from the input type.
Type of change
Checklist: