Skip to content

Commit 17f76f9

Browse files
guol-fnstfacebook-github-bot
authored andcommitted
Verbose param for schedulers that don't have it #38726 (#41580)
Summary: Verbose param for schedulers that don't have it #38726 Pull Request resolved: #41580 Reviewed By: izdeby Differential Revision: D22671163 Pulled By: vincentqb fbshipit-source-id: 53a6c9e929141d411b6846bc25f3fe7f46fdf3be
1 parent 37e7f0c commit 17f76f9

File tree

1 file changed

+62
-25
lines changed

1 file changed

+62
-25
lines changed

torch/optim/lr_scheduler.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
class _LRScheduler(object):
2525

26-
def __init__(self, optimizer, last_epoch=-1):
26+
def __init__(self, optimizer, last_epoch=-1, verbose=False):
2727

2828
# Attach optimizer
2929
if not isinstance(optimizer, Optimizer):
@@ -74,6 +74,7 @@ def wrapper(*args, **kwargs):
7474
self.optimizer.step = with_counter(self.optimizer.step)
7575
self.optimizer._step_count = 0
7676
self._step_count = 0
77+
self.verbose = verbose
7778

7879
self.step()
7980

@@ -103,6 +104,18 @@ def get_lr(self):
103104
# Compute learning rate using chainable form of the scheduler
104105
raise NotImplementedError
105106

107+
def print_lr(self, is_verbose, group, lr, epoch=None):
108+
"""Display the current learning rate.
109+
"""
110+
if is_verbose:
111+
if epoch is None:
112+
print('Adjusting learning rate'
113+
' of group {} to {:.4e}.'.format(group, lr))
114+
else:
115+
print('Epoch {:5d}: adjusting learning rate'
116+
' of group {} to {:.4e}.'.format(epoch, group, lr))
117+
118+
106119
def step(self, epoch=None):
107120
# Raise a warning if old pattern is detected
108121
# https://github.com/pytorch/pytorch/issues/20124
@@ -147,8 +160,10 @@ def __exit__(self, type, value, traceback):
147160
else:
148161
values = self.get_lr()
149162

150-
for param_group, lr in zip(self.optimizer.param_groups, values):
163+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
164+
param_group, lr = data
151165
param_group['lr'] = lr
166+
self.print_lr(self.verbose, i, lr, epoch)
152167

153168
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
154169

@@ -163,6 +178,8 @@ class LambdaLR(_LRScheduler):
163178
factor given an integer parameter epoch, or a list of such
164179
functions, one for each group in optimizer.param_groups.
165180
last_epoch (int): The index of last epoch. Default: -1.
181+
verbose (bool): If ``True``, prints a message to stdout for
182+
each update. Default: ``False``.
166183
167184
Example:
168185
>>> # Assuming optimizer has two groups.
@@ -175,7 +192,7 @@ class LambdaLR(_LRScheduler):
175192
>>> scheduler.step()
176193
"""
177194

178-
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
195+
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
179196
self.optimizer = optimizer
180197

181198
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
@@ -186,7 +203,7 @@ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
186203
len(optimizer.param_groups), len(lr_lambda)))
187204
self.lr_lambdas = list(lr_lambda)
188205
self.last_epoch = last_epoch
189-
super(LambdaLR, self).__init__(optimizer, last_epoch)
206+
super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)
190207

191208
def state_dict(self):
192209
"""Returns the state of the scheduler as a :class:`dict`.
@@ -245,6 +262,8 @@ class MultiplicativeLR(_LRScheduler):
245262
factor given an integer parameter epoch, or a list of such
246263
functions, one for each group in optimizer.param_groups.
247264
last_epoch (int): The index of last epoch. Default: -1.
265+
verbose (bool): If ``True``, prints a message to stdout for
266+
each update. Default: ``False``.
248267
249268
Example:
250269
>>> lmbda = lambda epoch: 0.95
@@ -255,7 +274,7 @@ class MultiplicativeLR(_LRScheduler):
255274
>>> scheduler.step()
256275
"""
257276

258-
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
277+
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
259278
self.optimizer = optimizer
260279

261280
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
@@ -266,7 +285,7 @@ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
266285
len(optimizer.param_groups), len(lr_lambda)))
267286
self.lr_lambdas = list(lr_lambda)
268287
self.last_epoch = last_epoch
269-
super(MultiplicativeLR, self).__init__(optimizer, last_epoch)
288+
super(MultiplicativeLR, self).__init__(optimizer, last_epoch, verbose)
270289

271290
def state_dict(self):
272291
"""Returns the state of the scheduler as a :class:`dict`.
@@ -326,6 +345,8 @@ class StepLR(_LRScheduler):
326345
gamma (float): Multiplicative factor of learning rate decay.
327346
Default: 0.1.
328347
last_epoch (int): The index of last epoch. Default: -1.
348+
verbose (bool): If ``True``, prints a message to stdout for
349+
each update. Default: ``False``.
329350
330351
Example:
331352
>>> # Assuming optimizer uses lr = 0.05 for all groups
@@ -340,10 +361,10 @@ class StepLR(_LRScheduler):
340361
>>> scheduler.step()
341362
"""
342363

343-
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
364+
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
344365
self.step_size = step_size
345366
self.gamma = gamma
346-
super(StepLR, self).__init__(optimizer, last_epoch)
367+
super(StepLR, self).__init__(optimizer, last_epoch, verbose)
347368

348369
def get_lr(self):
349370
if not self._get_lr_called_within_step:
@@ -372,6 +393,8 @@ class MultiStepLR(_LRScheduler):
372393
gamma (float): Multiplicative factor of learning rate decay.
373394
Default: 0.1.
374395
last_epoch (int): The index of last epoch. Default: -1.
396+
verbose (bool): If ``True``, prints a message to stdout for
397+
each update. Default: ``False``.
375398
376399
Example:
377400
>>> # Assuming optimizer uses lr = 0.05 for all groups
@@ -385,10 +408,10 @@ class MultiStepLR(_LRScheduler):
385408
>>> scheduler.step()
386409
"""
387410

388-
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
411+
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
389412
self.milestones = Counter(milestones)
390413
self.gamma = gamma
391-
super(MultiStepLR, self).__init__(optimizer, last_epoch)
414+
super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)
392415

393416
def get_lr(self):
394417
if not self._get_lr_called_within_step:
@@ -414,11 +437,13 @@ class ExponentialLR(_LRScheduler):
414437
optimizer (Optimizer): Wrapped optimizer.
415438
gamma (float): Multiplicative factor of learning rate decay.
416439
last_epoch (int): The index of last epoch. Default: -1.
440+
verbose (bool): If ``True``, prints a message to stdout for
441+
each update. Default: ``False``.
417442
"""
418443

419-
def __init__(self, optimizer, gamma, last_epoch=-1):
444+
def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
420445
self.gamma = gamma
421-
super(ExponentialLR, self).__init__(optimizer, last_epoch)
446+
super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)
422447

423448
def get_lr(self):
424449
if not self._get_lr_called_within_step:
@@ -468,15 +493,17 @@ class CosineAnnealingLR(_LRScheduler):
468493
T_max (int): Maximum number of iterations.
469494
eta_min (float): Minimum learning rate. Default: 0.
470495
last_epoch (int): The index of last epoch. Default: -1.
496+
verbose (bool): If ``True``, prints a message to stdout for
497+
each update. Default: ``False``.
471498
472499
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
473500
https://arxiv.org/abs/1608.03983
474501
"""
475502

476-
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
503+
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
477504
self.T_max = T_max
478505
self.eta_min = eta_min
479-
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
506+
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)
480507

481508
def get_lr(self):
482509
if not self._get_lr_called_within_step:
@@ -522,8 +549,6 @@ class ReduceLROnPlateau(object):
522549
with no improvement, and will only decrease the LR after the
523550
3rd epoch if the loss still hasn't improved then.
524551
Default: 10.
525-
verbose (bool): If ``True``, prints a message to stdout for
526-
each update. Default: ``False``.
527552
threshold (float): Threshold for measuring the new optimum,
528553
to only focus on significant changes. Default: 1e-4.
529554
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
@@ -539,6 +564,8 @@ class ReduceLROnPlateau(object):
539564
eps (float): Minimal decay applied to lr. If the difference
540565
between new and old lr is smaller than eps, the update is
541566
ignored. Default: 1e-8.
567+
verbose (bool): If ``True``, prints a message to stdout for
568+
each update. Default: ``False``.
542569
543570
Example:
544571
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
@@ -551,8 +578,8 @@ class ReduceLROnPlateau(object):
551578
"""
552579

553580
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
554-
verbose=False, threshold=1e-4, threshold_mode='rel',
555-
cooldown=0, min_lr=0, eps=1e-8):
581+
threshold=1e-4, threshold_mode='rel', cooldown=0,
582+
min_lr=0, eps=1e-8, verbose=False):
556583

557584
if factor >= 1.0:
558585
raise ValueError('Factor should be < 1.0.')
@@ -749,6 +776,8 @@ class CyclicLR(_LRScheduler):
749776
number of *batches* computed, not the total number of epochs computed.
750777
When last_epoch=-1, the schedule is started from the beginning.
751778
Default: -1
779+
verbose (bool): If ``True``, prints a message to stdout for
780+
each update. Default: ``False``.
752781
753782
Example:
754783
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
@@ -777,7 +806,8 @@ def __init__(self,
777806
cycle_momentum=True,
778807
base_momentum=0.8,
779808
max_momentum=0.9,
780-
last_epoch=-1):
809+
last_epoch=-1,
810+
verbose=False):
781811

782812
# Attach optimizer
783813
if not isinstance(optimizer, Optimizer):
@@ -830,7 +860,7 @@ def __init__(self,
830860
self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
831861
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
832862

833-
super(CyclicLR, self).__init__(optimizer, last_epoch)
863+
super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
834864
self.base_lrs = base_lrs
835865

836866
def _format_param(self, name, optimizer, param):
@@ -917,12 +947,14 @@ class CosineAnnealingWarmRestarts(_LRScheduler):
917947
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
918948
eta_min (float, optional): Minimum learning rate. Default: 0.
919949
last_epoch (int, optional): The index of last epoch. Default: -1.
950+
verbose (bool): If ``True``, prints a message to stdout for
951+
each update. Default: ``False``.
920952
921953
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
922954
https://arxiv.org/abs/1608.03983
923955
"""
924956

925-
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
957+
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
926958
if T_0 <= 0 or not isinstance(T_0, int):
927959
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
928960
if T_mult < 1 or not isinstance(T_mult, int):
@@ -932,7 +964,7 @@ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
932964
self.T_mult = T_mult
933965
self.eta_min = eta_min
934966

935-
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
967+
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)
936968

937969
self.T_cur = self.last_epoch
938970

@@ -1008,8 +1040,10 @@ def __exit__(self, type, value, traceback):
10081040
return self
10091041

10101042
with _enable_get_lr_call(self):
1011-
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
1043+
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
1044+
param_group, lr = data
10121045
param_group['lr'] = lr
1046+
self.print_lr(self.verbose, i, lr, epoch)
10131047

10141048
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
10151049

@@ -1090,6 +1124,8 @@ class OneCycleLR(_LRScheduler):
10901124
number of *batches* computed, not the total number of epochs computed.
10911125
When last_epoch=-1, the schedule is started from the beginning.
10921126
Default: -1
1127+
verbose (bool): If ``True``, prints a message to stdout for
1128+
each update. Default: ``False``.
10931129
10941130
Example:
10951131
>>> data_loader = torch.utils.data.DataLoader(...)
@@ -1117,7 +1153,8 @@ def __init__(self,
11171153
max_momentum=0.95,
11181154
div_factor=25.,
11191155
final_div_factor=1e4,
1120-
last_epoch=-1):
1156+
last_epoch=-1,
1157+
verbose=False):
11211158

11221159
# Validate optimizer
11231160
if not isinstance(optimizer, Optimizer):
@@ -1179,7 +1216,7 @@ def __init__(self,
11791216
group['max_momentum'] = m_momentum
11801217
group['base_momentum'] = b_momentum
11811218

1182-
super(OneCycleLR, self).__init__(optimizer, last_epoch)
1219+
super(OneCycleLR, self).__init__(optimizer, last_epoch, verbose)
11831220

11841221
def _format_param(self, name, optimizer, param):
11851222
"""Return correctly formatted lr/momentum for each param group."""

0 commit comments

Comments
 (0)