Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH deal with tiny loss improvements in line search #724

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/glum/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
enet_coordinate_descent_gram,
identify_active_rows,
)
from ._distribution import ExponentialDispersionModel
from ._distribution import ExponentialDispersionModel, get_one_over_variance
from ._link import Link
from ._util import _safe_lin_pred, _safe_sandwich_dot

Expand Down Expand Up @@ -758,6 +758,7 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
"""
# line search parameters
(beta, sigma) = (0.5, 0.0001)
eps = 16 * np.finfo(state.obj_val.dtype).eps

# line search by sequence beta^k, k=0, 1, ..
# F(w + lambda d) - F(w) <= lambda * bound
Expand All @@ -771,6 +772,9 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
# Note: the L2 penalty term is included in the score.
bound = sigma * (-(state.score @ d) + P1wd_1 - P1w_1)

# np.sum(np.abs(state.score))
sum_abs_grad_old = -1 # defer calculation

# The step direction in row space. We'll be multiplying this by varying
# step sizes during the line search. Factoring this matrix-vector product
# out of the inner loop improve performance a lot!
Expand All @@ -785,8 +789,30 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
eta_wd, mu_wd, obj_val_wd, coef_wd_P2 = _update_predictions(
state, data, coef_wd, X_dot_d, factor=factor
)
if (mu_wd.max() < 1e43) and (obj_val_wd - state.obj_val <= factor * bound):
# 1. Check Armijo / sufficient decrease condition.
loss_improvement = obj_val_wd - state.obj_val
if mu_wd.max() < 1e43 and loss_improvement <= factor * bound:
break
# 2. Deal with relative loss differences around machine precision.
tiny_loss = np.abs(state.obj_val * eps)
if np.abs(loss_improvement) <= tiny_loss:
if sum_abs_grad_old < 0:
sum_abs_grad_old = linalg.norm(state.score, ord=1)
# 2.1 Check sum of absolute gradients as alternative condition.
# Therefore, we need the recent gradient, see update_quadratic.
sigma_inv = get_one_over_variance(
data.family, data.link, mu_wd, eta_wd, 1.0, data.sample_weight
)
d1 = data.link.inverse_derivative(eta_wd) # = h'(eta)
d1_sigma_inv = d1 * sigma_inv
gradient_rows = d1_sigma_inv * (data.y - mu_wd)
grad = gradient_rows @ data.X
if data.fit_intercept:
grad = np.concatenate(([gradient_rows.sum()], grad))
grad -= coef_wd_P2
sum_abs_grad = linalg.norm(grad, ord=1)
if sum_abs_grad < sum_abs_grad_old:
break
factor *= beta
else:
warnings.warn(
Expand Down