Skip to content

Conversation

@pritamdamania87
Copy link
Contributor

@pritamdamania87 pritamdamania87 commented Aug 27, 2020

Stack from ghstack:

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:

> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something

Differential Revision: D23365408

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 27, 2020
pritamdamania87 pushed a commit that referenced this pull request Aug 27, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

ghstack-source-id: 110820151
Pull Request resolved: #43684
@dr-ci
Copy link

dr-ci bot commented Aug 27, 2020

💊 CI failures summary and remediations

As of commit 05d30a3 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (1/2)

Step: "Attaching workspace" (full log | diagnosis details | 🔁 rerun) <confirmed not flaky by 2 failures>

Downloading workspace layers
Downloading workspace layers
  workflows/workspaces/fc5d0f53-a0a7-4b42-9210-f120a044e074/0/14abf715-3b50-4ee8-8d34-304244dce2e6/0/105.tar.gz - 8.4 MB
Applying workspace layers
  14abf715-3b50-4ee8-8d34-304244dce2e6

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test1 (2/2)

Step: "Attaching workspace" (full log | diagnosis details | 🔁 rerun) <confirmed not flaky by 2 failures>

Downloading workspace layers
Downloading workspace layers
  workflows/workspaces/fc5d0f53-a0a7-4b42-9210-f120a044e074/0/14abf715-3b50-4ee8-8d34-304244dce2e6/0/105.tar.gz - 8.4 MB
Applying workspace layers
  14abf715-3b50-4ee8-8d34-304244dce2e6

3 failures confirmed as flaky and can be ignored:

  • pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test
  • pytorch_linux_xenial_py3_6_gcc5_4_test
  • caffe2_onnx_main_py3_6_clang7_ubuntu16_04_test

🚧 4 fixed upstream failures:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 51 times.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you separate the formatting changes in a different commit? It feels like it makes it unnecessarily hard to find what actually changed here.

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
@pritamdamania87 pritamdamania87 requested a review from albanD August 28, 2020 04:17
@pritamdamania87
Copy link
Contributor Author

Could you separate the formatting changes in a different commit? It feels like it makes it unnecessarily hard to find what actually changed here.

Sorry about that, I ran clang-format and didn't realize the formatting changes were so significant. I've removed all the formatting changes from the PR to make it easier to review.

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Aug 28, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 110920637

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!
Thanks for doing this update!

<< errorMsg;
LOG(INFO)
<< "Skipping setting following error on the Future since "
<< "it is already marked completed (this is not neccessarily an error)";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to use try_retrieve_error_message() here to keep the old behavior? Or it is not very important to log that message?

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Aug 29, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111002998

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Aug 31, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111080082

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Sep 1, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111109637

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in f1624b8.

@xwang233
Copy link
Collaborator

xwang233 commented Sep 2, 2020

I think this PR removed the functionality of C++ stack trace from cuDNN error. I'm not sure if this is an intended behavior. Hope we can get a fix on it.

For example, there is this known issue. If you have cuda11 + cudnn 8.0.1 (or you can try this environment on NGC pytorch 20.07 https://ngc.nvidia.com/catalog/containers/nvidia:pytorch/tags), and run this convolution double backward code, you will get CUDNN_STATUS_NOT_SUPPORTED runtime error. (the issue itself was fixed in cudnn 8.0.2, so don't worry about it :)

Before this PR:

1.7.0a0+825c109
Traceback (most recent call last):
  File "d3.py", line 21, in <module>
    gi.sum().backward()
  File "/home/xwang/Developer/pytorch/torch/tensor.py", line 214, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/xwang/Developer/pytorch/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.
Exception raised from operator() at ../aten/src/ATen/native/cudnn/Conv.cpp:845 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x55 (0x7f43152bded5 in /home/xwang/Developer/pytorch/torch/lib/libc10.so)
frame #1: <unknown function> + 0x2911ea8 (0x7f4317c23ea8 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x2915e13 (0x7f4317c27e13 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x2911544 (0x7f4317c23544 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x2914c83 (0x7f4317c26c83 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #5: at::native::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0xab (0x7f4317c26f1b in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0x2998d6d (0x7f4317caad6d in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x29c7bca (0x7f4317cd9bca in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #8: <unknown function> + 0x1426a8d (0x7f431ed6ca8d in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #9: at::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0x163 (0x7f431ec83bf3 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x27b3864 (0x7f43200f9864 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0xbf830a (0x7f431e53e30a in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x1426a8d (0x7f431ed6ca8d in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #13: at::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0x163 (0x7f431ec83bf3 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #14: at::native::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool) + 0x3a79 (0x7f431e6a6719 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x14b53a6 (0x7f431edfb3a6 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x150237e (0x7f431ee4837e in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0xbf80b8 (0x7f431e53e0b8 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x1422563 (0x7f431ed68563 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #19: at::_convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool) + 0x225 (0x7f431ec7fe25 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #20: at::native::_convolution_double_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 3ul>) + 0xcfb (0x7f431e69fffb in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x14b55ad (0x7f431edfb5ad in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #22: <unknown function> + 0x1502531 (0x7f431ee48531 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #23: <unknown function> + 0x14fdd08 (0x7f431ee43d08 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #24: <unknown function> + 0x14231a0 (0x7f431ed691a0 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #25: at::_convolution_double_backward(c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 3ul>) + 0x2a3 (0x7f431ec80863 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #26: torch::autograd::generated::CudnnConvolutionBackwardBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x3e2 (0x7f431ffcd072 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #27: <unknown function> + 0x2a7fcab (0x7f43203c5cab in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #28: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1596 (0x7f43203c1af6 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #29: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x507 (0x7f43203c2387 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #30: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x8a (0x7f43203bafea in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #31: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4e (0x7f4324c3cede in /home/xwang/Developer/pytorch/torch/lib/libtorch_python.so)
frame #32: <unknown function> + 0xc39a3 (0x7f43303d49a3 in /lib64/libstdc++.so.6)
frame #33: <unknown function> + 0x94e2 (0x7f43595a34e2 in /lib64/libpthread.so.0)
frame #34: clone + 0x43 (0x7f4359a096c3 in /lib64/libc.so.6)

After this PR:

1.7.0a0+f1624b8
Traceback (most recent call last):
  File "d3.py", line 21, in <module>
    gi.sum().backward()
  File "/home/xwang/Developer/pytorch/torch/tensor.py", line 214, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/xwang/Developer/pytorch/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.

@albanD
Copy link
Collaborator

albanD commented Sep 2, 2020

Thanks for the report!

Are you sure that you did not remove the TORCH_SHOW_CPP_STACKTRACES=1 flag from one run to the other?

@xwang233
Copy link
Collaborator

xwang233 commented Sep 2, 2020

I'm fairly certain that this is my first time seeing the TORCH_SHOW_CPP_STACKTRACES flag. I kept everything the same for all runs, without changing environment variables or flags. Thanks for checking!

@albanD
Copy link
Collaborator

albanD commented Sep 2, 2020

That is interesting.
Could you open a new issue about this please so that we can properly track it and discuss there?
Also if you have a code snippet that shows the issue without requiring such specific versions of cuda, it would be perfect!

@xwang233
Copy link
Collaborator

xwang233 commented Sep 2, 2020

I just realized that I can use TORCH_SHOW_CPP_STACKTRACES=1 python main.py to show the C++ stack trace again. I thought it was gone forever. It seems that the C++ stack trace is just hidden by default after this PR. Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants