Skip to content

Commit 715b4bf

Browse files
committed
Add more special functions to BaseFakeNumpyNamespace
And use positional-only where appropriate
1 parent 4817722 commit 715b4bf

File tree

5 files changed

+118
-83
lines changed

5 files changed

+118
-83
lines changed

arraycontext/fake_numpy.py

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -313,54 +313,101 @@ def inner(a: ArrayOrScalar) -> ArrayOrScalar:
313313
# as attributes, making __getattr__ fail to retrieve the intended function.
314314

315315
def broadcast_to(self,
316-
array: ArrayOrContainerOrScalar,
316+
array: ArrayOrContainerOrScalar, /,
317317
shape: tuple[int, ...]
318318
) -> ArrayOrContainerOrScalar: ...
319319

320320
def concatenate(self,
321-
arrays: Sequence[ArrayOrContainerT],
321+
arrays: Sequence[ArrayOrContainerT], /,
322322
axis: int = 0
323323
) -> ArrayOrContainerT: ...
324324

325325
def stack(self,
326-
arrays: Sequence[ArrayOrContainerT],
326+
arrays: Sequence[ArrayOrContainerT], /,
327327
axis: int = 0
328328
) -> ArrayOrContainerT: ...
329329

330330
def ravel(self,
331-
a: ArrayOrContainerOrScalarT,
331+
a: ArrayOrContainerOrScalarT, /,
332332
order: OrderCF = "C"
333333
) -> ArrayOrContainerOrScalarT: ...
334334

335335
def array_equal(self,
336336
a: ArrayOrContainerOrScalar,
337-
b: ArrayOrContainerOrScalar
337+
b: ArrayOrContainerOrScalar,
338+
/
338339
) -> Array: ...
339340

340-
def sqrt(self,
341-
a: ArrayOrContainerOrScalarT,
341+
def sqrt(self, a: ArrayOrContainerOrScalarT, /,
342+
) -> ArrayOrContainerOrScalarT: ...
343+
def abs(self, a: ArrayOrContainerOrScalarT, /,
342344
) -> ArrayOrContainerOrScalarT: ...
343345

344-
def abs(self,
345-
a: ArrayOrContainerOrScalarT,
346+
def sin(self, a: ArrayOrContainerOrScalarT, /,
347+
) -> ArrayOrContainerOrScalarT: ...
348+
def cos(self, a: ArrayOrContainerOrScalarT, /,
349+
) -> ArrayOrContainerOrScalarT: ...
350+
def tan(self, a: ArrayOrContainerOrScalarT, /,
351+
) -> ArrayOrContainerOrScalarT: ...
352+
def arcsin(self, a: ArrayOrContainerOrScalarT, /,
353+
) -> ArrayOrContainerOrScalarT: ...
354+
def arccos(self, a: ArrayOrContainerOrScalarT, /,
355+
) -> ArrayOrContainerOrScalarT: ...
356+
def arctan(self, a: ArrayOrContainerOrScalarT, /,
346357
) -> ArrayOrContainerOrScalarT: ...
347358

348-
def sin(self,
359+
def hypot(self,
349360
a: ArrayOrContainerOrScalarT,
361+
b: ArrayOrContainerOrScalarT,
362+
/,
350363
) -> ArrayOrContainerOrScalarT: ...
351-
352-
def cos(self,
364+
def arctan2(self,
353365
a: ArrayOrContainerOrScalarT,
366+
b: ArrayOrContainerOrScalarT,
367+
/,
354368
) -> ArrayOrContainerOrScalarT: ...
355369

356-
def floor(self,
357-
a: ArrayOrContainerOrScalarT,
370+
def deg2rad(self, a: ArrayOrContainerOrScalarT, /,
371+
) -> ArrayOrContainerOrScalarT: ...
372+
def rad2deg(self, a: ArrayOrContainerOrScalarT, /,
358373
) -> ArrayOrContainerOrScalarT: ...
359374

360-
def ceil(self,
361-
a: ArrayOrContainerOrScalarT,
375+
def sinh(self, a: ArrayOrContainerOrScalarT, /,
376+
) -> ArrayOrContainerOrScalarT: ...
377+
def cosh(self, a: ArrayOrContainerOrScalarT, /,
378+
) -> ArrayOrContainerOrScalarT: ...
379+
def tanh(self, a: ArrayOrContainerOrScalarT, /,
380+
) -> ArrayOrContainerOrScalarT: ...
381+
def arcsinh(self, a: ArrayOrContainerOrScalarT, /,
382+
) -> ArrayOrContainerOrScalarT: ...
383+
def arccosh(self, a: ArrayOrContainerOrScalarT, /,
384+
) -> ArrayOrContainerOrScalarT: ...
385+
def arctanh(self, a: ArrayOrContainerOrScalarT, /,
362386
) -> ArrayOrContainerOrScalarT: ...
363387

388+
def ceil(self, a: ArrayOrContainerOrScalarT, /,
389+
) -> ArrayOrContainerOrScalarT: ...
390+
def floor(self, a: ArrayOrContainerOrScalarT, /,
391+
) -> ArrayOrContainerOrScalarT: ...
392+
393+
def exp(self, a: ArrayOrContainerOrScalarT, /
394+
) -> ArrayOrContainerOrScalarT: ...
395+
def expm1(self, a: ArrayOrContainerOrScalarT, /
396+
) -> ArrayOrContainerOrScalarT: ...
397+
def exp2(self, a: ArrayOrContainerOrScalarT, /
398+
) -> ArrayOrContainerOrScalarT: ...
399+
def log(self, a: ArrayOrContainerOrScalarT, /
400+
) -> ArrayOrContainerOrScalarT: ...
401+
def log10(self, a: ArrayOrContainerOrScalarT, /
402+
) -> ArrayOrContainerOrScalarT: ...
403+
def log2(self, a: ArrayOrContainerOrScalarT, /
404+
) -> ArrayOrContainerOrScalarT: ...
405+
def log1p(self, a: ArrayOrContainerOrScalarT, /
406+
) -> ArrayOrContainerOrScalarT: ...
407+
def logaddexp(self, a: ArrayOrContainerOrScalarT, /
408+
) -> ArrayOrContainerOrScalarT: ...
409+
def logaddexp2(self, a: ArrayOrContainerOrScalarT, /
410+
) -> ArrayOrContainerOrScalarT: ...
364411
# {{{ binary/ternary ufuncs
365412

366413
# FIXME: These are more restrictive than necessary, but they'll do the job
@@ -397,73 +444,71 @@ def where(self,
397444

398445
@overload
399446
def sum(self,
400-
a: ArrayOrContainer,
447+
a: ArrayOrContainer, /,
401448
axis: int | tuple[int, ...] | None = None,
402449
dtype: DTypeLike = None,
403450
) -> Array: ...
404451
@overload
405452
def sum(self,
406-
a: ScalarLike,
453+
a: ScalarLike, /,
407454
axis: int | tuple[int, ...] | None = None,
408455
dtype: DTypeLike = None,
409456
) -> ScalarLike: ...
410457

411458
def sum(self,
412-
a: ArrayOrContainerOrScalar,
459+
a: ArrayOrContainerOrScalar, /,
413460
axis: int | tuple[int, ...] | None = None,
414461
dtype: DTypeLike = None,
415462
) -> ArrayOrScalar: ...
416463

417464
@overload
418465
def min(self,
419-
a: ArrayOrContainer,
466+
a: ArrayOrContainer, /,
420467
axis: int | tuple[int, ...] | None = None,
421468
) -> Array: ...
422469
@overload
423470
def min(self,
424-
a: ScalarLike,
471+
a: ScalarLike, /,
425472
axis: int | tuple[int, ...] | None = None,
426473
) -> ScalarLike: ...
427474

428475
def min(self,
429-
a: ArrayOrContainerOrScalar,
476+
a: ArrayOrContainerOrScalar, /,
430477
axis: int | tuple[int, ...] | None = None,
431478
) -> ArrayOrScalar: ...
432479

433480
@overload
434481
def max(self,
435-
a: ArrayOrContainer,
482+
a: ArrayOrContainer, /,
436483
axis: int | tuple[int, ...] | None = None,
437484
) -> Array: ...
438485
@overload
439486
def max(self,
440-
a: ScalarLike,
487+
a: ScalarLike, /,
441488
axis: int | tuple[int, ...] | None = None,
442489
) -> ScalarLike: ...
443490

444491
def max(self,
445-
a: ArrayOrContainerOrScalar,
492+
a: ArrayOrContainerOrScalar, /,
446493
axis: int | tuple[int, ...] | None = None,
447494
) -> ArrayOrScalar: ...
448495

449496
@deprecated("use min instead")
450497
def amin(self,
451-
a: ArrayOrContainerOrScalar,
498+
a: ArrayOrContainerOrScalar, /,
452499
axis: int | tuple[int, ...] | None = None,
453500
) -> ArrayOrScalar: ...
454501

455502
@deprecated("use max instead")
456503
def amax(self,
457-
a: ArrayOrContainerOrScalar,
504+
a: ArrayOrContainerOrScalar, /,
458505
axis: int | tuple[int, ...] | None = None,
459506
) -> ArrayOrScalar: ...
460507

461-
def any(self,
462-
a: ArrayOrContainerOrScalar,
508+
def any(self, a: ArrayOrContainerOrScalar, /,
463509
) -> ArrayOrScalar: ...
464510

465-
def all(self,
466-
a: ArrayOrContainerOrScalar,
511+
def all(self, a: ArrayOrContainerOrScalar, /,
467512
) -> ArrayOrScalar: ...
468513

469514
# }}}
@@ -475,50 +520,40 @@ def all(self,
475520
# These operations provide access to numpy-style comparisons in that
476521
# case.
477522

478-
def greater(
479-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
480-
) -> ArrayOrContainerOrScalar:
481-
...
482-
483-
def greater_equal(
484-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
485-
) -> ArrayOrContainerOrScalar:
486-
...
487-
488-
def less(
489-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
490-
) -> ArrayOrContainerOrScalar:
491-
...
492-
493-
def less_equal(
494-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
495-
) -> ArrayOrContainerOrScalar:
496-
...
497-
498-
def equal(
499-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
500-
) -> ArrayOrContainerOrScalar:
501-
...
502-
503-
def not_equal(
504-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
505-
) -> ArrayOrContainerOrScalar:
506-
...
507-
508-
def logical_or(
509-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
510-
) -> ArrayOrContainerOrScalar:
511-
...
512-
513-
def logical_and(
514-
self, x: ArrayOrContainerOrScalar, y: ArrayOrContainerOrScalar
515-
) -> ArrayOrContainerOrScalar:
516-
...
517-
518-
def logical_not(
519-
self, x: ArrayOrContainerOrScalar
520-
) -> ArrayOrContainerOrScalar:
521-
...
523+
def greater(self,
524+
x: ArrayOrContainerOrScalar,
525+
y: ArrayOrContainerOrScalar, /,
526+
) -> ArrayOrContainerOrScalar: ...
527+
def greater_equal(self,
528+
x: ArrayOrContainerOrScalar,
529+
y: ArrayOrContainerOrScalar, /
530+
) -> ArrayOrContainerOrScalar: ...
531+
def less(self,
532+
x: ArrayOrContainerOrScalar,
533+
y: ArrayOrContainerOrScalar, /
534+
) -> ArrayOrContainerOrScalar: ...
535+
def less_equal(self,
536+
x: ArrayOrContainerOrScalar,
537+
y: ArrayOrContainerOrScalar, /
538+
) -> ArrayOrContainerOrScalar: ...
539+
def equal(self,
540+
x: ArrayOrContainerOrScalar,
541+
y: ArrayOrContainerOrScalar, /
542+
) -> ArrayOrContainerOrScalar: ...
543+
def not_equal(self,
544+
x: ArrayOrContainerOrScalar,
545+
y: ArrayOrContainerOrScalar, /
546+
) -> ArrayOrContainerOrScalar: ...
547+
def logical_or(self,
548+
x: ArrayOrContainerOrScalar,
549+
y: ArrayOrContainerOrScalar, /
550+
) -> ArrayOrContainerOrScalar: ...
551+
def logical_and(self,
552+
x: ArrayOrContainerOrScalar,
553+
y: ArrayOrContainerOrScalar, /
554+
) -> ArrayOrContainerOrScalar: ...
555+
def logical_not(self, x: ArrayOrContainerOrScalar, /
556+
) -> ArrayOrContainerOrScalar: ...
522557

523558
# }}}
524559

arraycontext/impl/jax/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ def _rec_vdot(ary1, ary2):
157157

158158
# {{{ logic functions
159159

160-
def all(self, a):
160+
def all(self, a, /):
161161
return rec_map_reduce_array_container(
162162
partial(reduce, jnp.logical_and), jnp.all, a)
163163

164-
def any(self, a):
164+
def any(self, a, /):
165165
return rec_map_reduce_array_container(
166166
partial(reduce, jnp.logical_or), jnp.any, a)
167167

arraycontext/impl/numpy/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,11 @@ def inner_ravel(ary: ArrayOrScalar) -> ArrayOrScalar:
236236
def vdot(self, x, y):
237237
return rec_multimap_reduce_array_container(sum, np.vdot, x, y)
238238

239-
def any(self, a):
239+
def any(self, a, /):
240240
return rec_map_reduce_array_container(partial(reduce, np.logical_or),
241241
lambda subary: np.any(subary), a)
242242

243-
def all(self, a):
243+
def all(self, a, /):
244244
return rec_map_reduce_array_container(partial(reduce, np.logical_and),
245245
lambda subary: np.all(subary), a)
246246

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def vdot(self, x, y, dtype=None):
197197

198198
# {{{ logic functions
199199

200-
def all(self, a):
200+
def all(self, a, /):
201201
queue = self._array_context.queue
202202

203203
def _all(ary):
@@ -210,7 +210,7 @@ def _all(ary):
210210
_all,
211211
a)
212212

213-
def any(self, a):
213+
def any(self, a, /):
214214
queue = self._array_context.queue
215215

216216
def _any(ary):

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ def stack(self, arrays, axis=0):
173173

174174
# {{{ logic functions
175175

176-
def all(self, a):
176+
def all(self, a, /):
177177
return rec_map_reduce_array_container(
178178
partial(reduce, pt.logical_and),
179179
lambda subary: pt.all(subary), a)
180180

181-
def any(self, a):
181+
def any(self, a, /):
182182
return rec_map_reduce_array_container(
183183
partial(reduce, pt.logical_or),
184184
lambda subary: pt.any(subary), a)

0 commit comments

Comments
 (0)