Skip to content

Commit a257bd1

Browse files
kahneapaszke
authored andcommitted
added state_dict/load_state_dict for ReduceLROnPlateau (#7201)
1 parent 4eaf526 commit a257bd1

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

test/test_optim.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,16 @@ def test_cosine_lr_state_dict(self):
647647
lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
648648
epochs=epochs)
649649

650+
def test_reduce_lr_on_plateau_state_dict(self):
651+
scheduler = ReduceLROnPlateau(self.opt, mode='min', factor=0.1, patience=2)
652+
for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
653+
scheduler.step(score)
654+
scheduler_copy = ReduceLROnPlateau(self.opt, mode='max', factor=0.5, patience=10)
655+
scheduler_copy.load_state_dict(scheduler.state_dict())
656+
for key in scheduler.__dict__.keys():
657+
if key not in {'optimizer', 'is_better'}:
658+
self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key], allow_inf=True)
659+
650660
def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
651661
scheduler = constr()
652662
for _ in range(epochs):

torch/optim/lr_scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,10 @@ def _init_is_better(self, mode, threshold, threshold_mode):
378378
self.mode_worse = (-float('inf'))
379379

380380
self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
381+
382+
def state_dict(self):
383+
return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}}
384+
385+
def load_state_dict(self, state_dict):
386+
self.__dict__.update(state_dict)
387+
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)

0 commit comments

Comments
 (0)