Skip to content

Commit

Permalink
Add option for disabling warnings (#1058)
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer authored Dec 12, 2023
1 parent 760993c commit 2970e1b
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 9 deletions.
7 changes: 5 additions & 2 deletions docs/src/user/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ In addition to the solver, you can alter the behavior of the Optim package by us
* `store_trace`: Should a trace of the optimization algorithm's state be stored? Defaults to `false`.
* `show_trace`: Should a trace of the optimization algorithm's state be shown on `stdout`? Defaults to `false`.
* `extended_trace`: Save additional information. Solver dependent. Defaults to `false`.
* `show_warnings`: Should warnings due to NaNs or Inf be shown? Defaults to `true`.
* `trace_simplex`: Include the full simplex in the trace for `NelderMead`. Defaults to `false`.
* `show_every`: Trace output is printed every `show_every`th iteration.
* `callback`: A function to be called during tracing. A return value of `true` stops the `optimize` call. The callback function is called every `show_every`th iteration. If `store_trace` is false, the argument to the callback is of the type [`OptimizationState`](https://github.com/JuliaNLSolvers/Optim.jl/blob/a1035134ca1f3ebe855f1cde034e32683178225a/src/types.jl#L155), describing the state of the current iteration. If `store_trace` is true, the argument is a list of all the states from the first iteration to the current.
Expand All @@ -73,7 +74,8 @@ res = optimize(f, g!,
Optim.Options(g_tol = 1e-12,
iterations = 10,
store_trace = true,
show_trace = false))
show_trace = false,
show_warnings = true))
```
Another interface is also available, based directly on keywords:
```jl
Expand All @@ -83,7 +85,8 @@ res = optimize(f, g!,
g_tol = 1e-12,
iterations = 10,
store_trace = true,
show_trace = false)
show_trace = false,
show_warnings = true)
```
Notice the need to specify the method using a keyword if this syntax is used.
This approach might be deprecated in the future, and as a result we recommend writing code
Expand Down
2 changes: 1 addition & 1 deletion docs/src/user/minimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ If we want to manually specify this method, we use the usual syntax as for multi
optimize(f, lower, upper, GoldenSection(); kwargs...)
```

Keywords are used to set options for this special type of optimization. In addition to the `iterations`, `store_trace`, `show_trace` and `extended_trace` options, the following options are also available:
Keywords are used to set options for this special type of optimization. In addition to the `iterations`, `store_trace`, `show_trace`, `show_warnings`, and `extended_trace` options, the following options are also available:

* `rel_tol`: The relative tolerance used for determining convergence. Defaults to `sqrt(eps(T))`.
* `abs_tol`: The absolute tolerance used for determining convergence. Defaults to `eps(T)`.
Expand Down
5 changes: 4 additions & 1 deletion src/deprecate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function optimize(
store_trace::Bool = false,
show_trace::Bool = false,
extended_trace::Bool = false,
show_warnings::Bool = true,
callback = nothing,
show_every::Integer = 1,
linesearch = LineSearches.HagerZhang{T}(),
Expand All @@ -25,7 +26,8 @@ function optimize(
optimizer = ConjugateGradient,
optimizer_o = Options(store_trace = store_trace,
show_trace = show_trace,
extended_trace = extended_trace),
extended_trace = extended_trace,
show_warnings = show_warnings),
nargs...) where T<:AbstractFloat
if !has_deprecated_fminbox[]
@warn("Fminbox with the optimizer keyword is deprecated, construct Fminbox{optimizer}() and pass it to optimize(...) instead.")
Expand All @@ -37,6 +39,7 @@ function optimize(
store_trace=store_trace,
show_trace=show_trace,
extended_trace=extended_trace,
show_warnings=show_warnings,
show_every=show_every,
callback=callback,
linesearch=linesearch,
Expand Down
4 changes: 2 additions & 2 deletions src/multivariate/optimize/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ function optimize(d::D, initial_x::Tx, method::M,
end

if g_calls(d) > 0 && !all(isfinite, gradient(d))
@warn "Terminated early due to NaN in gradient."
options.show_warnings && @warn "Terminated early due to NaN in gradient."
break
end
if h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
@warn "Terminated early due to NaN in Hessian."
options.show_warnings && @warn "Terminated early due to NaN in Hessian."
break
end
end # while
Expand Down
5 changes: 4 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ outer_iterations::Int = 1000,
store_trace::Bool = false,
show_trace::Bool = false,
extended_trace::Bool = false,
show_warnings::Bool = true,
show_every::Int = 1,
callback = nothing,
time_limit = NaN
Expand Down Expand Up @@ -65,6 +66,7 @@ struct Options{T, TCallback}
trace_simplex::Bool
show_trace::Bool
extended_trace::Bool
show_warnings::Bool
show_every::Int
callback::TCallback
time_limit::Float64
Expand Down Expand Up @@ -101,6 +103,7 @@ function Options(;
trace_simplex::Bool = false,
show_trace::Bool = false,
extended_trace::Bool = false,
show_warnings::Bool = true,
show_every::Int = 1,
callback = nothing,
time_limit = NaN)
Expand All @@ -127,7 +130,7 @@ function Options(;
outer_f_reltol = outer_f_tol
end
Options(promote(x_abstol, x_reltol, f_abstol, f_reltol, g_abstol, g_reltol, outer_x_abstol, outer_x_reltol, outer_f_abstol, outer_f_reltol, outer_g_abstol, outer_g_reltol)..., f_calls_limit, g_calls_limit, h_calls_limit,
allow_f_increases, allow_outer_f_increases, successive_f_tol, Int(iterations), Int(outer_iterations), store_trace, trace_simplex, show_trace, extended_trace,
allow_f_increases, allow_outer_f_increases, successive_f_tol, Int(iterations), Int(outer_iterations), store_trace, trace_simplex, show_trace, extended_trace, show_warnings,
Int(show_every), callback, Float64(time_limit))
end

Expand Down
2 changes: 2 additions & 0 deletions src/univariate/optimize/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ function optimize(f,
iterations::Integer = 1_000,
store_trace::Bool = false,
show_trace::Bool = false,
show_warnings::Bool = true,
callback = nothing,
show_every = 1,
extended_trace::Bool = false) where T <: Real
Expand All @@ -24,6 +25,7 @@ function optimize(f,
iterations = iterations,
store_trace = store_trace,
show_trace = show_trace,
show_warnings = show_warnings,
show_every = show_every,
callback = callback,
extended_trace = extended_trace)
Expand Down
3 changes: 2 additions & 1 deletion src/univariate/solvers/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ function optimize(
iterations::Integer = 1_000,
store_trace::Bool = false,
show_trace::Bool = false,
show_warnings::Bool = true,
callback = nothing,
show_every = 1,
extended_trace::Bool = false) where T <: AbstractFloat
t0 = time()
options = (store_trace=store_trace, show_trace=show_trace, show_every=show_every, callback=callback)
options = (store_trace=store_trace, show_trace=show_trace, show_warnings=show_warnings, show_every=show_every, callback=callback)
if x_lower > x_upper
error("x_lower must be less than x_upper")
end
Expand Down
3 changes: 2 additions & 1 deletion src/univariate/solvers/golden_section.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function optimize(f, x_lower::T, x_upper::T,
iterations::Integer = 1_000,
store_trace::Bool = false,
show_trace::Bool = false,
show_warnings::Bool = true,
callback = nothing,
show_every = 1,
extended_trace::Bool = false,
Expand All @@ -33,7 +34,7 @@ function optimize(f, x_lower::T, x_upper::T,
error("x_lower must be less than x_upper")
end
t0 = time()
options = (store_trace=store_trace, show_trace=show_trace, show_every=show_every, callback=callback)
options = (store_trace=store_trace, show_trace=show_trace, show_warnings=show_warnings, show_every=show_every, callback=callback)
# Save for later
initial_lower = x_lower
initial_upper = x_upper
Expand Down

0 comments on commit 2970e1b

Please sign in to comment.