Skip to content

Commit

Permalink
add type annotations to binary functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Dec 2, 2024
1 parent 5ddba5b commit 17d8018
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_real_numeric_dtypes,
_numeric_dtypes,
_result_type,
_dtype_categories as _dtype_dtype_categories,
_dtype_categories,
)
from ._array_object import Array
from ._flags import requires_api_version
Expand Down Expand Up @@ -46,11 +46,26 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func):


def create_binary_func(func_name, dtype_category, np_func):
def inner(x1: Array, x2: Array, /) -> Array:
def inner(x1, x2, /) -> Array:
return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func)
return inner


# static type annotation for ArrayOrPythonScalar arguments given a category
# NB: keep the keys in sync with the _dtype_categories dict
_annotations = {
"all": "bool | int | float | complex | Array",
"real numeric": "int | float | Array",
"numeric": "int | float | complex | Array",
"integer": "int | Array",
"integer or boolean": "int | bool | Array",
"boolean": "bool | Array",
"real floating-point": "float | Array",
"complex floating-point": "complex | Array",
"floating-point": "float | complex | Array",
}


# func_name: dtype_category (must match that from _dtypes.py)
_binary_funcs = {
"add": "numeric",
Expand Down Expand Up @@ -97,7 +112,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
# create and attach functions to the module
for func_name, dtype_category in _binary_funcs.items():
# sanity check
assert dtype_category in _dtype_dtype_categories
assert dtype_category in _dtype_categories

numpy_name = _numpy_renames.get(func_name, func_name)
np_func = getattr(np, numpy_name)
Expand All @@ -106,6 +121,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
func.__name__ = func_name

func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name)
func.__annotations__['x1'] = _annotations[dtype_category]
func.__annotations__['x2'] = _annotations[dtype_category]

vars()[func_name] = func

Expand All @@ -117,20 +134,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821


def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
def bitwise_left_shift(x1: int | Array, x2: int | Array, /) -> Array:
is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
if is_negative:
raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
return _bitwise_left_shift(x1, x2) # noqa: F821
bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821
if _bitwise_left_shift.__doc__: # noqa: F821
bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821


def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array:
is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0
if is_negative:
raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
return _bitwise_right_shift(x1, x2) # noqa: F821
bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821
if _bitwise_right_shift.__doc__: # noqa: F821
bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821


# clean up to not pollute the namespace
Expand Down

0 comments on commit 17d8018

Please sign in to comment.