Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dtype kwarg support to fftfreq and rfftfreq #885

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/array_api_stubs/_draft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"ifftshift",
]

from ._types import Tuple, Union, Sequence, array, Optional, Literal, device
from ._types import Tuple, Union, Sequence, array, Optional, Literal, dtype, device


def fft(
Expand Down Expand Up @@ -551,7 +551,14 @@ def ihfft(
"""


def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> array:
def fftfreq(
n: int,
/,
*,
d: float = 1.0,
dtype: Optional[dtype] = None,
device: Optional[device] = None,
) -> array:
"""
Computes the discrete Fourier transform sample frequencies.

Expand All @@ -568,13 +575,15 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> ar
window length.
d: float
sample spacing between individual samples of the Fourier transform input. Default: ``1.0``.
dtype: Optional[dtype]
output array data type. Must be a real-valued floating-point data type. If ``dtype`` is ``None``, the output array data type must be the default real-valued floating-point data type. Default: ``None``.
device: Optional[device]
device on which to place the created array. Default: ``None``.

Returns
-------
out: array
an array of shape ``(n,)`` containing the sample frequencies. The returned array must have the default real-valued floating-point data type.
an array of shape ``(n,)`` containing the sample frequencies.

Notes
-----
Expand All @@ -586,7 +595,14 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> ar
"""


def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> array:
def rfftfreq(
n: int,
/,
*,
d: float = 1.0,
dtype: Optional[dtype] = None,
device: Optional[device] = None,
) -> array:
"""
Computes the discrete Fourier transform sample frequencies (for ``rfft`` and ``irfft``).

Expand All @@ -605,13 +621,15 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> a
window length.
d: float
sample spacing between individual samples of the Fourier transform input. Default: ``1.0``.
dtype: Optional[dtype]
output array data type. Must be a real-valued floating-point data type. If ``dtype`` is ``None``, the output array data type must be the default real-valued floating-point data type. Default: ``None``.
device: Optional[device]
device on which to place the created array. Default: ``None``.

Returns
-------
out: array
an array of shape ``(n//2+1,)`` containing the sample frequencies. The returned array must have the default real-valued floating-point data type.
an array of shape ``(n//2+1,)`` containing the sample frequencies.

Notes
-----
Expand Down
Loading