-
Notifications
You must be signed in to change notification settings - Fork 26.3k
torch.optim.lbfgs - added box constraint and line search methods(back… #938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import torch | ||
| from functools import reduce | ||
| from .optimizer import Optimizer | ||
| from math import isinf | ||
|
|
||
|
|
||
| class LBFGS(Optimizer): | ||
|
|
@@ -29,24 +30,29 @@ class LBFGS(Optimizer): | |
| (default: 1e-5). | ||
| tolerance_change (float): termination tolerance on function value/parameter | ||
| changes (default: 1e-9). | ||
| line_search_fn (str): line search methods, currently available | ||
| ['backtracking', 'goldstein', 'weak_wolfe'] | ||
| bounds (list of tuples of tensor): bounds[i][0], bounds[i][1] are elementwise | ||
| lowerbound and upperbound of param[i], respectively | ||
| history_size (int): update history size (default: 100). | ||
| """ | ||
|
|
||
| def __init__(self, params, lr=1, max_iter=20, max_eval=None, | ||
| tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100, | ||
| line_search_fn=None): | ||
| line_search_fn=None, bounds=None): | ||
| if max_eval is None: | ||
| max_eval = max_iter * 5 // 4 | ||
| defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval, | ||
| tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, | ||
| history_size=history_size, line_search_fn=line_search_fn) | ||
| history_size=history_size, line_search_fn=line_search_fn, bounds=bounds) | ||
| super(LBFGS, self).__init__(params, defaults) | ||
|
|
||
| if len(self.param_groups) != 1: | ||
| raise ValueError("LBFGS doesn't support per-parameter options " | ||
| "(parameter groups)") | ||
|
|
||
| self._params = self.param_groups[0]['params'] | ||
| self._bounds = [(None, None)] * len(self._params) if bounds is None else bounds | ||
| self._numel_cache = None | ||
|
|
||
| def _numel(self): | ||
|
|
@@ -62,7 +68,7 @@ def _add_grad(self, step_size, update): | |
| offset = 0 | ||
| for p in self._params: | ||
| numel = p.numel() | ||
| p.data.add_(step_size, update[offset:offset + numel]) | ||
| p.data.add_(step_size, update[offset:offset + numel].resize_(p.size())) | ||
| offset += numel | ||
| assert offset == self._numel() | ||
|
|
||
|
|
@@ -195,18 +201,25 @@ def step(self, closure): | |
| ls_func_evals = 0 | ||
| if line_search_fn is not None: | ||
| # perform line search, using user function | ||
| raise RuntimeError("line search function is not supported yet") | ||
| # raise RuntimeError("line search function is not supported yet") | ||
| if line_search_fn == 'weak_wolfe': | ||
| t = self._weak_wolfe(closure, d) | ||
| elif line_search_fn == 'goldstein': | ||
| t = self._goldstein(closure, d) | ||
| elif line_search_fn == 'backtracking': | ||
| t = self._backtracking(closure, d) | ||
| self._add_grad(t, d) | ||
| else: | ||
| # no line search, simply move with fixed-step | ||
| self._add_grad(t, d) | ||
| if n_iter != max_iter: | ||
| # re-evaluate function only if not in last iteration | ||
| # the reason we do this: in a stochastic setting, | ||
| # no use to re-evaluate that function here | ||
| loss = closure().data[0] | ||
| flat_grad = self._gather_flat_grad() | ||
| abs_grad_sum = flat_grad.abs().sum() | ||
| ls_func_evals = 1 | ||
| if n_iter != max_iter: | ||
| # re-evaluate function only if not in last iteration | ||
| # the reason we do this: in a stochastic setting, | ||
| # no use to re-evaluate that function here | ||
| loss = closure().data[0] | ||
| flat_grad = self._gather_flat_grad() | ||
| abs_grad_sum = flat_grad.abs().sum() | ||
| ls_func_evals = 1 | ||
|
|
||
| # update func eval | ||
| current_evals += ls_func_evals | ||
|
|
@@ -239,3 +252,136 @@ def step(self, closure): | |
| state['prev_loss'] = prev_loss | ||
|
|
||
| return orig_loss | ||
|
|
||
| def _copy_param(self): | ||
| original_param_data_list = [] | ||
| for p in self._params: | ||
| param_data = p.data.new(p.size()) | ||
| param_data.copy_(p.data) | ||
| original_param_data_list.append(param_data) | ||
| return original_param_data_list | ||
|
|
||
| def _set_param(self, param_data_list): | ||
| for i in range(len(param_data_list)): | ||
| self._params[i].data.copy_(param_data_list[i]) | ||
|
|
||
| def _set_param_incremental(self, alpha, d): | ||
| offset = 0 | ||
| for p in self._params: | ||
| numel = p.numel() | ||
| p.data.copy_(p.data + alpha*d[offset:offset + numel].resize_(p.size())) | ||
| offset += numel | ||
| assert offset == self._numel() | ||
|
|
||
| def _directional_derivative(self, d): | ||
| deriv = 0.0 | ||
| offset = 0 | ||
| for p in self._params: | ||
| numel = p.numel() | ||
| deriv += torch.sum(p.grad.data * d[offset:offset + numel].resize_(p.size())) | ||
| offset += numel | ||
| assert offset == self._numel() | ||
| return deriv | ||
|
|
||
| def _max_alpha(self, d): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference for bounded-constrained optimization: equation 16.71 in Numerical Optimization 2nd Edition by Nocedal and Wright. |
||
| offset = 0 | ||
| max_alpha = float('inf') | ||
| for p, bnd in zip(self._params, self._bounds): | ||
| numel = p.numel() | ||
| l_bnd, u_bnd = bnd | ||
| p_grad = d[offset:offset + numel].resize_(p.size()) | ||
| if l_bnd is not None: | ||
| from_l_bnd = ((l_bnd-p.data)/p_grad)[p_grad<0] | ||
| min_l_bnd = torch.min(from_l_bnd) if from_l_bnd.numel() > 0 else max_alpha | ||
| if u_bnd is not None: | ||
| from_u_bnd = ((u_bnd-p.data)/p_grad)[p_grad>0] | ||
| min_u_bnd = torch.min(from_u_bnd) if from_u_bnd.numel() > 0 else max_alpha | ||
| max_alpha = min(max_alpha, min_l_bnd, min_u_bnd) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line takes the minimal alpha across all the directions. This implies that, if the point is near the boundary, the next step is constrained to remain very close and the descent could stall. I expect instead to do different step size in each directions, see 16.71 mentioned in previous comment. Is there another reference for this?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The relevant part of scipy is the fortran lbfgsb library's cauchy function. The gradient is scaled differently for each component, as done in Nocedal equation 16.71. |
||
| return max_alpha | ||
|
|
||
|
|
||
| def _backtracking(self, closure, d): | ||
| # 0 < rho < 0.5 and 0 < w < 1 | ||
| rho = 1e-4 | ||
| w = 0.5 | ||
|
|
||
| original_param_data_list = self._copy_param() | ||
| phi_0 = closure().data[0] | ||
| phi_0_prime = self._directional_derivative(d) | ||
| alpha_k = 1.0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unlike the |
||
| while True: | ||
| self._set_param_incremental(alpha_k, d) | ||
| phi_k = closure().data[0] | ||
| self._set_param(original_param_data_list) | ||
| if phi_k <= phi_0 + rho * alpha_k * phi_0_prime: | ||
| break | ||
| else: | ||
| alpha_k *= w | ||
| return alpha_k | ||
|
|
||
|
|
||
| def _goldstein(self, closure, d): | ||
| # 0 < rho < 0.5 and t > 1 | ||
| rho = 1e-4 | ||
| t = 2.0 | ||
|
|
||
| original_param_data_list = self._copy_param() | ||
| phi_0 = closure().data[0] | ||
| phi_0_prime = self._directional_derivative(d) | ||
| a_k = 0.0 | ||
| b_k = self._max_alpha(d) | ||
| alpha_k = min(1e4, (a_k + b_k) / 2.0) | ||
| while True: | ||
| self._set_param_incremental(alpha_k, d) | ||
| phi_k = closure().data[0] | ||
| self._set_param(original_param_data_list) | ||
| if phi_k <= phi_0 + rho*alpha_k*phi_0_prime: | ||
| if phi_k >= phi_0 + (1-rho)*alpha_k*phi_0_prime: | ||
| break | ||
| else: | ||
| a_k = alpha_k | ||
| alpha_k = t*alpha_k if isinf(b_k) else (a_k + b_k) / 2.0 | ||
| else: | ||
| b_k = alpha_k | ||
| alpha_k = (a_k + b_k)/2.0 | ||
| if torch.sum(torch.abs(alpha_k * d)) < self.param_groups[0]['tolerance_grad']: | ||
| break | ||
| if abs(b_k-a_k) < 1e-6: | ||
| break | ||
| return alpha_k | ||
|
|
||
|
|
||
| def _weak_wolfe(self, closure, d): | ||
| # 0 < rho < 0.5 and rho < sigma < 1 | ||
| rho = 1e-4 | ||
| sigma = 0.9 | ||
|
|
||
| original_param_data_list = self._copy_param() | ||
| phi_0 = closure().data[0] | ||
| phi_0_prime = self._directional_derivative(d) | ||
| a_k = 0.0 | ||
| b_k = self._max_alpha(d) | ||
| alpha_k = min(1e4, (a_k + b_k) / 2.0) | ||
| while True: | ||
| self._set_param_incremental(alpha_k, d) | ||
| phi_k = closure().data[0] | ||
| phi_k_prime = self._directional_derivative(d) | ||
| self._set_param(original_param_data_list) | ||
| if phi_k <= phi_0 + rho*alpha_k*phi_0_prime: | ||
| if phi_k_prime >= sigma*phi_0_prime: | ||
| break | ||
| else: | ||
| alpha_hat = alpha_k + (alpha_k - a_k) * phi_k_prime / (phi_0_prime - phi_k_prime) | ||
| a_k = alpha_k | ||
| phi_0 = phi_k | ||
| phi_0_prime = phi_k_prime | ||
| alpha_k = alpha_hat | ||
| else: | ||
| alpha_hat = a_k + 0.5*(alpha_k-a_k)/(1+(phi_0-phi_k)/((alpha_k-a_k)*phi_0_prime)) | ||
| b_k = alpha_k | ||
| alpha_k = alpha_hat | ||
| if torch.sum(torch.abs(alpha_k * d)) < self.param_groups[0]['tolerance_grad']: | ||
| break | ||
| if abs(b_k-a_k) < 1e-6: | ||
| break | ||
| return alpha_k | ||
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.