2323
2424class _LRScheduler (object ):
2525
26- def __init__ (self , optimizer , last_epoch = - 1 ):
26+ def __init__ (self , optimizer , last_epoch = - 1 , verbose = False ):
2727
2828 # Attach optimizer
2929 if not isinstance (optimizer , Optimizer ):
@@ -74,6 +74,7 @@ def wrapper(*args, **kwargs):
7474 self .optimizer .step = with_counter (self .optimizer .step )
7575 self .optimizer ._step_count = 0
7676 self ._step_count = 0
77+ self .verbose = verbose
7778
7879 self .step ()
7980
@@ -103,6 +104,18 @@ def get_lr(self):
103104 # Compute learning rate using chainable form of the scheduler
104105 raise NotImplementedError
105106
107+ def print_lr (self , is_verbose , group , lr , epoch = None ):
108+ """Display the current learning rate.
109+ """
110+ if is_verbose :
111+ if epoch is None :
112+ print ('Adjusting learning rate'
113+ ' of group {} to {:.4e}.' .format (group , lr ))
114+ else :
115+ print ('Epoch {:5d}: adjusting learning rate'
116+ ' of group {} to {:.4e}.' .format (epoch , group , lr ))
117+
118+
106119 def step (self , epoch = None ):
107120 # Raise a warning if old pattern is detected
108121 # https://github.com/pytorch/pytorch/issues/20124
@@ -147,8 +160,10 @@ def __exit__(self, type, value, traceback):
147160 else :
148161 values = self .get_lr ()
149162
150- for param_group , lr in zip (self .optimizer .param_groups , values ):
163+ for i , data in enumerate (zip (self .optimizer .param_groups , values )):
164+ param_group , lr = data
151165 param_group ['lr' ] = lr
166+ self .print_lr (self .verbose , i , lr , epoch )
152167
153168 self ._last_lr = [group ['lr' ] for group in self .optimizer .param_groups ]
154169
@@ -163,6 +178,8 @@ class LambdaLR(_LRScheduler):
163178 factor given an integer parameter epoch, or a list of such
164179 functions, one for each group in optimizer.param_groups.
165180 last_epoch (int): The index of last epoch. Default: -1.
181+ verbose (bool): If ``True``, prints a message to stdout for
182+ each update. Default: ``False``.
166183
167184 Example:
168185 >>> # Assuming optimizer has two groups.
@@ -175,7 +192,7 @@ class LambdaLR(_LRScheduler):
175192 >>> scheduler.step()
176193 """
177194
178- def __init__ (self , optimizer , lr_lambda , last_epoch = - 1 ):
195+ def __init__ (self , optimizer , lr_lambda , last_epoch = - 1 , verbose = False ):
179196 self .optimizer = optimizer
180197
181198 if not isinstance (lr_lambda , list ) and not isinstance (lr_lambda , tuple ):
@@ -186,7 +203,7 @@ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
186203 len (optimizer .param_groups ), len (lr_lambda )))
187204 self .lr_lambdas = list (lr_lambda )
188205 self .last_epoch = last_epoch
189- super (LambdaLR , self ).__init__ (optimizer , last_epoch )
206+ super (LambdaLR , self ).__init__ (optimizer , last_epoch , verbose )
190207
191208 def state_dict (self ):
192209 """Returns the state of the scheduler as a :class:`dict`.
@@ -245,6 +262,8 @@ class MultiplicativeLR(_LRScheduler):
245262 factor given an integer parameter epoch, or a list of such
246263 functions, one for each group in optimizer.param_groups.
247264 last_epoch (int): The index of last epoch. Default: -1.
265+ verbose (bool): If ``True``, prints a message to stdout for
266+ each update. Default: ``False``.
248267
249268 Example:
250269 >>> lmbda = lambda epoch: 0.95
@@ -255,7 +274,7 @@ class MultiplicativeLR(_LRScheduler):
255274 >>> scheduler.step()
256275 """
257276
258- def __init__ (self , optimizer , lr_lambda , last_epoch = - 1 ):
277+ def __init__ (self , optimizer , lr_lambda , last_epoch = - 1 , verbose = False ):
259278 self .optimizer = optimizer
260279
261280 if not isinstance (lr_lambda , list ) and not isinstance (lr_lambda , tuple ):
@@ -266,7 +285,7 @@ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
266285 len (optimizer .param_groups ), len (lr_lambda )))
267286 self .lr_lambdas = list (lr_lambda )
268287 self .last_epoch = last_epoch
269- super (MultiplicativeLR , self ).__init__ (optimizer , last_epoch )
288+ super (MultiplicativeLR , self ).__init__ (optimizer , last_epoch , verbose )
270289
271290 def state_dict (self ):
272291 """Returns the state of the scheduler as a :class:`dict`.
@@ -326,6 +345,8 @@ class StepLR(_LRScheduler):
326345 gamma (float): Multiplicative factor of learning rate decay.
327346 Default: 0.1.
328347 last_epoch (int): The index of last epoch. Default: -1.
348+ verbose (bool): If ``True``, prints a message to stdout for
349+ each update. Default: ``False``.
329350
330351 Example:
331352 >>> # Assuming optimizer uses lr = 0.05 for all groups
@@ -340,10 +361,10 @@ class StepLR(_LRScheduler):
340361 >>> scheduler.step()
341362 """
342363
343- def __init__ (self , optimizer , step_size , gamma = 0.1 , last_epoch = - 1 ):
364+ def __init__ (self , optimizer , step_size , gamma = 0.1 , last_epoch = - 1 , verbose = False ):
344365 self .step_size = step_size
345366 self .gamma = gamma
346- super (StepLR , self ).__init__ (optimizer , last_epoch )
367+ super (StepLR , self ).__init__ (optimizer , last_epoch , verbose )
347368
348369 def get_lr (self ):
349370 if not self ._get_lr_called_within_step :
@@ -372,6 +393,8 @@ class MultiStepLR(_LRScheduler):
372393 gamma (float): Multiplicative factor of learning rate decay.
373394 Default: 0.1.
374395 last_epoch (int): The index of last epoch. Default: -1.
396+ verbose (bool): If ``True``, prints a message to stdout for
397+ each update. Default: ``False``.
375398
376399 Example:
377400 >>> # Assuming optimizer uses lr = 0.05 for all groups
@@ -385,10 +408,10 @@ class MultiStepLR(_LRScheduler):
385408 >>> scheduler.step()
386409 """
387410
388- def __init__ (self , optimizer , milestones , gamma = 0.1 , last_epoch = - 1 ):
411+ def __init__ (self , optimizer , milestones , gamma = 0.1 , last_epoch = - 1 , verbose = False ):
389412 self .milestones = Counter (milestones )
390413 self .gamma = gamma
391- super (MultiStepLR , self ).__init__ (optimizer , last_epoch )
414+ super (MultiStepLR , self ).__init__ (optimizer , last_epoch , verbose )
392415
393416 def get_lr (self ):
394417 if not self ._get_lr_called_within_step :
@@ -414,11 +437,13 @@ class ExponentialLR(_LRScheduler):
414437 optimizer (Optimizer): Wrapped optimizer.
415438 gamma (float): Multiplicative factor of learning rate decay.
416439 last_epoch (int): The index of last epoch. Default: -1.
440+ verbose (bool): If ``True``, prints a message to stdout for
441+ each update. Default: ``False``.
417442 """
418443
419- def __init__ (self , optimizer , gamma , last_epoch = - 1 ):
444+ def __init__ (self , optimizer , gamma , last_epoch = - 1 , verbose = False ):
420445 self .gamma = gamma
421- super (ExponentialLR , self ).__init__ (optimizer , last_epoch )
446+ super (ExponentialLR , self ).__init__ (optimizer , last_epoch , verbose )
422447
423448 def get_lr (self ):
424449 if not self ._get_lr_called_within_step :
@@ -468,15 +493,17 @@ class CosineAnnealingLR(_LRScheduler):
468493 T_max (int): Maximum number of iterations.
469494 eta_min (float): Minimum learning rate. Default: 0.
470495 last_epoch (int): The index of last epoch. Default: -1.
496+ verbose (bool): If ``True``, prints a message to stdout for
497+ each update. Default: ``False``.
471498
472499 .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
473500 https://arxiv.org/abs/1608.03983
474501 """
475502
476- def __init__ (self , optimizer , T_max , eta_min = 0 , last_epoch = - 1 ):
503+ def __init__ (self , optimizer , T_max , eta_min = 0 , last_epoch = - 1 , verbose = False ):
477504 self .T_max = T_max
478505 self .eta_min = eta_min
479- super (CosineAnnealingLR , self ).__init__ (optimizer , last_epoch )
506+ super (CosineAnnealingLR , self ).__init__ (optimizer , last_epoch , verbose )
480507
481508 def get_lr (self ):
482509 if not self ._get_lr_called_within_step :
@@ -522,8 +549,6 @@ class ReduceLROnPlateau(object):
522549 with no improvement, and will only decrease the LR after the
523550 3rd epoch if the loss still hasn't improved then.
524551 Default: 10.
525- verbose (bool): If ``True``, prints a message to stdout for
526- each update. Default: ``False``.
527552 threshold (float): Threshold for measuring the new optimum,
528553 to only focus on significant changes. Default: 1e-4.
529554 threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
@@ -539,6 +564,8 @@ class ReduceLROnPlateau(object):
539564 eps (float): Minimal decay applied to lr. If the difference
540565 between new and old lr is smaller than eps, the update is
541566 ignored. Default: 1e-8.
567+ verbose (bool): If ``True``, prints a message to stdout for
568+ each update. Default: ``False``.
542569
543570 Example:
544571 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
@@ -551,8 +578,8 @@ class ReduceLROnPlateau(object):
551578 """
552579
553580 def __init__ (self , optimizer , mode = 'min' , factor = 0.1 , patience = 10 ,
554- verbose = False , threshold = 1e-4 , threshold_mode = 'rel' ,
555- cooldown = 0 , min_lr = 0 , eps = 1e-8 ):
581+ threshold = 1e-4 , threshold_mode = 'rel' , cooldown = 0 ,
582+ min_lr = 0 , eps = 1e-8 , verbose = False ):
556583
557584 if factor >= 1.0 :
558585 raise ValueError ('Factor should be < 1.0.' )
@@ -749,6 +776,8 @@ class CyclicLR(_LRScheduler):
749776 number of *batches* computed, not the total number of epochs computed.
750777 When last_epoch=-1, the schedule is started from the beginning.
751778 Default: -1
779+ verbose (bool): If ``True``, prints a message to stdout for
780+ each update. Default: ``False``.
752781
753782 Example:
754783 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
@@ -777,7 +806,8 @@ def __init__(self,
777806 cycle_momentum = True ,
778807 base_momentum = 0.8 ,
779808 max_momentum = 0.9 ,
780- last_epoch = - 1 ):
809+ last_epoch = - 1 ,
810+ verbose = False ):
781811
782812 # Attach optimizer
783813 if not isinstance (optimizer , Optimizer ):
@@ -830,7 +860,7 @@ def __init__(self,
830860 self .base_momentums = list (map (lambda group : group ['momentum' ], optimizer .param_groups ))
831861 self .max_momentums = self ._format_param ('max_momentum' , optimizer , max_momentum )
832862
833- super (CyclicLR , self ).__init__ (optimizer , last_epoch )
863+ super (CyclicLR , self ).__init__ (optimizer , last_epoch , verbose )
834864 self .base_lrs = base_lrs
835865
836866 def _format_param (self , name , optimizer , param ):
@@ -917,12 +947,14 @@ class CosineAnnealingWarmRestarts(_LRScheduler):
917947 T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
918948 eta_min (float, optional): Minimum learning rate. Default: 0.
919949 last_epoch (int, optional): The index of last epoch. Default: -1.
950+ verbose (bool): If ``True``, prints a message to stdout for
951+ each update. Default: ``False``.
920952
921953 .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
922954 https://arxiv.org/abs/1608.03983
923955 """
924956
925- def __init__ (self , optimizer , T_0 , T_mult = 1 , eta_min = 0 , last_epoch = - 1 ):
957+ def __init__ (self , optimizer , T_0 , T_mult = 1 , eta_min = 0 , last_epoch = - 1 , verbose = False ):
926958 if T_0 <= 0 or not isinstance (T_0 , int ):
927959 raise ValueError ("Expected positive integer T_0, but got {}" .format (T_0 ))
928960 if T_mult < 1 or not isinstance (T_mult , int ):
@@ -932,7 +964,7 @@ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
932964 self .T_mult = T_mult
933965 self .eta_min = eta_min
934966
935- super (CosineAnnealingWarmRestarts , self ).__init__ (optimizer , last_epoch )
967+ super (CosineAnnealingWarmRestarts , self ).__init__ (optimizer , last_epoch , verbose )
936968
937969 self .T_cur = self .last_epoch
938970
@@ -1008,8 +1040,10 @@ def __exit__(self, type, value, traceback):
10081040 return self
10091041
10101042 with _enable_get_lr_call (self ):
1011- for param_group , lr in zip (self .optimizer .param_groups , self .get_lr ()):
1043+ for i , data in enumerate (zip (self .optimizer .param_groups , self .get_lr ())):
1044+ param_group , lr = data
10121045 param_group ['lr' ] = lr
1046+ self .print_lr (self .verbose , i , lr , epoch )
10131047
10141048 self ._last_lr = [group ['lr' ] for group in self .optimizer .param_groups ]
10151049
@@ -1090,6 +1124,8 @@ class OneCycleLR(_LRScheduler):
10901124 number of *batches* computed, not the total number of epochs computed.
10911125 When last_epoch=-1, the schedule is started from the beginning.
10921126 Default: -1
1127+ verbose (bool): If ``True``, prints a message to stdout for
1128+ each update. Default: ``False``.
10931129
10941130 Example:
10951131 >>> data_loader = torch.utils.data.DataLoader(...)
@@ -1117,7 +1153,8 @@ def __init__(self,
11171153 max_momentum = 0.95 ,
11181154 div_factor = 25. ,
11191155 final_div_factor = 1e4 ,
1120- last_epoch = - 1 ):
1156+ last_epoch = - 1 ,
1157+ verbose = False ):
11211158
11221159 # Validate optimizer
11231160 if not isinstance (optimizer , Optimizer ):
@@ -1179,7 +1216,7 @@ def __init__(self,
11791216 group ['max_momentum' ] = m_momentum
11801217 group ['base_momentum' ] = b_momentum
11811218
1182- super (OneCycleLR , self ).__init__ (optimizer , last_epoch )
1219+ super (OneCycleLR , self ).__init__ (optimizer , last_epoch , verbose )
11831220
11841221 def _format_param (self , name , optimizer , param ):
11851222 """Return correctly formatted lr/momentum for each param group."""
0 commit comments