-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Issue description
I was training an autoencoder with MSELoss, and the loss values on the training data were huge but the loss values on the validation data were small. It appeared as if the loss was not being averaged on the training pass, but it was on the validation pass. A little poking around in the debugger revealed this to be the case.
The problem is in the _pointwise_loss loss function in torch/nn/functional.py:
def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='elementwise_mean'):
if target.requires_grad:
d = lambd(input, target)
if reduction == 'none':
return d
return torch.mean(d) if reduction == 'elementwise_mean' else torch.sum(d)
else:
return lambd_optimized(input, target, _Reduction.get_enum(reduction))In the line return torch.mean(d) if reduction == 'elementwise_mean' else torch.sum(d), the variable reduction is actually an integer with value 1, not a string with value 'elementwise_mean', and so the loss is summed instead of averaged, even though the 'elementwise_mean` option was chosen.
Code example
This code reproduces the problem:
import torch
import torchvision as tv
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.c1 = torch.nn.Conv2d(1, 16, 3, padding=1)
self.c2 = torch.nn.Conv2d(16, 1, 3, padding=1)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
x = self.c2(self.relu(self.c1(x)))
return x
net = Net()
data = torchvision.datasets.MNIST('MNIST', True, download=True, transform=tv.transforms.ToTensor())
mse = torch.nn.MSELoss(reduction='elementwise_mean')
img = data[0][0].view(1, 1, 28, 28).requires_grad_()
pred = net(img)
loss1 = mse(pred, img).item()
img.requires_grad = False
loss2 = mse(pred, img).item()
print(loss1, loss2)And the output is:
155.7478 0.1987
System Info
PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.2 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.10.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 7.5.17
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti
GPU 4: GeForce GTX 1080 Ti
GPU 5: GeForce GTX 1080 Ti
GPU 6: GeForce GTX 1080 Ti
GPU 7: GeForce GTX 1080 Ti
Nvidia driver version: 387.34
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.5
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.5.0.5
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn_static.a
/usr/local/cuda-9.0/lib64/libcudnn.so
/usr/local/cuda-9.0/lib64/libcudnn.so.7
/usr/local/cuda-9.0/lib64/libcudnn.so.7.0.3
/usr/local/cuda-9.0/lib64/libcudnn_static.a
Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect