-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Documentation for torch.optim.swa_utils
#41228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c2ea483
414d904
276109f
1b0edac
1d6e94a
690a43c
6799131
617359a
e65c5da
874476e
c7d5bbd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,3 +175,107 @@ should write your code this way: | |
| :members: | ||
| .. autoclass:: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts | ||
| :members: | ||
|
|
||
| Stochastic Weight Averaging | ||
| --------------------------- | ||
|
|
||
| :mod:`torch.optim.swa_utils` implements Stochastic Weight Averaging (SWA). In particular, | ||
| :class:`torch.optim.swa_utils.AveragedModel` class implements SWA models, | ||
| :class:`torch.optim.swa_utils.SWALR` implements the SWA learning rate scheduler and | ||
| :func:`torch.optim.swa_utils.update_bn` is a utility function used to update SWA batch | ||
| normalization statistics at the end of training. | ||
|
|
||
| SWA has been proposed in `Averaging Weights Leads to Wider Optima and Better Generalization`_. | ||
|
|
||
| .. _`Averaging Weights Leads to Wider Optima and Better Generalization` https://arxiv.org/abs/1803.05407 | ||
|
|
||
| Constructing averaged models | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| `AveragedModel` class serves to compute the weights of the SWA model. You can create an | ||
| averaged model by running: | ||
|
|
||
| >>> swa_model = AveragedModel(model) | ||
|
|
||
| Here the model ``model`` can be an arbitrary :class:`torch.nn.Module` object. ``swa_model`` | ||
| will keep track of the running averages of the parameters of the ``model``. To update these | ||
| averages, you can use the :func:`update_parameters` function: | ||
|
|
||
| >>> swa_model.update_parameters(model) | ||
|
|
||
|
|
||
| SWA learning rate schedules | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| Typically, in SWA the learning rate is set to a high constant value. :class:`SWALR` is a | ||
| learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it | ||
| constant. For example, the following code creates a scheduler that linearly anneals the | ||
| learning rate from its initial value to 0.05 in 5 epochs within each parameter group: | ||
|
|
||
| >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \ | ||
| >>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05) | ||
|
|
||
| You can also use cosine annealing to a fixed value instead of linear annealing by setting | ||
| ``anneal_strategy="cos"``. | ||
|
|
||
|
|
||
| Taking care of batch normalization | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| :func:`update_bn` is a utility function that allows to compute the batchnorm statistics for the SWA model | ||
| on a given dataloader ``loader`` at the end of training: | ||
|
|
||
| >>> torch.optim.swa_utils.update_bn(loader, swa_model) | ||
|
|
||
| :func:`update_bn` applies the ``swa_model`` to every element in the dataloader and computes the activation | ||
| statistics for each batch normalization layer in the model. | ||
|
|
||
| .. warning:: | ||
| :func:`update_bn` assumes that each batch in the dataloader ``loader`` is either a tensors or a list of | ||
| tensors where the first element is the tensor that the network ``swa_model`` should be applied to. | ||
| If your dataloader has a different structure, you can update the batch normalization statistics of the | ||
| ``swa_model`` by doing a forward pass with the ``swa_model`` on each element of the dataset. | ||
|
|
||
|
|
||
| Custom averaging strategies | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| By default, :class:`torch.optim.swa_utils.AveragedModel` computes a running equal average of | ||
| the parameters that you provide, but you can also use custom averaging functions with the | ||
| ``avg_fn`` parameter. In the following example ``ema_model`` computes an exponential moving average. | ||
|
|
||
| Example: | ||
|
|
||
| >>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\ | ||
| >>> 0.1 * averaged_model_parameter + 0.9 * model_parameter | ||
| >>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg) | ||
|
|
||
|
|
||
| Putting it all together | ||
| ^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| In the example below, ``swa_model`` is the SWA model that accumulates the averages of the weights. | ||
| We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule | ||
| and start to collect SWA averages of the parameters at epoch 160: | ||
|
|
||
| >>> loader, optimizer, model, loss_fn = ... | ||
| >>> swa_model = torch.optim.swa_utils.AveragedModel(model) | ||
| >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300) | ||
| >>> swa_start = 160 | ||
| >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) | ||
| >>> | ||
| >>> for epoch in range(300): | ||
| >>> for input, target in loader: | ||
| >>> optimizer.zero_grad() | ||
| >>> loss_fn(model(input), target).backward() | ||
| >>> optimizer.step() | ||
| >>> if i > swa_start: | ||
| >>> swa_model.update_parameters(model) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @izmailovpavel Hi, thanks for the implementation of SWA in PyTorch! Do you find it useful to average the weights during annealing phase? By default
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey Daniil, thank you for looking into the implementation! Generally, it's reasonable to average the weights during the annealing phase, but I imagine there could be cases when it's not desirable, e.g. when the learning rate before the annealing is way too high
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the response! It is just confusing a bit. According to the paper, the purpose of SWA is to average weights during exploration of loss surface (to find high-performing networks), but averaging during annealing would lead to averaging with different (decreasing) learning rates, which paper refers as non suitable (links to experiments in Ruppert, 1988) for improving generalization, due to SGD does not perform very differently under this schedule.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whether or not averaging in the annealing phase would be desirable would depend on how much the learning rate changes in the annealing phase. Generally, averaging at different learning rates can work well, see e.g. https://arxiv.org/abs/1806.05594. Although it can also be bad, as you said
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be safer to fix it in the docs as you suggest. It makes the example a bit more complex... @vincentqb @andrewgordonwilson what do you think?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is any chance to remove annealing phase from the scheduler? I believe this is the part which accidentally differs from the paper.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What would be an alternate fix that would make this simpler?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the simplest option is to allow SWALR take
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for not following up on this discussion for so long. @Daniil-Osokin I agree that we should allow |
||
| >>> swa_scheduler.step() | ||
| >>> else: | ||
| >>> scheduler.step() | ||
| >>> | ||
| >>> # Update bn statistics for the swa_model at the end | ||
| >>> torch.optim.swa_utils.update_bn(loader, swa_model) | ||
| >>> # Use swa_model to make predictions on test data | ||
| >>> preds = swa_model(test_input) | ||
Uh oh!
There was an error while loading. Please reload this page.