Skip to content

Commit 8ece538

Browse files
vfnfacebook-github-bot
authored andcommitted
Addresses bad behavior with overridden optimizer.step by #20124 (#21460)
Summary: This PR addresses the problem described in the comment: #20203 (comment) and previously coded bad behaviour: - a warning was raised all the times when lr schedulling is initialized Now the code checks that: - on the second call of `lr_scheduler.step`, ensure that `optimizer.step` has been already called, otherwise raise a warning (as it was done in #20203 ) - if optimizer's step is overridden -> raise once another warning to aware user about the new pattern: `opt.step()` -> `lrs.step()` as we can not check this . Now tests check that - at initialization (`lrs = StepLR(...)`)there is no warnings - if we replace `optimizer.step` by something else (similarly to the [code of nvidia/apex](https://github.com/NVIDIA/apex/blob/master/apex/amp/_process_optimizer.py#L287)) there is another warning raised. cc ezyang PS. honestly I would say that there is a lot of overhead introduced for simple warnings. I hope all these checks will be removed in future `1.2.0` or other versions... Pull Request resolved: #21460 Differential Revision: D15701776 Pulled By: ezyang fbshipit-source-id: eac5712b9146d9d3392a30f6339cd33d90c497c7
1 parent 51d0da2 commit 8ece538

File tree

2 files changed

+101
-21
lines changed

2 files changed

+101
-21
lines changed

test/test_optim.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
import math
23
import unittest
34
import functools
@@ -529,7 +530,10 @@ def setUp(self):
529530

530531
def test_old_pattern_warning(self):
531532
epochs = 35
532-
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
533+
with warnings.catch_warnings(record=True) as ws:
534+
warnings.simplefilter("always") # allow any warning to be raised
535+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
536+
self.assertTrue(len(ws) == 0, "No warning should be raised")
533537

534538
def old_pattern():
535539
for e in range(epochs):
@@ -540,7 +544,10 @@ def old_pattern():
540544

541545
def test_old_pattern_warning_with_arg(self):
542546
epochs = 35
543-
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
547+
with warnings.catch_warnings(record=True) as ws:
548+
warnings.simplefilter("always") # allow any warning to be raised
549+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
550+
self.assertTrue(len(ws) == 0, "No warning should be raised")
544551

545552
def old_pattern2():
546553
for e in range(epochs):
@@ -554,7 +561,10 @@ def test_old_pattern_warning_resuming(self):
554561
for i, group in enumerate(self.opt.param_groups):
555562
group['initial_lr'] = 0.01
556563

557-
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
564+
with warnings.catch_warnings(record=True) as ws:
565+
warnings.simplefilter("always") # allow any warning to be raised
566+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
567+
self.assertTrue(len(ws) == 0, "No warning should be raised")
558568

559569
def old_pattern():
560570
for e in range(epochs):
@@ -568,7 +578,10 @@ def test_old_pattern_warning_resuming_with_arg(self):
568578
for i, group in enumerate(self.opt.param_groups):
569579
group['initial_lr'] = 0.01
570580

571-
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
581+
with warnings.catch_warnings(record=True) as ws:
582+
warnings.simplefilter("always") # allow any warning to be raised
583+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
584+
self.assertTrue(len(ws) == 0, "No warning should be raised")
572585

573586
def old_pattern2():
574587
for e in range(epochs):
@@ -577,11 +590,40 @@ def old_pattern2():
577590

578591
self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
579592

580-
def test_new_pattern_no_warning(self):
581-
import warnings
593+
def test_old_pattern_warning_with_overriden_optim_step(self):
594+
epochs = 35
595+
for i, group in enumerate(self.opt.param_groups):
596+
group['initial_lr'] = 0.01
597+
598+
with warnings.catch_warnings(record=True) as ws:
599+
warnings.simplefilter("always") # allow any warning to be raised
600+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
601+
self.assertTrue(len(ws) == 0, "No warning should be raised")
582602

603+
# emulate use-case with optimizer.step overriden
604+
import types
605+
606+
old_step = self.opt.step
607+
608+
def new_step(o, *args, **kwargs):
609+
retval = old_step(*args, **kwargs)
610+
return retval
611+
612+
self.opt.step = types.MethodType(new_step, self.opt)
613+
614+
def old_pattern2():
615+
for e in range(epochs):
616+
scheduler.step(e)
617+
self.opt.step()
618+
619+
self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
620+
621+
def test_new_pattern_no_warning(self):
583622
epochs = 35
584-
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
623+
with warnings.catch_warnings(record=True) as ws:
624+
warnings.simplefilter("always") # allow any warning to be raised
625+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
626+
self.assertTrue(len(ws) == 0, "No warning should be raised")
585627

586628
with warnings.catch_warnings(record=True) as ws:
587629
warnings.simplefilter("always") # allow any warning to be raised
@@ -591,10 +633,11 @@ def test_new_pattern_no_warning(self):
591633
self.assertTrue(len(ws) == 0, "No warning should be raised")
592634

593635
def test_new_pattern_no_warning_with_arg(self):
594-
import warnings
595-
596636
epochs = 35
597-
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
637+
with warnings.catch_warnings(record=True) as ws:
638+
warnings.simplefilter("always") # allow any warning to be raised
639+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
640+
self.assertTrue(len(ws) == 0, "No warning should be raised")
598641

599642
with warnings.catch_warnings(record=True) as ws:
600643
warnings.simplefilter("always") # allow any warning to be raised
@@ -603,6 +646,31 @@ def test_new_pattern_no_warning_with_arg(self):
603646
scheduler.step(e)
604647
self.assertTrue(len(ws) == 0, "No warning should be raised")
605648

649+
def test_new_pattern_no_warning_with_overriden_optim_step(self):
650+
epochs = 35
651+
with warnings.catch_warnings(record=True) as ws:
652+
warnings.simplefilter("always") # allow any warning to be raised
653+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
654+
self.assertTrue(len(ws) == 0, "No warning should be raised")
655+
656+
# emulate use-case with optimizer.step overriden
657+
import types
658+
659+
old_step = self.opt.step
660+
661+
def new_step(o, *args, **kwargs):
662+
retval = old_step(*args, **kwargs)
663+
return retval
664+
665+
self.opt.step = types.MethodType(new_step, self.opt)
666+
667+
def new_pattern():
668+
for e in range(epochs):
669+
self.opt.step()
670+
scheduler.step()
671+
672+
self.assertWarnsRegex(new_pattern, r'`optimizer.step\(\)` has been overridden')
673+
606674
def test_step_lr(self):
607675
# lr = 0.05 if epoch < 3
608676
# lr = 0.005 if 30 <= epoch < 6

torch/optim/lr_scheduler.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ def __init__(self, optimizer, last_epoch=-1):
2929
# Following https://github.com/pytorch/pytorch/issues/20124
3030
# We would like to ensure that `lr_scheduler.step()` is called after
3131
# `optimizer.step()`
32-
def with_counter(func):
32+
def with_counter(func, opt):
3333
@wraps(func)
3434
def wrapper(*args, **kwargs):
35-
wrapper.called += 1
35+
opt._step_count += 1
3636
return func(*args, **kwargs)
37-
wrapper.called = 0
37+
wrapper._with_counter = True
3838
return wrapper
3939

40-
self.optimizer.step = with_counter(self.optimizer.step)
40+
self.optimizer.step = with_counter(self.optimizer.step, self.optimizer)
41+
self.optimizer._step_count = 0
42+
self._step_count = 0
4143
self.step(last_epoch)
4244

4345
def state_dict(self):
@@ -63,13 +65,23 @@ def get_lr(self):
6365
def step(self, epoch=None):
6466
# Raise a warning if old pattern is detected
6567
# https://github.com/pytorch/pytorch/issues/20124
66-
if self.optimizer.step.called < 1:
67-
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
68-
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
69-
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
70-
"will result in PyTorch skipping the first value of the learning rate schedule."
71-
"See more details at "
72-
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
68+
if self._step_count == 1:
69+
if not hasattr(self.optimizer.step, "_with_counter"):
70+
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
71+
"initialization. Please, make sure to call `optimizer.step()` before "
72+
"`lr_scheduler.step()`. See more details at "
73+
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
74+
75+
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
76+
elif self.optimizer._step_count < 1:
77+
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
78+
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
79+
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
80+
"will result in PyTorch skipping the first value of the learning rate schedule."
81+
"See more details at "
82+
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
83+
self._step_count += 1
84+
7385
if epoch is None:
7486
epoch = self.last_epoch + 1
7587
self.last_epoch = epoch

0 commit comments

Comments
 (0)