33from .optimizer import Optimizer
44
55
6+ def _cubic_interpolate (x1 , f1 , g1 , x2 , f2 , g2 , bounds = None ):
7+ # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
8+ # Compute bounds of interpolation area
9+ if bounds is not None :
10+ xmin_bound , xmax_bound = bounds
11+ else :
12+ xmin_bound , xmax_bound = (x1 , x2 ) if x1 <= x2 else (x2 , x1 )
13+
14+ # Code for most common case: cubic interpolation of 2 points
15+ # w/ function and derivative values for both
16+ # Solution in this case (where x2 is the farthest point):
17+ # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
18+ # d2 = sqrt(d1^2 - g1*g2);
19+ # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
20+ # t_new = min(max(min_pos,xmin_bound),xmax_bound);
21+ d1 = g1 + g2 - 3 * (f1 - f2 ) / (x1 - x2 )
22+ d2_square = d1 ** 2 - g1 * g2
23+ if d2_square >= 0 :
24+ d2 = d2_square .sqrt ()
25+ if x1 <= x2 :
26+ min_pos = x2 - (x2 - x1 ) * ((g2 + d2 - d1 ) / (g2 - g1 + 2 * d2 ))
27+ else :
28+ min_pos = x1 - (x1 - x2 ) * ((g1 + d2 - d1 ) / (g1 - g2 + 2 * d2 ))
29+ return min (max (min_pos , xmin_bound ), xmax_bound )
30+ else :
31+ return (xmin_bound + xmax_bound ) / 2.
32+
33+
34+ def _strong_Wolfe (obj_func , x , t , d , f , g , gtd , c1 = 1e-4 , c2 = 0.9 , tolerance_change = 1e-9 ,
35+ max_ls = 25 ):
36+ # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
37+ d_norm = d .abs ().max ()
38+ g = g .clone ()
39+ # evaluate objective and gradient using initial step
40+ f_new , g_new = obj_func (x , t , d )
41+ ls_func_evals = 1
42+ gtd_new = g_new .dot (d )
43+
44+ # bracket an interval containing a point satisfying the Wolfe criteria
45+ t_prev , f_prev , g_prev , gtd_prev = 0 , f , g , gtd
46+ done = False
47+ ls_iter = 0
48+ while ls_iter < max_ls :
49+ # check conditions
50+ if f_new > (f + c1 * t * gtd ) or (ls_iter > 1 and f_new >= f_prev ):
51+ bracket = [t_prev , t ]
52+ bracket_f = [f_prev , f_new ]
53+ bracket_g = [g_prev , g_new .clone ()]
54+ bracket_gtd = [gtd_prev , gtd_new ]
55+ break
56+
57+ if abs (gtd_new ) <= - c2 * gtd :
58+ bracket = [t ]
59+ bracket_f = [f_new ]
60+ bracket_g = [g_new ]
61+ done = True
62+ break
63+
64+ if gtd_new >= 0 :
65+ bracket = [t_prev , t ]
66+ bracket_f = [f_prev , f_new ]
67+ bracket_g = [g_prev , g_new .clone ()]
68+ bracket_gtd = [gtd_prev , gtd_new ]
69+ break
70+
71+ # interpolate
72+ min_step = t + 0.01 * (t - t_prev )
73+ max_step = t * 10
74+ tmp = t
75+ t = _cubic_interpolate (t_prev , f_prev , gtd_prev , t , f_new , gtd_new ,
76+ bounds = (min_step , max_step ))
77+
78+ # next step
79+ t_prev = tmp
80+ f_prev = f_new
81+ g_prev = g_new .clone ()
82+ gtd_prev = gtd_new
83+ f_new , g_new = obj_func (x , t , d )
84+ ls_func_evals += 1
85+ gtd_new = g_new .dot (d )
86+ ls_iter += 1
87+
88+ # reached max number of iterations?
89+ if ls_iter == max_ls :
90+ bracket = [0 , t ]
91+ bracket_f = [f , f_new ]
92+ bracket_g = [g , g_new ]
93+
94+ # zoom phase: we now have a point satisfying the criteria, or
95+ # a bracket around it. We refine the bracket until we find the
96+ # exact point satisfying the criteria
97+ insuf_progress = False
98+ # find high and low points in bracket
99+ low_pos , high_pos = (0 , 1 ) if bracket_f [0 ] <= bracket_f [- 1 ] else (1 , 0 )
100+ while not done and ls_iter < max_ls :
101+ # compute new trial value
102+ t = _cubic_interpolate (bracket [0 ], bracket_f [0 ], bracket_gtd [0 ],
103+ bracket [1 ], bracket_f [1 ], bracket_gtd [1 ])
104+
105+ # test that we are making sufficient progress:
106+ # in case `t` is so close to boundary, we mark that we are making
107+ # insufficient progress, and if
108+ # + we have made insufficient progress in the last step, or
109+ # + `t` is at one of the boundary,
110+ # we will move `t` to a position which is `0.1 * len(bracket)`
111+ # away from the nearest boundary point.
112+ eps = 0.1 * (max (bracket ) - min (bracket ))
113+ if min (max (bracket ) - t , t - min (bracket )) < eps :
114+ # interpolation close to boundary
115+ if insuf_progress or t >= max (bracket ) or t <= min (bracket ):
116+ # evaluate at 0.1 away from boundary
117+ if abs (t - max (bracket )) < abs (t - min (bracket )):
118+ t = max (bracket ) - eps
119+ else :
120+ t = min (bracket ) + eps
121+ insuf_progress = False
122+ else :
123+ insuf_progress = True
124+ else :
125+ insuf_progress = False
126+
127+ # Evaluate new point
128+ f_new , g_new = obj_func (x , t , d )
129+ ls_func_evals += 1
130+ gtd_new = g_new .dot (d )
131+ ls_iter += 1
132+
133+ if f_new > (f + c1 * t * gtd ) or f_new >= bracket_f [low_pos ]:
134+ # Armijo condition not satisfied or not lower than lowest point
135+ bracket [high_pos ] = t
136+ bracket_f [high_pos ] = f_new
137+ bracket_g [high_pos ] = g_new .clone ()
138+ bracket_gtd [high_pos ] = gtd_new
139+ low_pos , high_pos = (0 , 1 ) if bracket_f [0 ] <= bracket_f [1 ] else (1 , 0 )
140+ else :
141+ if abs (gtd_new ) <= - c2 * gtd :
142+ # Wolfe conditions satisfied
143+ done = True
144+ elif gtd_new * (bracket [high_pos ] - bracket [low_pos ]) >= 0 :
145+ # old high becomes new low
146+ bracket [high_pos ] = bracket [low_pos ]
147+ bracket_f [high_pos ] = bracket_f [low_pos ]
148+ bracket_g [high_pos ] = bracket_g [low_pos ]
149+ bracket_gtd [high_pos ] = bracket_gtd [low_pos ]
150+
151+ # new point becomes new low
152+ bracket [low_pos ] = t
153+ bracket_f [low_pos ] = f_new
154+ bracket_g [low_pos ] = g_new .clone ()
155+ bracket_gtd [low_pos ] = gtd_new
156+
157+ # line-search bracket is so small
158+ if abs (bracket [1 ] - bracket [0 ]) * d_norm < tolerance_change :
159+ break
160+
161+ # return stuff
162+ t = bracket [low_pos ]
163+ f_new = bracket_f [low_pos ]
164+ g_new = bracket_g [low_pos ]
165+ return f_new , g_new , t , ls_func_evals
166+
167+
6168class LBFGS (Optimizer ):
7- """Implements L-BFGS algorithm.
169+ """Implements L-BFGS algorithm, heavily inspired by `minFunc
170+ <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.
8171
9172 .. warning::
10173 This optimizer doesn't support per-parameter options and parameter
@@ -30,6 +193,7 @@ class LBFGS(Optimizer):
30193 tolerance_change (float): termination tolerance on function
31194 value/parameter changes (default: 1e-9).
32195 history_size (int): update history size (default: 100).
196+ line_search_fn (str): either 'strong_Wolfe' or None (default: None).
33197 """
34198
35199 def __init__ (self , params , lr = 1 , max_iter = 20 , max_eval = None ,
@@ -58,11 +222,11 @@ def _gather_flat_grad(self):
58222 views = []
59223 for p in self ._params :
60224 if p .grad is None :
61- view = p .data . new (p . data .numel ()).zero_ ()
62- elif p .grad .data . is_sparse :
63- view = p .grad .data . to_dense ().view (- 1 )
225+ view = p .new (p .numel ()).zero_ ()
226+ elif p .grad .is_sparse :
227+ view = p .grad .to_dense ().view (- 1 )
64228 else :
65- view = p .grad .data . view (- 1 )
229+ view = p .grad .view (- 1 )
66230 views .append (view )
67231 return torch .cat (views , 0 )
68232
@@ -75,6 +239,20 @@ def _add_grad(self, step_size, update):
75239 offset += numel
76240 assert offset == self ._numel ()
77241
242+ def _clone_param (self ):
243+ return [p .clone () for p in self ._params ]
244+
245+ def _set_param (self , params_data ):
246+ for p , pdata in zip (self ._params , params_data ):
247+ p .data .copy_ (pdata )
248+
249+ def _directional_evaluate (self , closure , x , t , d ):
250+ self ._add_grad (t , d )
251+ loss = float (closure ())
252+ flat_grad = self ._gather_flat_grad ()
253+ self ._set_param (x )
254+ return loss , flat_grad
255+
78256 def step (self , closure ):
79257 """Performs a single optimization step.
80258
@@ -106,16 +284,18 @@ def step(self, closure):
106284 state ['func_evals' ] += 1
107285
108286 flat_grad = self ._gather_flat_grad ()
109- abs_grad_sum = flat_grad .abs ().sum ()
287+ opt_cond = flat_grad .abs ().max () <= tolerance_grad
110288
111- if abs_grad_sum <= tolerance_grad :
289+ # optimal condition
290+ if opt_cond :
112291 return orig_loss
113292
114293 # tensors cached in state (for tracing)
115294 d = state .get ('d' )
116295 t = state .get ('t' )
117296 old_dirs = state .get ('old_dirs' )
118297 old_stps = state .get ('old_stps' )
298+ ro = state .get ('ro' )
119299 H_diag = state .get ('H_diag' )
120300 prev_flat_grad = state .get ('prev_flat_grad' )
121301 prev_loss = state .get ('prev_loss' )
@@ -134,6 +314,7 @@ def step(self, closure):
134314 d = flat_grad .neg ()
135315 old_dirs = []
136316 old_stps = []
317+ ro = []
137318 H_diag = 1
138319 else :
139320 # do lbfgs update (update memory)
@@ -146,10 +327,12 @@ def step(self, closure):
146327 # shift history by one (limited-memory)
147328 old_dirs .pop (0 )
148329 old_stps .pop (0 )
330+ ro .pop (0 )
149331
150332 # store new direction/step
151333 old_dirs .append (y )
152334 old_stps .append (s )
335+ ro .append (1. / ys )
153336
154337 # update scale of initial Hessian approximation
155338 H_diag = ys / y .dot (y ) # (y*y)
@@ -158,15 +341,10 @@ def step(self, closure):
158341 # multiplied by the gradient
159342 num_old = len (old_dirs )
160343
161- if 'ro' not in state :
162- state ['ro' ] = [None ] * history_size
344+ if 'al' not in state :
163345 state ['al' ] = [None ] * history_size
164- ro = state ['ro' ]
165346 al = state ['al' ]
166347
167- for i in range (num_old ):
168- ro [i ] = 1. / old_dirs [i ].dot (old_stps [i ])
169-
170348 # iteration in L-BFGS loop collapsed to use just one buffer
171349 q = flat_grad .neg ()
172350 for i in range (num_old - 1 , - 1 , - 1 ):
@@ -191,18 +369,32 @@ def step(self, closure):
191369 ############################################################
192370 # reset initial guess for step size
193371 if state ['n_iter' ] == 1 :
194- t = min (1. , 1. / abs_grad_sum ) * lr
372+ t = min (1. , 1. / flat_grad . abs (). sum () ) * lr
195373 else :
196374 t = lr
197375
198376 # directional derivative
199377 gtd = flat_grad .dot (d ) # g * d
200378
379+ # directional derivative is below tolerance
380+ if gtd > - tolerance_change :
381+ break
382+
201383 # optional line search: user function
202384 ls_func_evals = 0
203385 if line_search_fn is not None :
204386 # perform line search, using user function
205- raise RuntimeError ("line search function is not supported yet" )
387+ if line_search_fn != "strong_Wolfe" :
388+ raise RuntimeError ("only 'strong_Wolfe' is supported" )
389+ else :
390+ x_init = self ._clone_param ()
391+
392+ def obj_func (x , t , d ):
393+ return self ._directional_evaluate (closure , x , t , d )
394+ loss , flat_grad , t , ls_func_evals = _strong_Wolfe (obj_func , x_init , t , d ,
395+ loss , flat_grad , gtd )
396+ self ._add_grad (t , d )
397+ opt_cond = flat_grad .abs ().max () <= tolerance_grad
206398 else :
207399 # no line search, simply move with fixed-step
208400 self ._add_grad (t , d )
@@ -212,7 +404,7 @@ def step(self, closure):
212404 # no use to re-evaluate that function here
213405 loss = float (closure ())
214406 flat_grad = self ._gather_flat_grad ()
215- abs_grad_sum = flat_grad .abs ().sum ()
407+ opt_cond = flat_grad .abs ().max () <= tolerance_grad
216408 ls_func_evals = 1
217409
218410 # update func eval
@@ -228,13 +420,12 @@ def step(self, closure):
228420 if current_evals >= max_eval :
229421 break
230422
231- if abs_grad_sum <= tolerance_grad :
232- break
233-
234- if gtd > - tolerance_change :
423+ # optimal condition
424+ if opt_cond :
235425 break
236426
237- if d .mul (t ).abs_ ().sum () <= tolerance_change :
427+ # lack of progress
428+ if d .mul (t ).abs ().max () <= tolerance_change :
238429 break
239430
240431 if abs (loss - prev_loss ) < tolerance_change :
@@ -244,6 +435,7 @@ def step(self, closure):
244435 state ['t' ] = t
245436 state ['old_dirs' ] = old_dirs
246437 state ['old_stps' ] = old_stps
438+ state ['ro' ] = ro
247439 state ['H_diag' ] = H_diag
248440 state ['prev_flat_grad' ] = prev_flat_grad
249441 state ['prev_loss' ] = prev_loss
0 commit comments