File tree Expand file tree Collapse file tree 1 file changed +4
-7
lines changed
Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -357,25 +357,22 @@ def modules(self, memo=None):
357357 for m in module .modules (memo ):
358358 yield m
359359
360- def train (self ):
360+ def train (self , mode = True ):
361361 """Sets the module in training mode.
362362
363363 This has any effect only on modules such as Dropout or BatchNorm.
364364 """
365- self .training = True
365+ self .training = mode
366366 for module in self .children ():
367- module .train ()
367+ module .train (mode )
368368 return self
369369
370370 def eval (self ):
371371 """Sets the module in evaluation mode.
372372
373373 This has any effect only on modules such as Dropout or BatchNorm.
374374 """
375- self .training = False
376- for module in self .children ():
377- module .eval ()
378- return self
375+ return self .train (False )
379376
380377 def zero_grad (self ):
381378 """Sets gradients of all model parameters to zero."""
You can’t perform that action at this time.
0 commit comments