Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,19 @@ def __init__(self):

return l, n, s

def test_requires_grad_(self):
m = self._create_basic_net()[-1]
assert len(list(m.buffers())) > 0, 'invalid test'
assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
assert len(list(m.parameters())) > 0, 'invalid test'
assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
for requires_grad in (False, True):
self.assertIs(m.requires_grad_(requires_grad), m)
for p in m.parameters():
self.assertEqual(p.requires_grad, requires_grad)
for b in m.buffers():
self.assertFalse(b.requires_grad)

def test_module_backcompat(self):
from torch.serialization import SourceChangeWarning
path = download_file('https://download.pytorch.org/test_data/linear.pt')
Expand Down
30 changes: 30 additions & 0 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,10 @@ def train(self, mode=True):
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.

Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.

Returns:
Module: self
"""
Expand All @@ -1065,9 +1069,35 @@ def eval(self):
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.

This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

Returns:
Module: self
"""
return self.train(False)

def requires_grad_(self, requires_grad=True):
r"""Change if autograd should record operations on parameters in this
module.

This method sets the parameters' :attr:`requires_grad` attributes
in-place.

This method is helpful for freezing part of the module for finetuning
or training parts of a model individually (e.g., GAN training).

Args:
requires_grad (bool): whether autograd should record operations on
parameters in this module. Default: ``True``.

Returns:
Module: self
"""
for p in self.parameters():
p.requires_grad_(requires_grad)
return self

def zero_grad(self):
r"""Sets gradients of all model parameters to zero."""
for p in self.parameters():
Expand Down