Skip to content

Commit 939ae80

Browse files
vincentqbfacebook-github-bot
authored andcommitted
Change schedulers to chainable form (#24352)
Summary: Pull Request resolved: #24352 Enable chainable schedulers as requested in #13022 by implementing the changes mentioned below from [comment](#21800 (comment)). * Changing the behavior of schedulers to the chainable formula when available * Using the closed form whenever epoch is different from None until the next release with a deprecation warning * Making `get_computed_values` the supported way of obtaining the last computed learning rate by the scheduler (see [comment](#21800 (comment)) for new syntax) * Returning a deprecation warning when invoking the undocumented get_lr function (see [comment](#21800 (comment))) referring to `get_computed_values`, and deprecating it in the next release. * `CosineAnnealingWarmRestart` still takes an epoch parameter as it is the only one with a mechanic relying on fractional epoch * `MultiplicativeLR` is consumes a function providing the multiplicative factor at each epoch. It mimics `LambdaLR` in its syntax. # #20527 ### Before The user calls scheduler with a constant epoch either across loops or in the same loop. ``` import torch.optim as optim from torch import nn conv = nn.Conv2d(3,3,3) optimizer = optim.Adam(conv.parameters()) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2) # Scheduler with sometimes-constant epoch number for epoch in [0, 0, 1, 1, 2, 2, 3, 3]: lr_scheduler.step(epoch) print(optimizer.param_groups[0]['lr']) ``` ### After If the user wants to step ``` import torch.optim as optim from torch import nn conv = nn.Conv2d(3,3,3) optimizer = optim.Adam(conv.parameters()) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 2) last_epoch = -1 for epoch in [0, 0, 1, 1, 2, 2, 3, 3]: # Check if epoch number has changed manually if epoch-last_epoch > 0: lr_scheduler.step() last_epoch = epoch print(epoch, scheduler.get_computed_values()) ``` # #22107 ### Before ``` import torch from torchvision.models import resnet18 net = resnet18() optimizer = torch.optim.SGD(net.parameters(), 0.1) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) for i in range(10): # Scheduler computes and returns new learning rate, leading to unexpected behavior print(i, scheduler.get_lr()) scheduler.step() ``` ### After ``` import torch from torchvision.models import resnet18 net = resnet18() optimizer = torch.optim.SGD(net.parameters(), 0.1) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) for i in range(10): # Returns last computed learning rate by scheduler print(i, lr_scheduler.get_computed_values()) lr_scheduler.step() ``` Test Plan: Imported from OSS Differential Revision: D17349760 Pulled By: vincentqb fbshipit-source-id: 0a6ac01e2a6b45000bc6f9df732033dd81f0d89f
1 parent 2503fdc commit 939ae80

File tree

2 files changed

+436
-91
lines changed

2 files changed

+436
-91
lines changed

0 commit comments

Comments
 (0)