Skip to content

Commit e9ef20e

Browse files
Kaixhinsoumith
authored andcommitted
Add Cosine Annealing LR Scheduler (#3311)
* Add Cosine Annealing LR Scheduler * Update eta_min in tests to prevent numerical mistakes * Use non-zero min_eta in test_cos_anneal_lr
1 parent 847c56a commit e9ef20e

File tree

3 files changed

+69
-18
lines changed

3 files changed

+69
-18
lines changed

docs/source/optim.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ How to adjust Learning Rate
130130
---------------------------
131131

132132
:mod:`torch.optim.lr_scheduler` provides several methods to adjust the learning
133-
rate based on the number of epoches. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`
133+
rate based on the number of epochs. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`
134134
allows dynamic learning rate reducing based on some validation measurements.
135135

136136
.. autoclass:: torch.optim.lr_scheduler.LambdaLR
@@ -141,5 +141,7 @@ allows dynamic learning rate reducing based on some validation measurements.
141141
:members:
142142
.. autoclass:: torch.optim.lr_scheduler.ExponentialLR
143143
:members:
144+
.. autoclass:: torch.optim.lr_scheduler.CosineAnnealingLR
145+
:members:
144146
.. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau
145147
:members:

test/test_optim.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import unittest
23
import functools
34
from copy import deepcopy
@@ -8,7 +9,7 @@
89
from torch.optim import SGD
910
from torch.autograd import Variable
1011
from torch import sparse
11-
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
12+
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
1213
from common import TestCase, run_tests
1314

1415

@@ -460,117 +461,127 @@ def test_step_lr(self):
460461
# lr = 0.05 if epoch < 3
461462
# lr = 0.005 if 30 <= epoch < 6
462463
# lr = 0.0005 if epoch >= 9
464+
epochs = 10
463465
single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
464-
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
466+
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
465467
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
466-
epochs = 10
467468
self._test(scheduler, targets, epochs)
468469

469470
def test_multi_step_lr(self):
470471
# lr = 0.05 if epoch < 2
471472
# lr = 0.005 if 2 <= epoch < 5
472473
# lr = 0.0005 if epoch < 9
473474
# lr = 0.00005 if epoch >= 9
475+
epochs = 10
474476
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
475-
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
477+
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
476478
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
477-
epochs = 10
478479
self._test(scheduler, targets, epochs)
479480

480481
def test_exp_lr(self):
481-
single_targets = [0.05 * (0.9 ** x) for x in range(10)]
482-
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
482+
epochs = 10
483+
single_targets = [0.05 * (0.9 ** x) for x in range(epochs)]
484+
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
483485
scheduler = ExponentialLR(self.opt, gamma=0.9)
486+
self._test(scheduler, targets, epochs)
487+
488+
def test_cos_anneal_lr(self):
484489
epochs = 10
490+
eta_min = 1e-10
491+
single_targets = [eta_min + (0.05 - eta_min) *
492+
(1 + math.cos(x / epochs * math.pi)) / 2
493+
for x in range(epochs)]
494+
targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
495+
scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
485496
self._test(scheduler, targets, epochs)
486497

487498
def test_reduce_lr_on_plateau1(self):
499+
epochs = 10
488500
for param_group in self.opt.param_groups:
489501
param_group['lr'] = 0.5
490502
targets = [[0.5] * 20]
491503
metrics = [10 - i * 0.0167 for i in range(20)]
492504
scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
493505
threshold=0.01, patience=5, cooldown=5)
494-
epochs = 10
495506
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
496507

497508
def test_reduce_lr_on_plateau2(self):
509+
epochs = 22
498510
for param_group in self.opt.param_groups:
499511
param_group['lr'] = 0.5
500512
targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
501513
metrics = [10 - i * 0.0165 for i in range(22)]
502514
scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
503515
mode='min', threshold=0.1)
504-
epochs = 22
505516
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
506517

507518
def test_reduce_lr_on_plateau3(self):
519+
epochs = 22
508520
for param_group in self.opt.param_groups:
509521
param_group['lr'] = 0.5
510522
targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
511523
metrics = [-0.8] * 2 + [-0.234] * 20
512524
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
513525
threshold_mode='abs')
514-
epochs = 22
515526
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
516527

517528
def test_reduce_lr_on_plateau4(self):
529+
epochs = 20
518530
for param_group in self.opt.param_groups:
519531
param_group['lr'] = 0.5
520532
targets = [[0.5] * 20]
521533
metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25
522534
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3,
523535
threshold_mode='rel', threshold=0.1)
524-
epochs = 20
525536
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
526537

527538
def test_reduce_lr_on_plateau5(self):
539+
epochs = 20
528540
for param_group in self.opt.param_groups:
529541
param_group['lr'] = 0.5
530542
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
531543
metrics = [1.5 * (1.005 ** i) for i in range(20)]
532544
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel',
533545
threshold=0.1, patience=5, cooldown=5)
534-
epochs = 20
535546
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
536547

537548
def test_reduce_lr_on_plateau6(self):
549+
epochs = 20
538550
for param_group in self.opt.param_groups:
539551
param_group['lr'] = 0.5
540552
targets = [[0.5] * 20]
541553
metrics = [1.5 * (0.85 ** i) for i in range(20)]
542554
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
543555
threshold=0.1)
544-
epochs = 20
545556
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
546557

547558
def test_reduce_lr_on_plateau7(self):
559+
epochs = 20
548560
for param_group in self.opt.param_groups:
549561
param_group['lr'] = 0.5
550562
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
551563
metrics = [1] * 7 + [0.6] + [0.5] * 12
552564
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
553565
threshold=0.1, patience=5, cooldown=5)
554-
epochs = 20
555566
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
556567

557568
def test_reduce_lr_on_plateau8(self):
569+
epochs = 20
558570
for param_group in self.opt.param_groups:
559571
param_group['lr'] = 0.5
560572
targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
561573
metrics = [1.5 * (1.005 ** i) for i in range(20)]
562574
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3],
563575
threshold=0.1, patience=5, cooldown=5)
564-
epochs = 20
565576
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
566577

567578
def test_lambda_lr(self):
579+
epochs = 10
568580
self.opt.param_groups[0]['lr'] = 0.05
569581
self.opt.param_groups[1]['lr'] = 0.4
570-
targets = [[0.05 * (0.9 ** x) for x in range(10)], [0.4 * (0.8 ** x) for x in range(10)]]
582+
targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]]
571583
scheduler = LambdaLR(self.opt,
572584
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
573-
epochs = 10
574585
self._test(scheduler, targets, epochs)
575586

576587
def _test(self, scheduler, targets, epochs=10):

torch/optim/lr_scheduler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from bisect import bisect_right
23
from .optimizer import Optimizer
34

@@ -160,6 +161,43 @@ def get_lr(self):
160161
for base_lr in self.base_lrs]
161162

162163

164+
class CosineAnnealingLR(_LRScheduler):
165+
"""Set the learning rate of each parameter group using a cosine annealing
166+
schedule, where :math:`\eta_{max}` is set to the initial lr and
167+
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
168+
169+
.. math::
170+
171+
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
172+
\cos(\frac{T_{cur}}{T_{max}}\pi))
173+
174+
When last_epoch=-1, sets initial lr as lr.
175+
176+
It has been proposed in
177+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
178+
implements the cosine annealing part of SGDR, and not the restarts.
179+
180+
Args:
181+
optimizer (Optimizer): Wrapped optimizer.
182+
T_max (int): Maximum number of iterations.
183+
eta_min (float): Minimum learning rate. Default: 0.
184+
last_epoch (int): The index of last epoch. Default: -1.
185+
186+
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
187+
https://arxiv.org/abs/1608.03983
188+
"""
189+
190+
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
191+
self.T_max = T_max
192+
self.eta_min = eta_min
193+
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
194+
195+
def get_lr(self):
196+
return [self.eta_min + (base_lr - self.eta_min) *
197+
(1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2
198+
for base_lr in self.base_lrs]
199+
200+
163201
class ReduceLROnPlateau(object):
164202
"""Reduce learning rate when a metric has stopped improving.
165203
Models often benefit from reducing the learning rate by a factor

0 commit comments

Comments
 (0)