@@ -311,6 +311,10 @@ def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
311311 variables = {"params" : params , ** others }
312312 output = model .apply (variables , * diff_xs , * no_diff_xs , rngs = self .apply_rng )
313313 return jnp .mean (output , dtype = jnp .float32 ).astype (output .dtype )
314+
315+ def _output_fn (self , params , others , model , diff_xs , no_diff_xs ):
316+ variables = {"params" : params , ** others }
317+ return model .apply (variables , * diff_xs , * no_diff_xs , rngs = self .apply_rng )
314318
315319 def _sync_params (self , ref , target ):
316320 """Copy the reference params to target"""
@@ -334,11 +338,14 @@ def test_forward(
334338 test_layer , test_params , test_others = self ._generate_layer (layer_cls , inputs , test_masks )
335339 ref_params , test_params = self ._sync_params (ref_params , test_params )
336340
337- ref_out = self ._loss_fn ( inputs , ref_masks , ref_params , ref_others , ref_layer )
338- test_out = self ._loss_fn ( inputs , test_masks , test_params , test_others , test_layer )
341+ ref_out = self ._output_fn ( ref_params , ref_others , ref_layer , inputs , ref_masks )
342+ test_out = self ._output_fn ( test_params , test_others , test_layer , inputs , test_masks )
339343
340344 tols = dtype_tols (dtype , rtol = rtol , atol = atol )
341- assert_allclose (ref_out , test_out , ** tols )
345+ if not get_quantize_config ().is_fp8_enabled ():
346+ assert_allclose (ref_out , test_out , ** tols )
347+ else :
348+ assert_allclose (ref_out .mean (), test_out .mean (), ** tols )
342349
343350 def test_backward (
344351 self ,
0 commit comments