Skip to content

Commit 9c7b6c7

Browse files
Nush395Torax team
authored andcommitted
Add __hash__ and __eq__ to more objects in anticipation of expanding JIT scope.
PiperOrigin-RevId: 774754949
1 parent 2ca8e6b commit 9c7b6c7

File tree

11 files changed

+235
-1
lines changed

11 files changed

+235
-1
lines changed

torax/_src/config/build_runtime_params.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
`get_consistent_dynamic_runtime_params_slice_and_geometry` which returns a
2323
DynamicRuntimeParamsSlice and a corresponding geometry with consistent Ip.
2424
"""
25+
26+
import functools
27+
2528
import chex
29+
import jax
2630
from torax._src.config import runtime_params_slice
2731
from torax._src.geometry import geometry
2832
from torax._src.geometry import geometry_provider as geometry_provider_lib
@@ -89,6 +93,22 @@ def __init__(
8993
self._mhd = torax_config.mhd
9094
self._neoclassical = torax_config.neoclassical
9195

96+
@functools.cached_property
97+
def initial_slice(self) -> runtime_params_slice.DynamicRuntimeParamsSlice:
98+
return self(0.0)
99+
100+
@property
101+
def structure(self) -> jax.tree_util.PyTreeDef:
102+
return jax.tree_util.tree_structure(self.initial_slice)
103+
104+
def __hash__(self) -> int:
105+
"""Used for jax.jit caching where we only care about the structure."""
106+
return hash(self.structure)
107+
108+
def __eq__(self, other: typing_extensions.Self) -> bool:
109+
"""Used for jax.jit caching where we only care about the structure."""
110+
return self.structure == other.structure
111+
92112
@classmethod
93113
def from_config(
94114
cls,

torax/_src/config/tests/build_runtime_params_test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
1515
from absl.testing import absltest
1616
from absl.testing import parameterized
1717
import numpy as np
1818
from torax._src.config import build_runtime_params
19+
from torax._src.config import config_loader
1920
from torax._src.config import profile_conditions as profile_conditions_lib
2021
from torax._src.geometry import pydantic_model as geometry_pydantic_model
2122
from torax._src.pedestal_model import pydantic_model as pedestal_pydantic_model
2223
from torax._src.pedestal_model import set_tped_nped
2324
from torax._src.test_utils import default_configs
25+
from torax._src.test_utils import paths
2426
from torax._src.torax_pydantic import model_config
2527
from torax._src.torax_pydantic import torax_pydantic
2628

@@ -251,6 +253,44 @@ def test_profile_conditions_set_electron_density_and_boundary_condition(
251253
)
252254
self.assertTrue(static_slice.n_e_right_bc_is_absolute)
253255

256+
def test_same_shape_dynamic_params_has_same_hash_and_equals(self):
257+
test_data_dir = paths.test_data_dir()
258+
torax_config = config_loader.build_torax_config_from_file(
259+
os.path.join(test_data_dir, 'test_iterhybrid_rampup.py')
260+
)
261+
provider1 = (
262+
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
263+
torax_config
264+
)
265+
)
266+
torax_config.update_fields({'sources.generic_heat.P_total': 1.0})
267+
provider2 = (
268+
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
269+
torax_config
270+
)
271+
)
272+
self.assertEqual(provider1, provider2)
273+
self.assertEqual(hash(provider1), hash(provider2))
274+
275+
def test_different_dynamic_params_has_different_hash_and_not_equals(self):
276+
test_data_dir = paths.test_data_dir()
277+
torax_config = config_loader.build_torax_config_from_file(
278+
os.path.join(test_data_dir, 'test_iterhybrid_rampup.py')
279+
)
280+
provider1 = (
281+
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
282+
torax_config
283+
)
284+
)
285+
torax_config.update_fields({'transport': {'model_name': 'constant'}})
286+
provider2 = (
287+
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
288+
torax_config
289+
)
290+
)
291+
self.assertNotEqual(provider1, provider2)
292+
self.assertNotEqual(hash(provider1), hash(provider2))
293+
254294

255295
if __name__ == '__main__':
256296
absltest.main()

torax/_src/geometry/geometry_provider.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def __call__(
8686
def torax_mesh(self) -> torax_pydantic.Grid1D:
8787
"""Returns the mesh used by Torax, this is consistent across time."""
8888

89+
def __hash__(self) -> int:
90+
"""For the purposes of jax.jit caching, we only care about the mesh."""
91+
return hash(self.torax_mesh)
92+
93+
def __eq__(self, other: typing_extensions.Self) -> bool:
94+
"""For the purposes of jax.jit caching, we only care about the mesh."""
95+
return self.torax_mesh == other.torax_mesh
96+
8997

9098
class ConstantGeometryProvider(GeometryProvider):
9199
"""Returns the same Geometry for all calls."""

torax/_src/geometry/tests/geometry_provider_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ def test_none_z_magnetic_axis_stays_none_time_dependent(self):
7474
self.assertIsNone(provider(0.0)._z_magnetic_axis)
7575
self.assertIsNone(provider(10.0)._z_magnetic_axis)
7676

77+
def test_same_mesh_has_same_hash_and_equality(self):
78+
geo1 = geometry_pydantic_model.CircularConfig(n_rho=25).build_geometry()
79+
geo2 = geometry_pydantic_model.CircularConfig(n_rho=25).build_geometry()
80+
provider1 = geometry_provider.ConstantGeometryProvider(geo1)
81+
provider2 = geometry_provider.ConstantGeometryProvider(geo2)
82+
self.assertEqual(provider1, provider2)
83+
self.assertEqual(hash(provider1), hash(provider2))
84+
85+
def test_different_mesh_has_different_hash_and_not_equals(self):
86+
geo1 = geometry_pydantic_model.CircularConfig(n_rho=25).build_geometry()
87+
geo2 = geometry_pydantic_model.CircularConfig(n_rho=50).build_geometry()
88+
provider1 = geometry_provider.ConstantGeometryProvider(geo1)
89+
provider2 = geometry_provider.ConstantGeometryProvider(geo2)
90+
self.assertNotEqual(provider1, provider2)
91+
self.assertNotEqual(hash(provider1), hash(provider2))
92+
7793

7894
if __name__ == "__main__":
7995
absltest.main()

torax/_src/mhd/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616

1717
import chex
1818
from torax._src.mhd.sawtooth import sawtooth_model
19+
import typing_extensions
1920

2021

2122
@chex.dataclass
2223
class MHDModels:
2324
"""Container for instantiated MHD model objects."""
2425

2526
sawtooth: sawtooth_model.SawtoothModel | None = None
27+
28+
def __hash__(self) -> int:
29+
return hash(self.sawtooth)
30+
31+
def __eq__(self, other: typing_extensions.Self) -> bool:
32+
return self.sawtooth == other.sawtooth

torax/_src/mhd/tests/base_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
from absl.testing import absltest
17+
from torax._src.config import config_loader
18+
from torax._src.orchestration import run_simulation
19+
from torax._src.test_utils import paths
20+
21+
22+
class BaseTest(absltest.TestCase):
23+
24+
def test_different_mhd_models_have_same_hash_and_equals(self):
25+
test_data_dir = paths.test_data_dir()
26+
torax_config = config_loader.build_torax_config_from_file(
27+
os.path.join(test_data_dir, "test_iterhybrid_rampup.py")
28+
)
29+
static_runtime_params_slice, _, _, _, _, _, step_fn = (
30+
run_simulation.prepare_simulation(torax_config)
31+
)
32+
model1 = torax_config.mhd.build_mhd_models(
33+
static_runtime_params_slice,
34+
step_fn.solver.transport_model,
35+
step_fn.solver.source_models,
36+
step_fn.solver.pedestal_model,
37+
step_fn.solver.neoclassical_models,
38+
)
39+
model2 = torax_config.mhd.build_mhd_models(
40+
static_runtime_params_slice,
41+
step_fn.solver.transport_model,
42+
step_fn.solver.source_models,
43+
step_fn.solver.pedestal_model,
44+
step_fn.solver.neoclassical_models,
45+
)
46+
self.assertEqual(model1, model2)
47+
self.assertEqual(hash(model1), hash(model2))
48+
49+
50+
if __name__ == "__main__":
51+
absltest.main()

torax/_src/orchestration/step_function.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torax._src.sources import source_profiles as source_profiles_lib
4444
from torax._src.time_step_calculator import time_step_calculator as ts
4545
from torax._src.transport_model import transport_model as transport_model_lib
46+
import typing_extensions
4647

4748
# pylint: disable=invalid-name
4849

@@ -623,6 +624,25 @@ def body_fun(
623624
geo_t_plus_dt,
624625
)
625626

627+
def __hash__(self) -> int:
628+
return hash((
629+
self.solver,
630+
self.mhd_models,
631+
self.time_step_calculator,
632+
self._dynamic_runtime_params_slice_provider,
633+
self._geometry_provider,
634+
))
635+
636+
def __eq__(self, other: typing_extensions.Self) -> bool:
637+
return (
638+
self.solver == other.solver
639+
and self.mhd_models == other.mhd_models
640+
and self.time_step_calculator == other.time_step_calculator
641+
and self._dynamic_runtime_params_slice_provider
642+
== other._dynamic_runtime_params_slice_provider
643+
and self._geometry_provider == other._geometry_provider
644+
)
645+
626646

627647
@functools.partial(
628648
jax_utils.jit,

torax/_src/time_step_calculator/chi_time_step_calculator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,9 @@ def next_dt(
6969
)
7070

7171
return dt
72+
73+
def __hash__(self) -> int:
74+
return hash(self.__class__.__name__)
75+
76+
def __eq__(self, other) -> bool:
77+
return isinstance(other, ChiTimeStepCalculator)

torax/_src/time_step_calculator/fixed_time_step_calculator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,9 @@ def next_dt(
5656
dt = jnp.array(dynamic_runtime_params_slice.numerics.fixed_dt)
5757

5858
return dt
59+
60+
def __hash__(self) -> int:
61+
return hash(self.__class__.__name__)
62+
63+
def __eq__(self, other) -> bool:
64+
return isinstance(other, FixedTimeStepCalculator)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from absl.testing import absltest
15+
from absl.testing import parameterized
16+
from torax._src.time_step_calculator import chi_time_step_calculator
17+
from torax._src.time_step_calculator import fixed_time_step_calculator
18+
19+
20+
class TimeStepCalculatorTest(parameterized.TestCase):
21+
22+
@parameterized.named_parameters(
23+
dict(
24+
testcase_name='fixed',
25+
calculator_type=fixed_time_step_calculator.FixedTimeStepCalculator,
26+
),
27+
dict(
28+
testcase_name='chi',
29+
calculator_type=chi_time_step_calculator.ChiTimeStepCalculator,
30+
),
31+
)
32+
def test_different_time_step_calculators_have_same_hash_and_equals(
33+
self,
34+
calculator_type,
35+
):
36+
calculator1 = calculator_type()
37+
calculator2 = calculator_type()
38+
self.assertEqual(calculator1, calculator2)
39+
self.assertEqual(hash(calculator1), hash(calculator2))
40+
41+
42+
if __name__ == '__main__':
43+
absltest.main()

torax/_src/time_step_calculator/time_step_calculator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,20 @@ def next_dt(
6868
core_profiles: Core plasma profiles in the tokamak.
6969
core_transport: Transport coefficients.
7070
"""
71+
72+
@abc.abstractmethod
73+
def __hash__(self) -> int:
74+
"""Returns a hash of the time step calculator.
75+
76+
Should be implemented to support jax.jit caching.
77+
"""
78+
79+
@abc.abstractmethod
80+
def __eq__(self, other) -> bool:
81+
"""Returns whether the time step calculator is equal to the other.
82+
83+
Should be implemented to support jax.jit caching.
84+
85+
Args:
86+
other: The object to compare to.
87+
"""

0 commit comments

Comments
 (0)