diff --git a/torax/_src/config/numerics.py b/torax/_src/config/numerics.py index dc998bb4c..6dc45b103 100644 --- a/torax/_src/config/numerics.py +++ b/torax/_src/config/numerics.py @@ -40,7 +40,6 @@ class DynamicNumerics: resistivity_multiplier: array_typing.ScalarFloat adaptive_T_source_prefactor: float adaptive_n_source_prefactor: float - calcphibdot: bool @chex.dataclass(frozen=True) @@ -54,6 +53,7 @@ class StaticNumerics: evolve_current: bool evolve_density: bool adaptive_dt: bool + calcphibdot: bool class Numerics(torax_pydantic.BaseModelFrozen): @@ -145,7 +145,6 @@ def build_dynamic_params( chi_timestep_prefactor=self.chi_timestep_prefactor, fixed_dt=self.fixed_dt, dt_reduction_factor=self.dt_reduction_factor, - calcphibdot=self.calcphibdot, resistivity_multiplier=self.resistivity_multiplier.get_value(t), adaptive_T_source_prefactor=self.adaptive_T_source_prefactor, adaptive_n_source_prefactor=self.adaptive_n_source_prefactor, @@ -159,4 +158,5 @@ def build_static_params(self) -> StaticNumerics: evolve_current=self.evolve_current, evolve_density=self.evolve_density, adaptive_dt=self.adaptive_dt, + calcphibdot=self.calcphibdot, ) diff --git a/torax/_src/orchestration/step_function.py b/torax/_src/orchestration/step_function.py index 74c79acc4..973313480 100644 --- a/torax/_src/orchestration/step_function.py +++ b/torax/_src/orchestration/step_function.py @@ -200,6 +200,7 @@ def __call__( ) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( input_state.t, dt_crash, + self._static_runtime_params_slice, self._dynamic_runtime_params_slice_provider, geo_t, self._geometry_provider, @@ -239,6 +240,7 @@ def __call__( _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( input_state.t, dt, + self._static_runtime_params_slice, self._dynamic_runtime_params_slice_provider, geo_t, self._geometry_provider, @@ -544,6 +546,7 @@ def body_fun( ) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( input_state.t, dt, + self._static_runtime_params_slice, self._dynamic_runtime_params_slice_provider, geo_t, self._geometry_provider, @@ -852,6 +855,7 @@ def _calculate_total_transport_coeffs( def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( t: jax.Array, dt: jax.Array, + static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider, geo_t: geometry.Geometry, geometry_provider: geometry_provider_lib.GeometryProvider, @@ -865,6 +869,7 @@ def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( Args: t: Time at which the simulation is currently at. dt: Time step duration. + static_runtime_params_slice: Static runtime parameters. dynamic_runtime_params_slice_provider: Object that returns a set of runtime parameters which may change from time step to time step or simulation run to run. If these runtime parameters change, it does NOT trigger a JAX @@ -886,7 +891,7 @@ def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot( geometry_provider=geometry_provider, ) ) - if dynamic_runtime_params_slice_t_plus_dt.numerics.calcphibdot: + if static_runtime_params_slice.numerics.calcphibdot: geo_t, geo_t_plus_dt = geometry.update_geometries_with_Phibdot( dt=dt, geo_t=geo_t,