Skip to content

Commit c0a0947

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent 108ba4c commit c0a0947

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

tests/jax/test_layer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,6 @@ def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
311311
variables = {"params": params, **others}
312312
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
313313
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
314-
315-
def _output_fn(self, params, others, model, diff_xs, no_diff_xs):
316-
variables = {"params": params, **others}
317-
return model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
318314

319315
def _sync_params(self, ref, target):
320316
"""Copy the reference params to target"""
@@ -338,14 +334,11 @@ def test_forward(
338334
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
339335
ref_params, test_params = self._sync_params(ref_params, test_params)
340336

341-
ref_out = self._output_fn(ref_params, ref_others, ref_layer, inputs, ref_masks)
342-
test_out = self._output_fn(test_params, test_others, test_layer, inputs, test_masks)
337+
ref_out = self._loss_fn(ref_params, ref_others, ref_layer, inputs, ref_masks)
338+
test_out = self._loss_fn(test_params, test_others, test_layer, inputs, test_masks)
343339

344340
tols = dtype_tols(dtype, rtol=rtol, atol=atol)
345-
if not get_quantize_config().is_fp8_enabled():
346-
assert_allclose(ref_out, test_out, **tols)
347-
else:
348-
assert_allclose(ref_out.mean(), test_out.mean(), **tols)
341+
assert_allclose(ref_out.mean(), test_out.mean(), **tols)
349342

350343
def test_backward(
351344
self,

0 commit comments

Comments
 (0)