Skip to content

Commit 4c41766

Browse files
authored
Criticality search method on the Model class (#3569)
1 parent a74c142 commit 4c41766

File tree

4 files changed

+326
-9
lines changed

4 files changed

+326
-9
lines changed

openmc/model/model.py

Lines changed: 269 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from __future__ import annotations
2-
from collections.abc import Iterable, Sequence
2+
from collections.abc import Callable, Iterable, Sequence
33
import copy
4+
from dataclasses import dataclass, field
45
from functools import cache
56
from pathlib import Path
67
import math
78
from numbers import Integral, Real
89
import random
910
import re
1011
from tempfile import NamedTemporaryFile, TemporaryDirectory
12+
from typing import Any, Protocol
1113
import warnings
1214

1315
import h5py
1416
import lxml.etree as ET
1517
import numpy as np
18+
from scipy.optimize import curve_fit
1619

1720
import openmc
1821
import openmc._xml as xml
@@ -24,6 +27,12 @@
2427
from openmc.utility_funcs import change_directory
2528

2629

30+
# Protocol for a function that is passed to search_keff
31+
class ModelModifier(Protocol):
32+
def __call__(self, val: float, **kwargs: Any) -> None:
33+
...
34+
35+
2736
class Model:
2837
"""Model container.
2938
@@ -2196,3 +2205,262 @@ def _replace_infinity(value):
21962205

21972206
# Take a wild guess as to how many rays are needed
21982207
self.settings.particles = 2 * int(max_length)
2208+
2209+
def keff_search(
2210+
self,
2211+
func: ModelModifier,
2212+
x0: float,
2213+
x1: float,
2214+
target: float = 1.0,
2215+
k_tol: float = 1e-4,
2216+
sigma_final: float = 3e-4,
2217+
p: float = 0.5,
2218+
q: float = 0.95,
2219+
memory: int = 4,
2220+
x_min: float | None = None,
2221+
x_max: float | None = None,
2222+
b0: int | None = None,
2223+
b_min: int = 20,
2224+
b_max: int | None = None,
2225+
maxiter: int = 50,
2226+
output: bool = False,
2227+
func_kwargs: dict[str, Any] | None = None,
2228+
run_kwargs: dict[str, Any] | None = None,
2229+
) -> SearchResult:
2230+
r"""Perform a keff search on a model parametrized by a single variable.
2231+
2232+
This method uses the GRsecant method described in a paper by `Price and
2233+
Roskoff <https://doi.org/10.1016/j.pnucene.2023.104731>`_. The GRsecant
2234+
method is a modification of the secant method that accounts for
2235+
uncertainties in the function evaluations. The method uses a weighted
2236+
linear fit of the most recent function evaluations to predict the next
2237+
point to evaluate. It also adaptively changes the number of batches to
2238+
meet the target uncertainty value at each iteration.
2239+
2240+
The target uncertainty for iteration :math:`n+1` is determined by the
2241+
following equation (following Eq. (8) in the paper):
2242+
2243+
.. math::
2244+
\sigma_{i+1} = q \sigma_\text{final} \left ( \frac{ \min \left \{
2245+
\left\lvert k_i - k_\text{target} \right\rvert : k=0,1,\dots,n
2246+
\right \} }{k_\text{tol}} \right )^p
2247+
2248+
where :math:`q` is a multiplicative factor less than 1, given as the
2249+
``sigma_factor`` parameter below.
2250+
2251+
Parameters
2252+
----------
2253+
func : ModelModifier
2254+
Function that takes the parameter to be searched and makes a
2255+
modification to the model.
2256+
x0 : float
2257+
First guess for the parameter passed to `func`
2258+
x1 : float
2259+
Second guess for the parameter passed to `func`
2260+
target : float, optional
2261+
keff value to search for
2262+
k_tol : float, optional
2263+
Stopping criterion on the function value; the absolute value must be
2264+
within ``k_tol`` of zero to be accepted.
2265+
sigma_final : float, optional
2266+
Maximum accepted k-effective uncertainty for the stopping criterion.
2267+
p : float, optional
2268+
Exponent used in the stopping criterion.
2269+
q : float, optional
2270+
Multiplicative factor used in the stopping criterion.
2271+
memory : int, optional
2272+
Number of most-recent points used in the weighted linear fit of
2273+
``f(x) = a + b x`` to predict the next point.
2274+
x_min : float, optional
2275+
Minimum allowed value for the parameter ``x``.
2276+
x_max : float, optional
2277+
Maximum allowed value for the parameter ``x``.
2278+
b0 : int, optional
2279+
Number of active batches to use for the initial function
2280+
evaluations. If None, uses the model's current setting.
2281+
b_min : int, optional
2282+
Minimum number of active batches to use in a function evaluation.
2283+
b_max : int, optional
2284+
Maximum number of active batches to use in a function evaluation.
2285+
maxiter : int, optional
2286+
Maximum number of iterations to perform.
2287+
output : bool, optional
2288+
Whether or not to display output showing iteration progress.
2289+
func_kwargs : dict, optional
2290+
Keyword-based arguments to pass to the `func` function.
2291+
run_kwargs : dict, optional
2292+
Keyword arguments to pass to :meth:`openmc.Model.run` or
2293+
:meth:`openmc.lib.run`.
2294+
2295+
Returns
2296+
-------
2297+
SearchResult
2298+
Result object containing the estimated root (parameter value) and
2299+
evaluation history (parameters, means, standard deviations, and
2300+
batches), plus convergence status and termination reason.
2301+
2302+
"""
2303+
import openmc.lib
2304+
2305+
check_type('model modifier', func, Callable)
2306+
check_type('target', target, Real)
2307+
if memory < 2:
2308+
raise ValueError("memory must be ≥ 2")
2309+
func_kwargs = {} if func_kwargs is None else dict(func_kwargs)
2310+
run_kwargs = {} if run_kwargs is None else dict(run_kwargs)
2311+
run_kwargs.setdefault('output', False)
2312+
2313+
# Create lists to store the history of evaluations
2314+
xs: list[float] = []
2315+
fs: list[float] = []
2316+
ss: list[float] = []
2317+
gs: list[int] = []
2318+
count = 0
2319+
2320+
# Helper function to evaluate f and store results
2321+
def eval_at(x: float, batches: int) -> tuple[float, float]:
2322+
# Modify the model with the current guess
2323+
func(x, **func_kwargs)
2324+
2325+
# Change the number of batches and run the model
2326+
batches += self.settings.inactive
2327+
if openmc.lib.is_initialized:
2328+
openmc.lib.settings.set_batches(batches)
2329+
openmc.lib.reset()
2330+
openmc.lib.run(**run_kwargs)
2331+
sp_filepath = f'statepoint.{batches}.h5'
2332+
else:
2333+
self.settings.batches = batches
2334+
sp_filepath = self.run(**run_kwargs)
2335+
2336+
# Extract keff and its uncertainty
2337+
with openmc.StatePoint(sp_filepath) as sp:
2338+
keff = sp.keff
2339+
2340+
if output:
2341+
nonlocal count
2342+
count += 1
2343+
print(f'Iteration {count}: {batches=}, {x=:.6g}, {keff=:.5f}')
2344+
2345+
xs.append(float(x))
2346+
fs.append(float(keff.n - target))
2347+
ss.append(float(keff.s))
2348+
gs.append(int(batches))
2349+
return fs[-1], ss[-1]
2350+
2351+
# Default b0 to current model settings if not explicitly provided
2352+
if b0 is None:
2353+
b0 = self.settings.batches - self.settings.inactive
2354+
2355+
# Perform the search (inlined GRsecant) in a temporary directory
2356+
with TemporaryDirectory() as tmpdir:
2357+
if not openmc.lib.is_initialized:
2358+
run_kwargs.setdefault('cwd', tmpdir)
2359+
2360+
# ---- Seed with two evaluations
2361+
f0, s0 = eval_at(x0, b0)
2362+
if abs(f0) <= k_tol and s0 <= sigma_final:
2363+
return SearchResult(x0, xs, fs, ss, gs, True, "converged")
2364+
f1, s1 = eval_at(x1, b0)
2365+
if abs(f1) <= k_tol and s1 <= sigma_final:
2366+
return SearchResult(x1, xs, fs, ss, gs, True, "converged")
2367+
2368+
for _ in range(maxiter - 2):
2369+
# ------ Step 1: propose next x via GRsecant
2370+
m = min(memory, len(xs))
2371+
2372+
# Perform a curve fit on f(x) = a + bx accounting for
2373+
# uncertainties. This is equivalent to minimizing the function
2374+
# in Equation (A.14)
2375+
(a, b), _ = curve_fit(
2376+
lambda x, a, b: a + b*x,
2377+
xs[-m:], fs[-m:], sigma=ss[-m:], absolute_sigma=True
2378+
)
2379+
x_new = float(-a / b)
2380+
2381+
# Clamp x_new to the bounds if provided
2382+
if x_min is not None:
2383+
x_new = max(x_new, x_min)
2384+
if x_max is not None:
2385+
x_new = min(x_new, x_max)
2386+
2387+
# ------ Step 2: choose target σ for next run (Eq. 8 + clamp)
2388+
2389+
min_abs_f = float(np.min(np.abs(fs)))
2390+
base = q * sigma_final
2391+
ratio = min_abs_f / k_tol if k_tol > 0 else 1.0
2392+
sig = base * (ratio ** p)
2393+
sig_target = max(sig, base)
2394+
2395+
# ------ Step 3: choose generations to hit σ_target (Appendix C)
2396+
2397+
# Use at least two past points for regression
2398+
if len(gs) >= 2 and np.var(np.log(gs)) > 0.0:
2399+
# Perform a curve fit based on Eq. (C.3) to solve for ln(k).
2400+
# Note that unlike in the paper, we do not leave r as an
2401+
# undetermined parameter and choose r=0.5.
2402+
(ln_k,), _ = curve_fit(
2403+
lambda ln_b, ln_k: ln_k - 0.5*ln_b,
2404+
np.log(gs[-4:]), np.log(ss[-4:]),
2405+
)
2406+
k = float(np.exp(ln_k))
2407+
else:
2408+
k = float(ss[-1] * math.sqrt(gs[-1]))
2409+
2410+
b_new = (k / sig_target) ** 2
2411+
2412+
# Clamp and round up to integer
2413+
b_new = max(b_min, math.ceil(b_new))
2414+
if b_max is not None:
2415+
b_new = min(b_new, b_max)
2416+
2417+
# Evaluate at proposed x with batches determined above
2418+
f_new, s_new = eval_at(x_new, b_new)
2419+
2420+
# Termination based on both criteria (|f| and σ)
2421+
if abs(f_new) <= k_tol and s_new <= sigma_final:
2422+
return SearchResult(x_new, xs, fs, ss, gs, True, "converged")
2423+
2424+
return SearchResult(xs[-1], xs, fs, ss, gs, False, "maxiter")
2425+
2426+
2427+
@dataclass
2428+
class SearchResult:
2429+
"""Result of a GRsecant keff search.
2430+
2431+
Attributes
2432+
----------
2433+
root : float
2434+
Estimated parameter value where f(x) = 0 at termination.
2435+
parameters : list[float]
2436+
Parameter values (x) evaluated during the search, in order.
2437+
keffs : list[float]
2438+
Estimated keff values for each evaluation.
2439+
stdevs : list[float]
2440+
One-sigma uncertainties of keff for each evaluation.
2441+
batches : list[int]
2442+
Number of active batches used for each evaluation.
2443+
converged : bool
2444+
Whether both |f| <= k_tol and sigma <= sigma_final were met.
2445+
flag : str
2446+
Reason for termination (e.g., "converged", "maxiter").
2447+
"""
2448+
root: float
2449+
parameters: list[float] = field(repr=False)
2450+
means: list[float] = field(repr=False)
2451+
stdevs: list[float] = field(repr=False)
2452+
batches: list[int] = field(repr=False)
2453+
converged: bool
2454+
flag: str
2455+
2456+
@property
2457+
def function_calls(self) -> int:
2458+
"""Number of function evaluations performed."""
2459+
return len(self.parameters)
2460+
2461+
@property
2462+
def total_batches(self) -> int:
2463+
"""Total number of active batches used across all evaluations."""
2464+
return sum(self.batches)
2465+
2466+

src/settings.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,11 +1220,6 @@ extern "C" int openmc_set_n_batches(
12201220
return OPENMC_E_INVALID_ARGUMENT;
12211221
}
12221222

1223-
if (simulation::current_batch >= n_batches) {
1224-
set_errmsg("Number of batches must be greater than current batch.");
1225-
return OPENMC_E_INVALID_ARGUMENT;
1226-
}
1227-
12281223
if (!settings::trigger_on) {
12291224
// Set n_batches and n_max_batches to same value
12301225
settings::n_batches = n_batches;

tests/unit_tests/test_lib.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,6 @@ def test_set_n_batches(lib_run):
496496

497497
for i in range(7):
498498
openmc.lib.next_batch()
499-
# Setting n_batches less than current_batch should raise error
500-
with pytest.raises(exc.InvalidArgumentError):
501-
settings.set_batches(6)
502499
# n_batches should stay the same
503500
assert settings.get_batches() == 10
504501

tests/unit_tests/test_model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,7 @@ def test_id_map_aligned_model():
901901
assert tr_instance == 3, f"Expected cell instance 3 at top-right corner, got {tr_instance}"
902902
assert tr_material == 5, f"Expected material ID 5 at top-right corner, got {tr_material}"
903903

904+
904905
def test_setter_from_list():
905906
mat = openmc.Material()
906907
model = openmc.Model(materials=[mat])
@@ -913,3 +914,59 @@ def test_setter_from_list():
913914
plot = openmc.Plot()
914915
model = openmc.Model(plots=[plot])
915916
assert isinstance(model.plots, openmc.Plots)
917+
918+
919+
def test_keff_search(run_in_tmpdir):
920+
"""Test the Model.keff_search method"""
921+
922+
# Create model of a sphere of U235
923+
mat = openmc.Material()
924+
mat.set_density('g/cm3', 18.9)
925+
mat.add_nuclide('U235', 1.0)
926+
sphere = openmc.Sphere(r=10.0, boundary_type='vacuum')
927+
cell = openmc.Cell(fill=mat, region=-sphere)
928+
geometry = openmc.Geometry([cell])
929+
settings = openmc.Settings(particles=1000, inactive=10, batches=30)
930+
model = openmc.Model(geometry=geometry, settings=settings)
931+
932+
# Define function to modify sphere radius
933+
def modify_radius(radius):
934+
sphere.r = radius
935+
936+
# Perform keff search
937+
k_tol = 4e-3
938+
sigma_final = 2e-3
939+
result = model.keff_search(
940+
func=modify_radius,
941+
x0=6.0,
942+
x1=9.0,
943+
k_tol=k_tol,
944+
sigma_final=sigma_final,
945+
output=True,
946+
)
947+
948+
final_keff = result.means[-1] + 1.0 # Add back target since means are (keff - target)
949+
final_sigma = result.stdevs[-1]
950+
951+
# Check for convergence and that tolerances are met
952+
assert result.converged, "keff_search did not converge"
953+
assert abs(final_keff - 1.0) <= k_tol, \
954+
f"Final keff {final_keff:.5f} not within k_tol {k_tol}"
955+
assert final_sigma <= sigma_final, \
956+
f"Final uncertainty {final_sigma:.5f} exceeds sigma_final {sigma_final}"
957+
958+
# Check type of result
959+
assert isinstance(result, openmc.model.SearchResult)
960+
961+
# Check that we have function evaluation history
962+
assert len(result.parameters) >= 2
963+
assert len(result.means) == len(result.parameters)
964+
assert len(result.stdevs) == len(result.parameters)
965+
assert len(result.batches) == len(result.parameters)
966+
967+
# Check that function_calls property works
968+
assert result.function_calls == len(result.parameters)
969+
970+
# Check that total_batches property works
971+
assert result.total_batches == sum(result.batches)
972+
assert result.total_batches > 0

0 commit comments

Comments
 (0)