Skip to content

Commit

Permalink
Add result_type and cast_to_result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Aug 27, 2024
1 parent 40a5ff2 commit 6d04310
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._src.display.display_generic import display_generic
from ._src.display.internal import internal_print_generic
from ._src.display.print_generic import print_generic
from ._src.dtype_tools import cast_to_result_type, result_type
from ._src.dtypes import default_atol, default_rtol, default_tols
from ._src.graph import (graph_arrow, graph_edge_name, register_graph_as_jax_pytree,
register_graph_as_nnx_node)
Expand All @@ -27,6 +28,7 @@
from ._src.shims import custom_jvp, custom_jvp_method, custom_vjp, custom_vjp_method, hessian, jit
from ._src.testing import (assert_tree_allclose, get_relative_test_string, get_test_string,
tree_allclose)
from ._src.tree_tools import dynamic_tree_all

__all__ = [
'Array',
Expand Down Expand Up @@ -68,6 +70,7 @@
'abstract_custom_jvp',
'abstract_jit',
'assert_tree_allclose',
'cast_to_result_type',
'copy_cotangent',
'cotangent_combinator',
'create_diagonal_array',
Expand All @@ -84,6 +87,7 @@
'display_generic',
'divide_nonnegative',
'divide_where',
'dynamic_tree_all',
'fork_streams',
'get_relative_test_string',
'get_test_string',
Expand All @@ -106,6 +110,7 @@
'register_graph_as_jax_pytree',
'register_graph_as_nnx_node',
'replace_cotangent',
'result_type',
'scale_cotangent',
'softplus',
'tree_allclose',
Expand Down
55 changes: 55 additions & 0 deletions tjax/_src/dtype_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from typing import TypeVar

import jax
import jax.numpy as jnp
from jax.dtypes import canonicalize_dtype

_T = TypeVar('_T', bound=tuple[None | jax.Array, ...])


def result_type(*args: None | jax.Array,
dtype: jax.typing.DTypeLike | None = None,
ensure_inexact: bool = True
) -> jax.typing.DTypeLike:
"""Find a common data type for arrays.
Args:
*args: Arrays to consider (or None).
dtype: Overrides the inferred data type.
ensure_inexact: Ensures that the result type is inexact, or raises.
Returns:
The common data type of *args.
"""
if dtype is None:
filtered_args = [x for x in args if x is not None]
dtype = jnp.result_type(*filtered_args)
if ensure_inexact and not jnp.issubdtype(dtype, jnp.inexact):
dtype = jnp.promote_types(canonicalize_dtype(float), dtype)
if ensure_inexact and not jnp.issubdtype(dtype, jnp.inexact):
msg = f'Data type must be inexact: {dtype}'
raise ValueError(msg)
return dtype


def cast_to_result_type(arrays: _T,
/,
*,
dtype: None | jax.typing.DTypeLike = None,
ensure_inexact: bool = True
) -> _T:
"""Casts the input arrays to a common data dtype.
Args:
arrays: Arrays to be converted (or None).
dtype: Overrides the inferred data type.
ensure_inexact: Ensures that the result type is inexact, or raises.
Returns:
The arrays cast to the inferred type.
"""
dtype = result_type(*arrays, dtype=dtype, ensure_inexact=ensure_inexact)
return tuple(jnp.asarray(x, dtype) # type: ignore # pyright: ignore
if x is not None else None for x in arrays)

0 comments on commit 6d04310

Please sign in to comment.