Skip to content

Commit 983296f

Browse files
committed
Revert batmobile
1 parent 85fce08 commit 983296f

File tree

2 files changed

+111
-4
lines changed

2 files changed

+111
-4
lines changed

array_api_compat/common/_typing.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Mapping
34
from types import ModuleType as Namespace
45
from typing import (
56
TYPE_CHECKING,
@@ -105,11 +106,72 @@ def shape(self, /) -> _T_co: ...
105106
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
106107

107108

109+
# `__array_namespace_info__.dtypes(kind="bool")`
110+
class DTypesBool(TypedDict):
111+
bool: DType
112+
113+
114+
# `__array_namespace_info__.dtypes(kind="signed integer")`
115+
class DTypesSigned(TypedDict):
116+
int8: DType
117+
int16: DType
118+
int32: DType
119+
int64: DType
120+
121+
122+
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
123+
class DTypesUnsigned(TypedDict):
124+
uint8: DType
125+
uint16: DType
126+
uint32: DType
127+
uint64: DType
128+
129+
130+
# `__array_namespace_info__.dtypes(kind="integral")`
131+
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
132+
pass
133+
134+
135+
# `__array_namespace_info__.dtypes(kind="real floating")`
136+
class DTypesReal(TypedDict):
137+
float32: DType
138+
float64: DType
139+
140+
141+
# `__array_namespace_info__.dtypes(kind="complex floating")`
142+
class DTypesComplex(TypedDict):
143+
complex64: DType
144+
complex128: DType
145+
146+
147+
# `__array_namespace_info__.dtypes(kind="numeric")`
148+
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
149+
pass
150+
151+
152+
# `__array_namespace_info__.dtypes(kind=None)` (default)
153+
class DTypesAll(DTypesBool, DTypesNumeric):
154+
pass
155+
156+
157+
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
158+
DTypesAny: TypeAlias = Mapping[str, DType]
159+
160+
108161
__all__ = [
109162
"Array",
110163
"Capabilities",
111164
"DType",
112165
"DTypeKind",
166+
"DTypesAny",
167+
"DTypesAll",
168+
"DTypesBool",
169+
"DTypesNumeric",
170+
"DTypesIntegral",
171+
"DTypesSigned",
172+
"DTypesUnsigned",
173+
"DTypesReal",
174+
"DTypesComplex",
113175
"DefaultDTypes",
114176
"Device",
115177
"HasShape",

array_api_compat/dask/array/_info.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from __future__ import annotations
1414

15-
from typing import Literal, TypeAlias
15+
from typing import Literal, TypeAlias, overload
1616

1717
import dask.array as da
1818
from numpy import bool_ as bool
@@ -34,8 +34,21 @@
3434
)
3535

3636
from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device
37-
from ...common._typing import Capabilities, DefaultDTypes, DType, DTypeKind
38-
37+
from ...common._typing import (
38+
Capabilities,
39+
DefaultDTypes,
40+
DType,
41+
DTypeKind,
42+
DTypesAll,
43+
DTypesAny,
44+
DTypesBool,
45+
DTypesComplex,
46+
DTypesIntegral,
47+
DTypesNumeric,
48+
DTypesReal,
49+
DTypesSigned,
50+
DTypesUnsigned,
51+
)
3952
Device: TypeAlias = Literal["cpu"] | _dask_device
4053

4154

@@ -202,9 +215,41 @@ def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes:
202215
"indexing": dtype(intp),
203216
}
204217

218+
@overload
219+
def dtypes(
220+
self, /, *, device: Device | None = None, kind: None = None
221+
) -> DTypesAll: ...
222+
@overload
223+
def dtypes(
224+
self, /, *, device: Device | None = None, kind: Literal["bool"]
225+
) -> DTypesBool: ...
226+
@overload
227+
def dtypes(
228+
self, /, *, device: Device | None = None, kind: Literal["signed integer"]
229+
) -> DTypesSigned: ...
230+
@overload
231+
def dtypes(
232+
self, /, *, device: Device | None = None, kind: Literal["unsigned integer"]
233+
) -> DTypesUnsigned: ...
234+
@overload
235+
def dtypes(
236+
self, /, *, device: Device | None = None, kind: Literal["integral"]
237+
) -> DTypesIntegral: ...
238+
@overload
239+
def dtypes(
240+
self, /, *, device: Device | None = None, kind: Literal["real floating"]
241+
) -> DTypesReal: ...
242+
@overload
243+
def dtypes(
244+
self, /, *, device: Device | None = None, kind: Literal["complex floating"]
245+
) -> DTypesComplex: ...
246+
@overload
247+
def dtypes(
248+
self, /, *, device: Device | None = None, kind: Literal["numeric"]
249+
) -> DTypesNumeric: ...
205250
def dtypes(
206251
self, /, *, device: Device | None = None, kind: DTypeKind | None = None
207-
) -> dict[str, DType]:
252+
) -> DTypesAny:
208253
"""
209254
The array API data types supported by Dask.
210255

0 commit comments

Comments
 (0)