Skip to content

Commit 108ba4c

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent 8c68a47 commit 108ba4c

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tests/jax/test_layer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
311311
variables = {"params": params, **others}
312312
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
313313
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
314+
315+
def _output_fn(self, params, others, model, diff_xs, no_diff_xs):
316+
variables = {"params": params, **others}
317+
return model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
314318

315319
def _sync_params(self, ref, target):
316320
"""Copy the reference params to target"""
@@ -334,11 +338,14 @@ def test_forward(
334338
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
335339
ref_params, test_params = self._sync_params(ref_params, test_params)
336340

337-
ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
338-
test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)
341+
ref_out = self._output_fn(ref_params, ref_others, ref_layer, inputs, ref_masks)
342+
test_out = self._output_fn(test_params, test_others, test_layer, inputs, test_masks)
339343

340344
tols = dtype_tols(dtype, rtol=rtol, atol=atol)
341-
assert_allclose(ref_out, test_out, **tols)
345+
if not get_quantize_config().is_fp8_enabled():
346+
assert_allclose(ref_out, test_out, **tols)
347+
else:
348+
assert_allclose(ref_out.mean(), test_out.mean(), **tols)
342349

343350
def test_backward(
344351
self,

0 commit comments

Comments
 (0)