@@ -281,7 +281,7 @@ def test_gemm_correct_shape_2d(shape_pairs: list) -> None:
281
281
y = wrapper .randu (shape_pairs [1 ], dtype )
282
282
283
283
result_shape = (shape_pairs [0 ][0 ], shape_pairs [1 ][1 ])
284
- result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 )
284
+ result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
285
285
286
286
assert wrapper .get_dims (result )[0 :2 ] == result_shape
287
287
@@ -302,7 +302,7 @@ def test_gemm_correct_shape_3d(shape_pairs: list) -> None:
302
302
y = wrapper .randu (shape_pairs [1 ], dtype )
303
303
result_shape = (shape_pairs [0 ][0 ], shape_pairs [1 ][1 ], shape_pairs [0 ][2 ])
304
304
305
- result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 )
305
+ result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
306
306
assert wrapper .get_dims (result )[0 :3 ] == result_shape
307
307
308
308
@@ -322,7 +322,7 @@ def test_gemm_correct_shape_4d(shape_pairs: list) -> None:
322
322
y = wrapper .randu (shape_pairs [1 ], dtype )
323
323
result_shape = (shape_pairs [0 ][0 ], shape_pairs [1 ][1 ], shape_pairs [0 ][2 ], shape_pairs [0 ][3 ])
324
324
325
- result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 )
325
+ result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
326
326
assert wrapper .get_dims (result )[0 :4 ] == result_shape
327
327
328
328
@@ -339,7 +339,7 @@ def test_gemm_correct_dtype(dtype: dtypes.Dtype) -> None:
339
339
x = wrapper .randu (shape , dtype )
340
340
y = wrapper .randu (shape , dtype )
341
341
342
- result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 )
342
+ result = wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
343
343
344
344
assert dtypes .c_api_value_to_dtype (wrapper .get_type (result )) == dtype
345
345
@@ -361,7 +361,7 @@ def test_gemm_invalid_pair(shape_pairs: list) -> None:
361
361
x = wrapper .randu (shape_pairs [0 ], dtype )
362
362
y = wrapper .randu (shape_pairs [1 ], dtype )
363
363
364
- wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 )
364
+ wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
365
365
366
366
367
367
def test_gemm_empty_shape () -> None :
@@ -371,7 +371,7 @@ def test_gemm_empty_shape() -> None:
371
371
dtype = dtypes .f32
372
372
373
373
x = wrapper .randu (empty_shape , dtype )
374
- wrapper .gemm (x , x , MatProp .NONE , MatProp .NONE , 1 , 1 )
374
+ wrapper .gemm (x , x , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
375
375
376
376
377
377
@pytest .mark .parametrize (
@@ -390,7 +390,7 @@ def test_gemm_invalid_dtype(dtype_index: int) -> None:
390
390
x = wrapper .randu (shape , dtype )
391
391
y = wrapper .randu (shape , dtype )
392
392
393
- wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 )
393
+ wrapper .gemm (x , y , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
394
394
395
395
396
396
def test_gemm_empty_matrix () -> None :
@@ -400,7 +400,7 @@ def test_gemm_empty_matrix() -> None:
400
400
dtype = dtypes .f32
401
401
402
402
x = wrapper .randu (empty_shape , dtype )
403
- wrapper .gemm (x , x , MatProp .NONE , MatProp .NONE , 1 , 1 )
403
+ wrapper .gemm (x , x , MatProp .NONE , MatProp .NONE , 1 , 1 , None )
404
404
405
405
406
406
# matmul tests
0 commit comments