Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,5 +618,6 @@ def _test_reduce_lr_on_plateau(self, scheduler, targets, metrics, epochs=10, ver
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)


if __name__ == '__main__':
run_tests()
36 changes: 23 additions & 13 deletions torch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
from bisect import bisect_right
from functools import partial

from .optimizer import Optimizer


Expand Down Expand Up @@ -322,22 +324,30 @@ def _reduce_lr(self, epoch):
def in_cooldown(self):
return self.cooldown_counter > 0

def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if threshold_mode not in {'rel', 'abs'}:
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
def _cmp(self, mode, threshold_mode, threshold, a, best):
if mode == 'min' and threshold_mode == 'rel':
rel_epsilon = 1. - threshold
self.is_better = lambda a, best: a < best * rel_epsilon
self.mode_worse = float('Inf')
return a < best * rel_epsilon

elif mode == 'min' and threshold_mode == 'abs':
self.is_better = lambda a, best: a < best - threshold
self.mode_worse = float('Inf')
return a < best - threshold

elif mode == 'max' and threshold_mode == 'rel':
rel_epsilon = threshold + 1.
self.is_better = lambda a, best: a > best * rel_epsilon
self.mode_worse = -float('Inf')
return a > best * rel_epsilon

else: # mode == 'max' and epsilon_mode == 'abs':
self.is_better = lambda a, best: a > best + threshold
self.mode_worse = -float('Inf')
return a > best + threshold

def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if threshold_mode not in {'rel', 'abs'}:
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

if mode == 'min':
self.mode_worse = float('inf')
else: # mode == 'max':
self.mode_worse = (-float('inf'))

self.is_better = partial(self._cmp, mode, threshold_mode, threshold)