Skip to content

Add experimental compile to Solver.__call__. #1298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions torax/_src/solver/nonlinear_theta_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,32 @@

@chex.dataclass(frozen=True)
class DynamicOptimizerRuntimeParams(runtime_params.DynamicRuntimeParams):
initial_guess_mode: int
n_max_iterations: int
loss_tol: float


@chex.dataclass(frozen=True)
class DynamicNewtonRaphsonRuntimeParams(runtime_params.DynamicRuntimeParams):
log_iterations: bool
class StaticOptimizerRuntimeParams(runtime_params.StaticRuntimeParams):
initial_guess_mode: int


@chex.dataclass(frozen=True)
class DynamicNewtonRaphsonRuntimeParams(runtime_params.DynamicRuntimeParams):
maxiter: int
residual_tol: float
residual_coarse_tol: float
delta_reduction_factor: float
tau_min: float


@chex.dataclass(frozen=True)
class StaticNewtonRaphsonRuntimeParams(
runtime_params.StaticRuntimeParams
):
initial_guess_mode: int
log_iterations: bool


class NonlinearThetaMethod(solver.Solver):
"""Time step update using nonlinear solvers and the theta method."""

Expand Down Expand Up @@ -179,6 +189,8 @@ def _x_new_helper(
"""See abstract method docstring in NonlinearThetaMethod."""
solver_params = dynamic_runtime_params_slice_t.solver
assert isinstance(solver_params, DynamicOptimizerRuntimeParams)
static_solver_params = static_runtime_params_slice.solver
assert isinstance(static_solver_params, StaticOptimizerRuntimeParams)
(
x_new,
solver_numeric_outputs,
Expand All @@ -202,7 +214,7 @@ def _x_new_helper(
coeffs_callback=coeffs_callback,
evolving_names=evolving_names,
initial_guess_mode=enums.InitialGuessMode(
solver_params.initial_guess_mode,
static_solver_params.initial_guess_mode,
),
maxiter=solver_params.n_max_iterations,
tol=solver_params.loss_tol,
Expand Down Expand Up @@ -236,6 +248,8 @@ def _x_new_helper(
"""See abstract method docstring in NonlinearThetaMethod."""
solver_params = dynamic_runtime_params_slice_t.solver
assert isinstance(solver_params, DynamicNewtonRaphsonRuntimeParams)
static_solver_params = static_runtime_params_slice.solver
assert isinstance(static_solver_params, StaticNewtonRaphsonRuntimeParams)

(
x_new,
Expand All @@ -259,9 +273,9 @@ def _x_new_helper(
neoclassical_models=self.neoclassical_models,
coeffs_callback=coeffs_callback,
evolving_names=evolving_names,
log_iterations=solver_params.log_iterations,
log_iterations=static_solver_params.log_iterations,
initial_guess_mode=enums.InitialGuessMode(
solver_params.initial_guess_mode
static_solver_params.initial_guess_mode
),
maxiter=solver_params.maxiter,
tol=solver_params.residual_tol,
Expand Down
25 changes: 22 additions & 3 deletions torax/_src/solver/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Pydantic config for Solver."""
import abc
import dataclasses
import functools
from typing import Literal

Expand Down Expand Up @@ -163,15 +164,24 @@ class NewtonRaphsonThetaMethod(BaseSolver):
def linear_solver(self) -> bool:
return self.initial_guess_mode == enums.InitialGuessMode.LINEAR

def build_static_params(
self,
) -> nonlinear_theta_method.StaticNewtonRaphsonRuntimeParams:
"""Builds static runtime params from the config."""
base_params = super().build_static_params()
return nonlinear_theta_method.StaticNewtonRaphsonRuntimeParams(
**dataclasses.asdict(base_params),
initial_guess_mode=self.initial_guess_mode.value,
log_iterations=self.log_iterations,
)

@functools.cached_property
def build_dynamic_params(
self,
) -> nonlinear_theta_method.DynamicNewtonRaphsonRuntimeParams:
return nonlinear_theta_method.DynamicNewtonRaphsonRuntimeParams(
chi_pereverzev=self.chi_pereverzev,
D_pereverzev=self.D_pereverzev,
log_iterations=self.log_iterations,
initial_guess_mode=self.initial_guess_mode.value,
maxiter=self.n_max_iterations,
residual_tol=self.residual_tol,
residual_coarse_tol=self.residual_coarse_tol,
Expand Down Expand Up @@ -216,14 +226,23 @@ class OptimizerThetaMethod(BaseSolver):
def linear_solver(self) -> bool:
return self.initial_guess_mode == enums.InitialGuessMode.LINEAR

def build_static_params(
self,
) -> nonlinear_theta_method.StaticOptimizerRuntimeParams:
"""Builds static runtime params from the config."""
base_params = super().build_static_params()
return nonlinear_theta_method.StaticOptimizerRuntimeParams(
**dataclasses.asdict(base_params),
initial_guess_mode=self.initial_guess_mode.value,
)

@functools.cached_property
def build_dynamic_params(
self,
) -> nonlinear_theta_method.DynamicOptimizerRuntimeParams:
return nonlinear_theta_method.DynamicOptimizerRuntimeParams(
chi_pereverzev=self.chi_pereverzev,
D_pereverzev=self.D_pereverzev,
initial_guess_mode=self.initial_guess_mode.value,
n_max_iterations=self.n_max_iterations,
loss_tol=self.loss_tol,
n_corrector_steps=self.n_corrector_steps,
Expand Down
7 changes: 7 additions & 0 deletions torax/_src/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import jax
from torax._src import state
from torax._src import xnp
from torax._src.config import runtime_params_slice
from torax._src.fvm import cell_variable
from torax._src.geometry import geometry
Expand Down Expand Up @@ -98,6 +99,12 @@ def evolving_names(self) -> tuple[str, ...]:
evolving_names.append('n_e')
return tuple(evolving_names)

@functools.partial(
xnp.jit,
static_argnames=[
'self',
],
)
def __call__(
self,
t: jax.Array,
Expand Down
5 changes: 5 additions & 0 deletions torax/tests/sim_experimental_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class SimExperimentalCompileTest(sim_test_case.SimTestCase):
'test_implicit_short_optimizer',
'test_implicit_short_optimizer.py',
),
# Using sawtooth solver.
(
'test_iterhybrid_rampup_sawtooth',
'test_iterhybrid_rampup_sawtooth.py',
),
)
def test_run_simulation_experimental_compile(
self,
Expand Down