@@ -311,10 +311,6 @@ def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
311311 variables = {"params" : params , ** others }
312312 output = model .apply (variables , * diff_xs , * no_diff_xs , rngs = self .apply_rng )
313313 return jnp .mean (output , dtype = jnp .float32 ).astype (output .dtype )
314-
315- def _output_fn (self , params , others , model , diff_xs , no_diff_xs ):
316- variables = {"params" : params , ** others }
317- return model .apply (variables , * diff_xs , * no_diff_xs , rngs = self .apply_rng )
318314
319315 def _sync_params (self , ref , target ):
320316 """Copy the reference params to target"""
@@ -338,14 +334,11 @@ def test_forward(
338334 test_layer , test_params , test_others = self ._generate_layer (layer_cls , inputs , test_masks )
339335 ref_params , test_params = self ._sync_params (ref_params , test_params )
340336
341- ref_out = self ._output_fn (ref_params , ref_others , ref_layer , inputs , ref_masks )
342- test_out = self ._output_fn (test_params , test_others , test_layer , inputs , test_masks )
337+ ref_out = self ._loss_fn (ref_params , ref_others , ref_layer , inputs , ref_masks )
338+ test_out = self ._loss_fn (test_params , test_others , test_layer , inputs , test_masks )
343339
344340 tols = dtype_tols (dtype , rtol = rtol , atol = atol )
345- if not get_quantize_config ().is_fp8_enabled ():
346- assert_allclose (ref_out , test_out , ** tols )
347- else :
348- assert_allclose (ref_out .mean (), test_out .mean (), ** tols )
341+ assert_allclose (ref_out .mean (), test_out .mean (), ** tols )
349342
350343 def test_backward (
351344 self ,
0 commit comments