Skip to content

Commit 478d480

Browse files
ssnlfacebook-github-bot
authored andcommitted
Add Module.requires_grad_ (#22576)
Summary: addresses #20241 Pull Request resolved: #22576 Differential Revision: D16149314 Pulled By: zou3519 fbshipit-source-id: 1cc4c1ec084df30e00e9ae73ce1a53494a034d5c
1 parent 456d27d commit 478d480

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

test/test_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ def __init__(self):
570570

571571
return l, n, s
572572

573+
def test_requires_grad_(self):
574+
m = self._create_basic_net()[-1]
575+
assert len(list(m.buffers())) > 0, 'invalid test'
576+
assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
577+
assert len(list(m.parameters())) > 0, 'invalid test'
578+
assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
579+
for requires_grad in (False, True):
580+
self.assertIs(m.requires_grad_(requires_grad), m)
581+
for p in m.parameters():
582+
self.assertEqual(p.requires_grad, requires_grad)
583+
for b in m.buffers():
584+
self.assertFalse(b.requires_grad)
585+
573586
def test_module_backcompat(self):
574587
from torch.serialization import SourceChangeWarning
575588
path = download_file('https://download.pytorch.org/test_data/linear.pt')

torch/nn/modules/module.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,10 @@ def train(self, mode=True):
10501050
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
10511051
etc.
10521052
1053+
Args:
1054+
mode (bool): whether to set training mode (``True``) or evaluation
1055+
mode (``False``). Default: ``True``.
1056+
10531057
Returns:
10541058
Module: self
10551059
"""
@@ -1065,9 +1069,35 @@ def eval(self):
10651069
particular modules for details of their behaviors in training/evaluation
10661070
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
10671071
etc.
1072+
1073+
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
1074+
1075+
Returns:
1076+
Module: self
10681077
"""
10691078
return self.train(False)
10701079

1080+
def requires_grad_(self, requires_grad=True):
1081+
r"""Change if autograd should record operations on parameters in this
1082+
module.
1083+
1084+
This method sets the parameters' :attr:`requires_grad` attributes
1085+
in-place.
1086+
1087+
This method is helpful for freezing part of the module for finetuning
1088+
or training parts of a model individually (e.g., GAN training).
1089+
1090+
Args:
1091+
requires_grad (bool): whether autograd should record operations on
1092+
parameters in this module. Default: ``True``.
1093+
1094+
Returns:
1095+
Module: self
1096+
"""
1097+
for p in self.parameters():
1098+
p.requires_grad_(requires_grad)
1099+
return self
1100+
10711101
def zero_grad(self):
10721102
r"""Sets gradients of all model parameters to zero."""
10731103
for p in self.parameters():

0 commit comments

Comments
 (0)