@@ -1050,6 +1050,10 @@ def train(self, mode=True):
10501050 mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
10511051 etc.
10521052
1053+ Args:
1054+ mode (bool): whether to set training mode (``True``) or evaluation
1055+ mode (``False``). Default: ``True``.
1056+
10531057 Returns:
10541058 Module: self
10551059 """
@@ -1065,9 +1069,35 @@ def eval(self):
10651069 particular modules for details of their behaviors in training/evaluation
10661070 mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
10671071 etc.
1072+
1073+ This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
1074+
1075+ Returns:
1076+ Module: self
10681077 """
10691078 return self .train (False )
10701079
1080+ def requires_grad_ (self , requires_grad = True ):
1081+ r"""Change if autograd should record operations on parameters in this
1082+ module.
1083+
1084+ This method sets the parameters' :attr:`requires_grad` attributes
1085+ in-place.
1086+
1087+ This method is helpful for freezing part of the module for finetuning
1088+ or training parts of a model individually (e.g., GAN training).
1089+
1090+ Args:
1091+ requires_grad (bool): whether autograd should record operations on
1092+ parameters in this module. Default: ``True``.
1093+
1094+ Returns:
1095+ Module: self
1096+ """
1097+ for p in self .parameters ():
1098+ p .requires_grad_ (requires_grad )
1099+ return self
1100+
10711101 def zero_grad (self ):
10721102 r"""Sets gradients of all model parameters to zero."""
10731103 for p in self .parameters ():
0 commit comments