Skip to content

Commit b5f7592

Browse files
szagoruykosoumith
authored andcommitted
boolean mode in module.train
1 parent f366e5f commit b5f7592

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

torch/nn/modules/module.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,25 +357,22 @@ def modules(self, memo=None):
357357
for m in module.modules(memo):
358358
yield m
359359

360-
def train(self):
360+
def train(self, mode=True):
361361
"""Sets the module in training mode.
362362
363363
This has any effect only on modules such as Dropout or BatchNorm.
364364
"""
365-
self.training = True
365+
self.training = mode
366366
for module in self.children():
367-
module.train()
367+
module.train(mode)
368368
return self
369369

370370
def eval(self):
371371
"""Sets the module in evaluation mode.
372372
373373
This has any effect only on modules such as Dropout or BatchNorm.
374374
"""
375-
self.training = False
376-
for module in self.children():
377-
module.eval()
378-
return self
375+
return self.train(False)
379376

380377
def zero_grad(self):
381378
"""Sets gradients of all model parameters to zero."""

0 commit comments

Comments
 (0)