@@ -199,17 +199,19 @@ def test_hl_dot_codegen_acc_differs_uses_addition(self):
199199 self .assertIn ("out_dtype=tl.float32" , code )
200200
201201 # Test case 2: separate addition (acc_dtype = float16, common dtype = float32)
202- input_dtype_2 = torch .float32
203- acc_dtype_2 = torch .float16
204- x2 = torch .randn (64 , 64 , device = DEVICE , dtype = input_dtype_2 )
205- y2 = torch .randn (64 , 64 , device = DEVICE , dtype = input_dtype_2 )
206- code2 , out2 = code_and_output (dot_kernel_acc_arg , (x2 , y2 , acc_dtype_2 ))
207- # Validate we use separate addition pattern with cast
208- self .assertIn ("tl.dot(" , code2 )
209- # Check for the addition pattern: acc + result
210- self .assertIn (" + " , code2 )
211- # Check that we cast the result to acc_dtype
212- self .assertIn ("tl.cast" , code2 )
202+ # TODO(Eikan): Support this case on XPU
203+ if not torch .xpu .is_available ():
204+ input_dtype_2 = torch .float32
205+ acc_dtype_2 = torch .float16
206+ x2 = torch .randn (64 , 64 , device = DEVICE , dtype = input_dtype_2 )
207+ y2 = torch .randn (64 , 64 , device = DEVICE , dtype = input_dtype_2 )
208+ code2 , out2 = code_and_output (dot_kernel_acc_arg , (x2 , y2 , acc_dtype_2 ))
209+ # Validate we use separate addition pattern with cast
210+ self .assertIn ("tl.dot(" , code2 )
211+ # Check for the addition pattern: acc + result
212+ self .assertIn (" + " , code2 )
213+ # Check that we cast the result to acc_dtype
214+ self .assertIn ("tl.cast" , code2 )
213215
214216 # Test case 3: separate addition (acc_dtype = int32, common dtype = int8)
215217 input_dtype_3 = torch .int8
@@ -951,6 +953,17 @@ def test_matmul_reshape_n_2(self):
951953 REF_EAGER_TEST_FAILURES_FP8_E4M3FN_LOW_COMPUTE_CAP [test_name ]
952954 )(_test_func )
953955
956+ # Apply skipIfXPU decorator if needed
957+ if acc_dtype is torch .float16 and input_dtype in (
958+ torch .float8_e4m3fn ,
959+ torch .float8_e5m2 ,
960+ torch .bfloat16 ,
961+ torch .float32 ,
962+ ):
963+ _test_func = skipIfXPU ("skip: float6 accmulator for non-fp16 input data types" )(
964+ _test_func
965+ )
966+
954967 # Additional ref eager skips for unsupported accumulator/input combos
955968 if acc_dtype is torch .float16 and input_dtype in (
956969 torch .bfloat16 ,
0 commit comments