-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
- Run the code below
import torch
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, *grad):
raise ValueError("something")
t = torch.Tensor(20)
t.requires_grad_()
output = Foo.apply(t)
loss = torch.nn.MSELoss()
loss(output, torch.Tensor(20)).backward()
Expected behavior
Stacktrace at the point of raise, showing the error to be from Foo.backward
Actual output
Traceback (most recent call last):
File "/private/home/tbirch/bug.py", line 17, in <module>
loss(output, torch.Tensor(20)).backward()
File "/private/home/tbirch/.conda/envs/py38/lib/python3.8/site-packages/torch/tensor.py", line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/private/home/tbirch/.conda/envs/py38/lib/python3.8/site-packages/torch/autograd/__init__.py", line 98, in backward
Variable._execution_engine.run_backward(
RuntimeError: something
Environment
(same issue in 1.5.1 and 1.6.0)
Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: N/A
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 418.116.00
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1
[pip3] torchtext==0.7.0
[pip3] torchvision==0.6.0a0+35d732a
[pip3] torchviz==0.0.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] mkl 2019.4 243
[conda] mkl-service 2.3.0 py38h516909a_0 conda-forge
[conda] mkl_fft 1.1.0 py38hc1659b7_1 conda-forge
[conda] mkl_random 1.1.0 py38h962f231_0
[conda] numpy 1.18.5 py38ha1c710e_0
[conda] numpy-base 1.18.5 py38hde5b4d6_0
[conda] pytorch 1.5.1 py3.8_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchtext 0.7.0 pypi_0 pypi
[conda] torchvision 0.6.1 py38_cu101 pytorch
[conda] torchviz 0.0.1 pypi_0 pypi