Skip to content

Commit

Permalink
Fix flax object display
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Aug 24, 2024
1 parent 0aed709 commit a2ee0b8
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 11 deletions.
63 changes: 63 additions & 0 deletions tests/test_flax_display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from io import StringIO

import jax.numpy as jnp
from flax import nnx
from pytest import CaptureFixture
from rich.console import Console

from tjax import JaxRealArray, print_generic
from tjax.dataclasses import DataClassModule, field

from .test_display import verify


def test_module_display(capsys: CaptureFixture[str],
console: Console) -> None:
class C(nnx.Module):
def __init__(self) -> None:
super().__init__()
self.x = jnp.zeros(4)
self.y = 'abc'
self.a = nnx.Variable(jnp.zeros(4))

c = C()
print_generic(c=c, console=console, immediate=True)
assert isinstance(console.file, StringIO)
captured = console.file.getvalue()
verify(captured,
"""
c=C[flax-module]
├── x=Jax Array (4,) float64
│ └── 0.0000 │ 0.0000 │ 0.0000 │ 0.0000
├── y="abc"
└── a=Variable
└── value=Jax Array (4,) float64
└── 0.0000 │ 0.0000 │ 0.0000 │ 0.0000
""")


def test_dataclass_module_display(capsys: CaptureFixture[str],
console: Console) -> None:
class C(DataClassModule):
x: JaxRealArray
y: str = field(static=True)

def __post_init__(self, rngs: nnx.Rngs) -> None:
if hasattr(super(), '__post_init__'):
super().__post_init__(rngs)
self.a = nnx.Variable(jnp.zeros(4))

c = C(jnp.zeros(4), 'abc', rngs=nnx.Rngs())
print_generic(c=c, console=console)
assert isinstance(console.file, StringIO)
captured = console.file.getvalue()
verify(captured,
"""
c=C[dataclass,flax-module]
├── a=Variable
│ └── value=Jax Array (4,) float64
│ └── 0.0000 │ 0.0000 │ 0.0000 │ 0.0000
├── x=Jax Array (4,) float64
│ └── 0.0000 │ 0.0000 │ 0.0000 │ 0.0000
└── y="abc"
""")
34 changes: 23 additions & 11 deletions tjax/_src/display/display_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,7 @@ def is_node_type(x: type[Any]) -> bool:


def attribute_filter(value: Any, attribute_name: str) -> bool:
is_private = attribute_name.startswith('_')
if flax_loaded:
from flax import nnx # noqa: PLC0415
if isinstance(value, nnx.State) and is_private:
return False
if (isinstance(value, nnx.Variable | nnx.VariableState)
and (is_private or attribute_name.endswith('_hooks'))):
return False
if isinstance(value, FlaxModule) and attribute_name.startswith('_graph_node__'):
return False
# is_private = attribute_name.startswith('_')
return True


Expand All @@ -89,6 +80,10 @@ def display_generic(value: Any,
with _verify(value, seen, key) as x:
if x:
return x
if flax_loaded:
from flax import nnx # noqa: PLC0415
if isinstance(value, nnx.nnx.reprlib.Representable): # pyright: ignore
return _display_flax_object(value, seen=seen, key=key)
if is_dataclass(value) and not isinstance(value, type):
return _display_dataclass(value, seen=seen, key=key) # type: ignore[unreachable]
return _display_object(value, seen=seen, key=key)
Expand Down Expand Up @@ -271,11 +266,28 @@ def _display_dataclass(value: DataclassInstance,
return retval


def _display_flax_object(value: Any,
*,
seen: MutableSet[int],
key: str = '',
) -> Tree:
from flax import nnx # noqa: PLC0415
assert isinstance(value, nnx.nnx.reprlib.Representable) # pyright: ignore
iterator = value.__nnx_repr__()
config = next(iterator)
assert isinstance(config, nnx.nnx.reprlib.Object) # pyright: ignore
retval = display_class(key, type(value))
for element in iterator:
assert isinstance(element, nnx.nnx.reprlib.Attr) # pyright: ignore
retval.children.append(display_generic(getattr(value, element.key), seen=seen,
key=element.key))
return retval


def _display_object(value: Any,
*,
seen: MutableSet[int],
key: str = '',
hide_private: bool = False,
) -> Tree:
retval = display_class(key, type(value))
variables = _variables(value)
Expand Down

0 comments on commit a2ee0b8

Please sign in to comment.