Skip to content

Add calcphibdot to StaticNumerics. #1318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torax/_src/config/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class DynamicNumerics:
resistivity_multiplier: array_typing.ScalarFloat
adaptive_T_source_prefactor: float
adaptive_n_source_prefactor: float
calcphibdot: bool


@chex.dataclass(frozen=True)
Expand All @@ -54,6 +53,7 @@ class StaticNumerics:
evolve_current: bool
evolve_density: bool
adaptive_dt: bool
calcphibdot: bool


class Numerics(torax_pydantic.BaseModelFrozen):
Expand Down Expand Up @@ -145,7 +145,6 @@ def build_dynamic_params(
chi_timestep_prefactor=self.chi_timestep_prefactor,
fixed_dt=self.fixed_dt,
dt_reduction_factor=self.dt_reduction_factor,
calcphibdot=self.calcphibdot,
resistivity_multiplier=self.resistivity_multiplier.get_value(t),
adaptive_T_source_prefactor=self.adaptive_T_source_prefactor,
adaptive_n_source_prefactor=self.adaptive_n_source_prefactor,
Expand All @@ -159,4 +158,5 @@ def build_static_params(self) -> StaticNumerics:
evolve_current=self.evolve_current,
evolve_density=self.evolve_density,
adaptive_dt=self.adaptive_dt,
calcphibdot=self.calcphibdot,
)
7 changes: 6 additions & 1 deletion torax/_src/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __call__(
) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
input_state.t,
dt_crash,
self._static_runtime_params_slice,
self._dynamic_runtime_params_slice_provider,
geo_t,
self._geometry_provider,
Expand Down Expand Up @@ -239,6 +240,7 @@ def __call__(
_get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
input_state.t,
dt,
self._static_runtime_params_slice,
self._dynamic_runtime_params_slice_provider,
geo_t,
self._geometry_provider,
Expand Down Expand Up @@ -544,6 +546,7 @@ def body_fun(
) = _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
input_state.t,
dt,
self._static_runtime_params_slice,
self._dynamic_runtime_params_slice_provider,
geo_t,
self._geometry_provider,
Expand Down Expand Up @@ -852,6 +855,7 @@ def _calculate_total_transport_coeffs(
def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
t: jax.Array,
dt: jax.Array,
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider,
geo_t: geometry.Geometry,
geometry_provider: geometry_provider_lib.GeometryProvider,
Expand All @@ -865,6 +869,7 @@ def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
Args:
t: Time at which the simulation is currently at.
dt: Time step duration.
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice_provider: Object that returns a set of runtime
parameters which may change from time step to time step or simulation run
to run. If these runtime parameters change, it does NOT trigger a JAX
Expand All @@ -886,7 +891,7 @@ def _get_geo_and_dynamic_runtime_params_at_t_plus_dt_and_phibdot(
geometry_provider=geometry_provider,
)
)
if dynamic_runtime_params_slice_t_plus_dt.numerics.calcphibdot:
if static_runtime_params_slice.numerics.calcphibdot:
geo_t, geo_t_plus_dt = geometry.update_geometries_with_Phibdot(
dt=dt,
geo_t=geo_t,
Expand Down