Skip to content

Commit 2b037b9

Browse files
committed
Added gemm matrix accumulation matrix into interface and tests
1 parent 983500e commit 2b037b9

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

arrayfire_wrapper/lib/linear_algebra/blas_operations.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from arrayfire_wrapper.dtypes import c_api_value_to_dtype, complex32, complex64, float32, float64
66
from arrayfire_wrapper.lib._constants import MatProp
77
from arrayfire_wrapper.lib._utility import call_from_clib
8-
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_type
8+
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_type, copy_array
99

1010

1111
def dot(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, /) -> AFArray:
@@ -29,11 +29,17 @@ def dot_all(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, /)
2929
return real.value if imag.value == 0 else real.value + imag.value * 1j
3030

3131

32-
def gemm(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, alpha: Any, beta: Any, /) -> AFArray:
32+
def gemm(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, alpha: Any, beta: Any, accum: AFArray | None, /) -> AFArray:
3333
"""
3434
source: https://arrayfire.org/docs/group__blas__func__matmul.htm#ga0463ae584163128718237b02faf5caf7
3535
"""
36-
out = AFArray.create_null_pointer()
36+
out = None
37+
if not accum is None:
38+
out = copy_array(accum)
39+
else:
40+
beta = 0.0
41+
out = AFArray.create_null_pointer()
42+
3743
lhs_dtype = c_api_value_to_dtype(get_type(lhs))
3844

3945
type_mapping = {

tests/test_blas.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def test_gemm_correct_shape_2d(shape_pairs: list) -> None:
281281
y = wrapper.randu(shape_pairs[1], dtype)
282282

283283
result_shape = (shape_pairs[0][0], shape_pairs[1][1])
284-
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
284+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
285285

286286
assert wrapper.get_dims(result)[0:2] == result_shape
287287

@@ -302,7 +302,7 @@ def test_gemm_correct_shape_3d(shape_pairs: list) -> None:
302302
y = wrapper.randu(shape_pairs[1], dtype)
303303
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2])
304304

305-
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
305+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
306306
assert wrapper.get_dims(result)[0:3] == result_shape
307307

308308

@@ -322,7 +322,7 @@ def test_gemm_correct_shape_4d(shape_pairs: list) -> None:
322322
y = wrapper.randu(shape_pairs[1], dtype)
323323
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2], shape_pairs[0][3])
324324

325-
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
325+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
326326
assert wrapper.get_dims(result)[0:4] == result_shape
327327

328328

@@ -339,7 +339,7 @@ def test_gemm_correct_dtype(dtype: dtypes.Dtype) -> None:
339339
x = wrapper.randu(shape, dtype)
340340
y = wrapper.randu(shape, dtype)
341341

342-
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
342+
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
343343

344344
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
345345

@@ -361,7 +361,7 @@ def test_gemm_invalid_pair(shape_pairs: list) -> None:
361361
x = wrapper.randu(shape_pairs[0], dtype)
362362
y = wrapper.randu(shape_pairs[1], dtype)
363363

364-
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
364+
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
365365

366366

367367
def test_gemm_empty_shape() -> None:
@@ -371,7 +371,7 @@ def test_gemm_empty_shape() -> None:
371371
dtype = dtypes.f32
372372

373373
x = wrapper.randu(empty_shape, dtype)
374-
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1)
374+
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1, None)
375375

376376

377377
@pytest.mark.parametrize(
@@ -390,7 +390,7 @@ def test_gemm_invalid_dtype(dtype_index: int) -> None:
390390
x = wrapper.randu(shape, dtype)
391391
y = wrapper.randu(shape, dtype)
392392

393-
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
393+
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
394394

395395

396396
def test_gemm_empty_matrix() -> None:
@@ -400,7 +400,7 @@ def test_gemm_empty_matrix() -> None:
400400
dtype = dtypes.f32
401401

402402
x = wrapper.randu(empty_shape, dtype)
403-
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1)
403+
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1, None)
404404

405405

406406
# matmul tests

0 commit comments

Comments
 (0)