Skip to content

Commit

Permalink
Update for Flax 0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 9, 2024
1 parent b4f5f75 commit cba42b8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_flax_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def graph() -> nx.DiGraph[Any]:


def test_rebuild(graph: nx.DiGraph[Any]) -> None:
graph_def, state, _ = nnx.graph.flatten(graph)
rebuilt_graph, _ = nnx.graph.unflatten(graph_def, state)
graph_def, state = nnx.graph.flatten(graph)
rebuilt_graph = nnx.graph.unflatten(graph_def, state)
assert nx.utils.graphs_equal(graph, rebuilt_graph)


Expand All @@ -46,7 +46,7 @@ def __init__(self) -> None:


def test_flatten(graph: nx.DiGraph[Any]) -> None:
_, state, _ = nnx.graph.flatten(graph)
_, state = nnx.graph.flatten(graph)
substate = state[GraphEdgeKey('a', 'b')]
assert isinstance(substate, nnx.State)
variable = substate['x']
Expand Down

0 comments on commit cba42b8

Please sign in to comment.