55import torch
66from torch ._six import inf
77import torch .optim as optim
8- import torch .legacy .optim as old_optim
98import torch .nn .functional as F
109from torch .optim import SGD
1110from torch .autograd import Variable
@@ -24,44 +23,7 @@ def drosenbrock(tensor):
2423 return torch .DoubleTensor ((- 400 * x * (y - x ** 2 ) - 2 * (1 - x ), 200 * (y - x ** 2 )))
2524
2625
27- def wrap_old_fn (old_fn , ** config ):
28- def wrapper (closure , params , state ):
29- return old_fn (closure , params , config , state )
30- return wrapper
31-
32-
3326class TestOptim (TestCase ):
34- def _test_rosenbrock (self , constructor , old_fn ):
35- params_t = torch .Tensor ([1.5 , 1.5 ])
36- state = {}
37-
38- params = Variable (torch .Tensor ([1.5 , 1.5 ]), requires_grad = True )
39- optimizer = constructor ([params ])
40-
41- solution = torch .Tensor ([1 , 1 ])
42- initial_dist = params .data .dist (solution )
43-
44- def eval ():
45- optimizer .zero_grad ()
46- loss = rosenbrock (params )
47- loss .backward ()
48- # loss.backward() will give **slightly** different
49- # gradients, than drosenbtock, because of a different ordering
50- # of floating point operations. In most cases it doesn't matter,
51- # but some optimizers are so sensitive that they can temporarily
52- # diverge up to 1e-4, just to converge again. This makes the
53- # comparison more stable.
54- params .grad .data .copy_ (drosenbrock (params .data ))
55- return loss
56-
57- for i in range (2000 ):
58- optimizer .step (eval )
59- old_fn (lambda _ : (rosenbrock (params_t ), drosenbrock (params_t )),
60- params_t , state )
61- self .assertEqual (params .data , params_t )
62-
63- self .assertLessEqual (params .data .dist (solution ), initial_dist )
64-
6527 def _test_rosenbrock_sparse (self , constructor , sparse_only = False ):
6628 params_t = torch .Tensor ([1.5 , 1.5 ])
6729
@@ -237,16 +199,6 @@ def _build_params_dict_single(self, weight, bias, **kwargs):
237199 return [dict (params = bias , ** kwargs )]
238200
239201 def test_sgd (self ):
240- self ._test_rosenbrock (
241- lambda params : optim .SGD (params , lr = 1e-3 ),
242- wrap_old_fn (old_optim .sgd , learningRate = 1e-3 )
243- )
244- self ._test_rosenbrock (
245- lambda params : optim .SGD (params , lr = 1e-3 , momentum = 0.9 ,
246- dampening = 0 , weight_decay = 1e-4 ),
247- wrap_old_fn (old_optim .sgd , learningRate = 1e-3 , momentum = 0.9 ,
248- dampening = 0 , weightDecay = 1e-4 )
249- )
250202 self ._test_basic_cases (
251203 lambda weight , bias : optim .SGD ([weight , bias ], lr = 1e-3 )
252204 )
@@ -273,14 +225,6 @@ def test_sgd_sparse(self):
273225 )
274226
275227 def test_adam (self ):
276- self ._test_rosenbrock (
277- lambda params : optim .Adam (params , lr = 1e-2 ),
278- wrap_old_fn (old_optim .adam , learningRate = 1e-2 )
279- )
280- self ._test_rosenbrock (
281- lambda params : optim .Adam (params , lr = 1e-2 , weight_decay = 1e-2 ),
282- wrap_old_fn (old_optim .adam , learningRate = 1e-2 , weightDecay = 1e-2 )
283- )
284228 self ._test_basic_cases (
285229 lambda weight , bias : optim .Adam ([weight , bias ], lr = 1e-3 )
286230 )
@@ -310,18 +254,6 @@ def test_sparse_adam(self):
310254 optim .SparseAdam (None , lr = 1e-2 , betas = (1.0 , 0.0 ))
311255
312256 def test_adadelta (self ):
313- self ._test_rosenbrock (
314- lambda params : optim .Adadelta (params ),
315- wrap_old_fn (old_optim .adadelta )
316- )
317- self ._test_rosenbrock (
318- lambda params : optim .Adadelta (params , rho = 0.95 ),
319- wrap_old_fn (old_optim .adadelta , rho = 0.95 )
320- )
321- self ._test_rosenbrock (
322- lambda params : optim .Adadelta (params , weight_decay = 1e-2 ),
323- wrap_old_fn (old_optim .adadelta , weightDecay = 1e-2 )
324- )
325257 self ._test_basic_cases (
326258 lambda weight , bias : optim .Adadelta ([weight , bias ])
327259 )
@@ -333,18 +265,6 @@ def test_adadelta(self):
333265 optim .Adadelta (None , lr = 1e-2 , rho = 1.1 )
334266
335267 def test_adagrad (self ):
336- self ._test_rosenbrock (
337- lambda params : optim .Adagrad (params , lr = 1e-1 ),
338- wrap_old_fn (old_optim .adagrad , learningRate = 1e-1 )
339- )
340- self ._test_rosenbrock (
341- lambda params : optim .Adagrad (params , lr = 1e-1 , lr_decay = 1e-3 ),
342- wrap_old_fn (old_optim .adagrad , learningRate = 1e-1 , learningRateDecay = 1e-3 )
343- )
344- self ._test_rosenbrock (
345- lambda params : optim .Adagrad (params , lr = 1e-1 , weight_decay = 1e-2 ),
346- wrap_old_fn (old_optim .adagrad , learningRate = 1e-1 , weightDecay = 1e-2 )
347- )
348268 self ._test_basic_cases (
349269 lambda weight , bias : optim .Adagrad ([weight , bias ], lr = 1e-1 )
350270 )
@@ -367,18 +287,6 @@ def test_adagrad_sparse(self):
367287
368288 @skipIfRocm
369289 def test_adamax (self ):
370- self ._test_rosenbrock (
371- lambda params : optim .Adamax (params , lr = 1e-1 ),
372- wrap_old_fn (old_optim .adamax , learningRate = 1e-1 )
373- )
374- self ._test_rosenbrock (
375- lambda params : optim .Adamax (params , lr = 1e-1 , weight_decay = 1e-2 ),
376- wrap_old_fn (old_optim .adamax , learningRate = 1e-1 , weightDecay = 1e-2 )
377- )
378- self ._test_rosenbrock (
379- lambda params : optim .Adamax (params , lr = 1e-1 , betas = (0.95 , 0.998 )),
380- wrap_old_fn (old_optim .adamax , learningRate = 1e-1 , beta1 = 0.95 , beta2 = 0.998 )
381- )
382290 self ._test_basic_cases (
383291 lambda weight , bias : optim .Adamax ([weight , bias ], lr = 1e-1 )
384292 )
@@ -391,18 +299,6 @@ def test_adamax(self):
391299 optim .Adamax (None , lr = 1e-2 , betas = (0.0 , 1.0 ))
392300
393301 def test_rmsprop (self ):
394- self ._test_rosenbrock (
395- lambda params : optim .RMSprop (params , lr = 1e-2 ),
396- wrap_old_fn (old_optim .rmsprop , learningRate = 1e-2 )
397- )
398- self ._test_rosenbrock (
399- lambda params : optim .RMSprop (params , lr = 1e-2 , weight_decay = 1e-2 ),
400- wrap_old_fn (old_optim .rmsprop , learningRate = 1e-2 , weightDecay = 1e-2 )
401- )
402- self ._test_rosenbrock (
403- lambda params : optim .RMSprop (params , lr = 1e-2 , alpha = 0.95 ),
404- wrap_old_fn (old_optim .rmsprop , learningRate = 1e-2 , alpha = 0.95 )
405- )
406302 self ._test_basic_cases (
407303 lambda weight , bias : optim .RMSprop ([weight , bias ], lr = 1e-2 )
408304 )
@@ -415,18 +311,6 @@ def test_rmsprop(self):
415311 optim .RMSprop (None , lr = 1e-2 , momentum = - 1.0 )
416312
417313 def test_asgd (self ):
418- self ._test_rosenbrock (
419- lambda params : optim .ASGD (params , lr = 1e-3 ),
420- wrap_old_fn (old_optim .asgd , eta0 = 1e-3 )
421- )
422- self ._test_rosenbrock (
423- lambda params : optim .ASGD (params , lr = 1e-3 , alpha = 0.8 ),
424- wrap_old_fn (old_optim .asgd , eta0 = 1e-3 , alpha = 0.8 )
425- )
426- self ._test_rosenbrock (
427- lambda params : optim .ASGD (params , lr = 1e-3 , t0 = 1e3 ),
428- wrap_old_fn (old_optim .asgd , eta0 = 1e-3 , t0 = 1e3 )
429- )
430314 self ._test_basic_cases (
431315 lambda weight , bias : optim .ASGD ([weight , bias ], lr = 1e-3 , t0 = 100 )
432316 )
@@ -440,18 +324,6 @@ def test_asgd(self):
440324
441325 @skipIfRocm
442326 def test_rprop (self ):
443- self ._test_rosenbrock (
444- lambda params : optim .Rprop (params , lr = 1e-3 ),
445- wrap_old_fn (old_optim .rprop , stepsize = 1e-3 )
446- )
447- self ._test_rosenbrock (
448- lambda params : optim .Rprop (params , lr = 1e-3 , etas = (0.6 , 1.1 )),
449- wrap_old_fn (old_optim .rprop , stepsize = 1e-3 , etaminus = 0.6 , etaplus = 1.1 )
450- )
451- self ._test_rosenbrock (
452- lambda params : optim .Rprop (params , lr = 1e-3 , step_sizes = (1e-4 , 3 )),
453- wrap_old_fn (old_optim .rprop , stepsize = 1e-3 , stepsizemin = 1e-4 , stepsizemax = 3 )
454- )
455327 self ._test_basic_cases (
456328 lambda weight , bias : optim .Rprop ([weight , bias ], lr = 1e-3 )
457329 )
@@ -464,14 +336,6 @@ def test_rprop(self):
464336 optim .Rprop (None , lr = 1e-2 , etas = (1.0 , 0.5 ))
465337
466338 def test_lbfgs (self ):
467- self ._test_rosenbrock (
468- lambda params : optim .LBFGS (params ),
469- wrap_old_fn (old_optim .lbfgs )
470- )
471- self ._test_rosenbrock (
472- lambda params : optim .LBFGS (params , lr = 5e-2 , max_iter = 5 ),
473- wrap_old_fn (old_optim .lbfgs , learningRate = 5e-2 , maxIter = 5 )
474- )
475339 self ._test_basic_cases (
476340 lambda weight , bias : optim .LBFGS ([weight , bias ]),
477341 ignore_multidevice = True
0 commit comments