Skip to content

Commit

Permalink
Use better wording for the return dtypes for fft functions
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Dec 12, 2023
1 parent 5a14534 commit c51ce80
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/array_api_stubs/_draft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def fft(
Returns
-------
out: array
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axis (dimension) indicated by ``axis``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -111,7 +111,7 @@ def ifft(
Returns
-------
out: array
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axis (dimension) indicated by ``axis``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -169,7 +169,7 @@ def fftn(
Returns
-------
out: array
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axes (dimension) indicated by ``axes``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -227,7 +227,7 @@ def ifftn(
Returns
-------
out: array
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axes (dimension) indicated by ``axes``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -278,7 +278,7 @@ def rfft(
Returns
-------
out: array
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex-valued floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -329,7 +329,7 @@ def irfft(
Returns
-------
out: array
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. The length along the transformed axis is ``n`` (if given) or ``2*(m-1)`` (otherwise).
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have the same data type as ``x``. The length along the transformed axis is ``n`` (if given) or ``2*(m-1)`` (otherwise).
Notes
-----
Expand Down Expand Up @@ -387,7 +387,7 @@ def rfftn(
Returns
-------
out: array
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex-valued floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -445,7 +445,7 @@ def irfftn(
Returns
-------
out: array
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. The length along the last transformed axis is ``s[-1]`` (if given) or ``2*(m - 1)`` (otherwise), and all other axes ``s[i]``.
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a real-valued floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``complex128``, then the returned array must have a ``float64`` data type). The length along the last transformed axis is ``s[-1]`` (if given) or ``2*(m - 1)`` (otherwise), and all other axes ``s[i]``.
Notes
-----
Expand Down Expand Up @@ -493,7 +493,7 @@ def hfft(
Returns
-------
out: array
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axis (dimension) indicated by ``axis``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand Down Expand Up @@ -541,7 +541,7 @@ def ihfft(
Returns
-------
out: array
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex-valued floating-point data type determined by :ref:`type-promotion`.
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
Notes
-----
Expand All @@ -552,7 +552,7 @@ def ihfft(

def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> array:
"""
Returns the discrete Fourier transform sample frequencies.
Computes the discrete Fourier transform sample frequencies.
For a Fourier transform of length ``n`` and length unit of ``d`` the frequencies are described as:
Expand All @@ -573,7 +573,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> ar
Returns
-------
out: array
an array of length ``n`` containing the sample frequencies. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
an array of length ``n`` containing the sample frequencies. The returned array must have the default real-valued floating-point data type.
Notes
-----
Expand All @@ -584,7 +584,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> ar

def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> array:
"""
Returns the discrete Fourier transform sample frequencies (for ``rfft`` and ``irfft``).
Computes the discrete Fourier transform sample frequencies (for ``rfft`` and ``irfft``).
For a Fourier transform of length ``n`` and length unit of ``d`` the frequencies are described as:
Expand All @@ -607,7 +607,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> a
Returns
-------
out: array
an array of length ``n//2+1`` containing the sample frequencies. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
an array of length ``n//2+1`` containing the sample frequencies. The returned array must have the default real-valued floating-point data type.
Notes
-----
Expand Down

0 comments on commit c51ce80

Please sign in to comment.