Skip to content

Conversation

@xwang233
Copy link
Collaborator

@xwang233 xwang233 commented Sep 19, 2020

Originally proposed at #44473 (comment) by @colesbury .

This PR adds the functionality to print relevant tensor shapes and convolution parameters along with the stack trace once a cuDNN exception is thrown.

@xwang233
Copy link
Collaborator Author

xwang233 commented Sep 19, 2020

There is this known cuDNN error in 8.0.0 and 8.0.1. It was fixed in 8.0.2, which is the version we are currently running on CI. It would be hard to test this feature, but you can take this diff and test that on cuDNN 8.0.0 or 8.0.1.

diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index 85faea1957634..70c49c185ed6a 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -2114,6 +2114,14 @@ def fractional_max_pool3d_test(test_case):
         cudnn=True,
         check_with_long_tensor=True,
     ),
+    dict(
+        fullname='Conv3d_groups_40999',
+        constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
+        cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
+        input_size=(1, 2, 3, 3, 3),
+        cudnn=True,
+        check_with_long_tensor=True,
+    ),
     dict(
         fullname='Conv3d_dilated',
         constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),

Example stack trace after this PR

$ CUDA_LAUNCH_BLOCKING=1 python test/test_nn.py -k Conv3d_groups_40999
.E.s
======================================================================
ERROR: test_Conv3d_groups_40999_cuda (__main__.TestNN)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/xwang/Developer/pytorch/torch/testing/_internal/common_utils.py", line 818, in wrapper
    method(*args, **kwargs)
  File "/home/xwang/Developer/pytorch/torch/testing/_internal/common_utils.py", line 818, in wrapper
    method(*args, **kwargs)
  File "test/test_nn.py", line 8709, in <lambda>
    add(cuda_test_name, lambda self, test=test, kwargs=kwargs: test.test_cuda(self, **kwargs))
  File "/home/xwang/Developer/pytorch/torch/testing/_internal/common_nn.py", line 4817, in test_cuda
    gpu_gg = torch.autograd.grad(
  File "/home/xwang/Developer/pytorch/torch/autograd/__init__.py", line 202, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
data = torch.randn([1, 1, 3, 3, 3], dtype=torch.float, device='cuda', requires_grad=True)
net = torch.nn.Conv3d(1, 2, kernel_size=[1, 1, 1], padding=[0, 0, 0], stride=[1, 1, 1], dilation=[1, 1, 1], groups=1)
net = net.cuda().float()
out = net(data)
out.backward(torch.randn_like(out))
torch.cuda.synchronize()

ConvolutionParams 
    data_type = CUDNN_DATA_FLOAT
    padding = [0, 0, 0]
    stride = [1, 1, 1]
    dilation = [1, 1, 1]
    groups = 1
    deterministic = false
    allow_tf32 = true
input: TensorDescriptor 0x7fcd52a52e70
    type = CUDNN_DATA_FLOAT
    nbDims = 5
    dimA = 1, 1, 3, 3, 3, 
    strideA = 27, 27, 9, 3, 1, 
output: TensorDescriptor 0x7fcd500081d0
    type = CUDNN_DATA_FLOAT
    nbDims = 5
    dimA = 1, 2, 3, 3, 3, 
    strideA = 54, 27, 9, 3, 1, 
weight: FilterDescriptor 0x7fcd500085f0
    type = CUDNN_DATA_FLOAT
    tensor_format = CUDNN_TENSOR_NCHW
    nbDims = 5
    dimA = 2, 1, 1, 1, 1, 
Pointer addresses: 
    input: 0x7fc775001e00
    output: 0x7fc775002400
    weight: 0x7fc775000608
Forward algorithm: 1

Exception raised from operator() at ../aten/src/ATen/native/cudnn/Conv.cpp:931 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7fcf01e5feb8 in /home/xwang/Developer/pytorch/torch/lib/libc10.so)
frame #1: <unknown function> + 0xa13112 (0x7fceca47f112 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x2d5dc0b (0x7fcecc7c9c0b in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x2d525bd (0x7fcecc7be5bd in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x2d4ea54 (0x7fcecc7baa54 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x2d4fdeb (0x7fcecc7bbdeb in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #6: at::native::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0xc3 (0x7fcecc7bc0f3 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x2da8711 (0x7fcecc814711 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #8: <unknown function> + 0x2dde7e6 (0x7fcecc84a7e6 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cuda.so)
frame #9: <unknown function> + 0x196ed24 (0x7fced36d4d24 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #10: at::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0x1e4 (0x7fced3619dd4 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x2e2dee9 (0x7fced4b93ee9 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x12af5e6 (0x7fced30155e6 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #13: <unknown function> + 0x196ed24 (0x7fced36d4d24 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #14: at::cudnn_convolution(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) + 0x1e4 (0x7fced3619dd4 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #15: at::native::_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool) + 0x2347 (0x7fced316cf87 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x1a20c76 (0x7fced3786c76 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x1a84313 (0x7fced37ea313 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x12af134 (0x7fced3015134 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x1969136 (0x7fced36cf136 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #20: at::_convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool) + 0x260 (0x7fced3615d20 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #21: at::native::_convolution_double_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool, std::array<bool, 3ul>) + 0xebd (0x7fced316751d in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #22: <unknown function> + 0x1a21072 (0x7fced3787072 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #23: <unknown function> + 0x1a8460e (0x7fced37ea60e in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #24: <unknown function> + 0x1a74f1f (0x7fced37daf1f in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #25: <unknown function> + 0x196a697 (0x7fced36d0697 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #26: at::_convolution_double_backward(c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool, std::array<bool, 3ul>) + 0x2ee (0x7fced361669e in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #27: torch::autograd::generated::CudnnConvolutionBackwardBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x46b (0x7fced4a3aa2b in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #28: <unknown function> + 0x317e08b (0x7fced4ee408b in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #29: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1852 (0x7fced4edecf2 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #30: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x5d0 (0x7fced4edf660 in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #31: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x9b (0x7fced4ed923b in /home/xwang/Developer/pytorch/torch/lib/libtorch_cpu.so)
frame #32: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x5e (0x7fced99a775e in /home/xwang/Developer/pytorch/torch/lib/libtorch_python.so)
frame #33: <unknown function> + 0xcfc24 (0x7fcf01fb9c24 in /usr/lib/libstdc++.so.6)
frame #34: <unknown function> + 0x93e9 (0x7fcf03b853e9 in /usr/lib/libpthread.so.0)
frame #35: clone + 0x43 (0x7fcf03c9e293 in /usr/lib/libc.so.6)


----------------------------------------------------------------------
Ran 4 tests in 2.422s

FAILED (errors=1, skipped=1)

@xwang233
Copy link
Collaborator Author

xwang233 commented Sep 19, 2020

cc @ptrblck

cc @zasdfgbnm to check if there are any extra parameters that need to be printed

@dr-ci
Copy link

dr-ci bot commented Sep 19, 2020

💊 CI failures summary and remediations

As of commit c3a6719 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_backward_compatibility_check_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 25 22:15:30 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not.
Sep 25 22:15:30 processing existing schema:  __setstate__(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase _0, (int, Tensor[], float[], int[]) _1) -> (None _0) 
Sep 25 22:15:30 processing existing schema:  bit_rate(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase _0) -> (int _0) 
Sep 25 22:15:30 processing existing schema:  version(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase _0) -> (int _0) 
Sep 25 22:15:30 processing existing schema:  __getstate__(__torch__.torch.classes.xnnpack.LinearOpContext _0) -> ((Tensor, Tensor?, Scalar?, Scalar?) _0) 
Sep 25 22:15:30 processing existing schema:  __setstate__(__torch__.torch.classes.xnnpack.LinearOpContext _0, (Tensor, Tensor?, Scalar?, Scalar?) _1) -> (None _0) 
Sep 25 22:15:30 processing existing schema:  __getstate__(__torch__.torch.classes.xnnpack.Conv2dOpContext _0) -> ((Tensor, Tensor?, int[], int[], int[], int, Scalar?, Scalar?) _0) 
Sep 25 22:15:30 processing existing schema:  __setstate__(__torch__.torch.classes.xnnpack.Conv2dOpContext _0, (Tensor, Tensor?, int[], int[], int[], int, Scalar?, Scalar?) _1) -> (None _0) 
Sep 25 22:15:30 processing existing schema:  __getstate__(__torch__.torch.classes.xnnpack.TransposeConv2dOpContext _0) -> ((Tensor, Tensor?, int[], int[], int[], int[], int, Scalar?, Scalar?) _0) 
Sep 25 22:15:30 processing existing schema:  __setstate__(__torch__.torch.classes.xnnpack.TransposeConv2dOpContext _0, (Tensor, Tensor?, int[], int[], int[], int[], int, Scalar?, Scalar?) _1) -> (None _0) 
Sep 25 22:15:30 processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> (None _0) 
Sep 25 22:15:30 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not.  
Sep 25 22:15:30  
Sep 25 22:15:30 Broken ops: [ 
Sep 25 22:15:30 	static::mul.a(Tensor a, Tensor b) -> (Tensor) 
Sep 25 22:15:30 	static::mul.b(Tensor a, int b) -> (Tensor) 
Sep 25 22:15:30 	static::add(Tensor a, Tensor b) -> (Tensor) 
Sep 25 22:15:30 ] 
Sep 25 22:15:30 + cleanup 
Sep 25 22:15:30 + retcode=1 
Sep 25 22:15:30 + set +x 
Sep 25 22:15:30 =================== sccache compilation log =================== 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 36 times.

@ngimel
Copy link
Collaborator

ngimel commented Sep 19, 2020

My suggestion is to format error message such that it can easily be used to reproduce the error in a standalone example:

you can reproduce this error by running the following script:
mod=nn.Conv3d(2, 4, kernel_size=3, groups=2).to("cuda")
inp = torch.randn(..., device=..., dtype=...., memory_format=...., requires_grad=...) # as_strided call if needed
out=mod(inp)
#possibly backward, if error happens in backward)

You cannot always guarantee unfortunately that the algo will be the same, but if deterministic/benchmark/tf32 are the same chances are good.
This will make it easier for people to report bugs, and for you to get cudnn logging.
Outputting descriptor objects pointers and data pointers probably does not make sense?

@xwang233
Copy link
Collaborator Author

Outputting descriptor objects pointers and data pointers probably does not make sense?

Hmm, descriptor pointer address print is already there for 3 years. I'm not sure if they are needed.

Data pointer address would definitely be needed. In the example I showed above, you can see the weight pointer is misaligned to 16-byte position, which caused the original cuDNN problem.

@xwang233
Copy link
Collaborator Author

xwang233 commented Sep 21, 2020

My suggestion is to format error message such that it can easily be used to reproduce the error in a standalone example:

@ngimel , generation of python repro code snippet is added. Is this ready to go?

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 22, 2020
ss << "groups=" << args.params.groups << ")" << to_channels_last << "\n";
ss << "net = net.cuda()." << partial_dtype << "()\n";
ss << "out = net(data)\n";
ss << "out.backward(torch.randn_like(out))\n";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include this only when the error message is for backward?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that is a good point. However, if we only want to include this in a backward pass, the codegen logic would be much more complicated. Besides, let users run an extra backward pass is basically no harm.

@codecov

This comment has been minimized.

@xwang233 xwang233 requested a review from ezyang September 25, 2020 00:08
@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2020

Related #37160

}
};

std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to mark this with TORCH_CUDA_API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_CUDA_API here introduces failed test, (not sure if they are relevant) like

Sep 25 19:10:06 Exception occurred:
Sep 25 19:10:06   File "/opt/conda/lib/python3.6/site-packages/sphinx/domains/cpp.py", line 6099, in _parse_type
Sep 25 19:10:06     raise self._make_multi_error(prevErrors, header)
Sep 25 19:10:06 sphinx.util.cfamily.DefinitionError: Error when parsing function declaration.
Sep 25 19:10:06 If the function has no return type:
Sep 25 19:10:06   Error in declarator or parameters-and-qualifiers
Sep 25 19:10:06   Invalid C++ declaration: Expecting "(" in parameters-and-qualifiers. [error at 15]
Sep 25 19:10:06     TORCH_CUDA_API std::ostream & operator<< (std::ostream &out, const FilterDescriptor &d)
Sep 25 19:10:06     ---------------^
Sep 25 19:10:06 If the function has a return type:
Sep 25 19:10:06   Error in declarator or parameters-and-qualifiers
Sep 25 19:10:06   If pointer to member declarator:
Sep 25 19:10:06     Invalid C++ declaration: Expected '::' in pointer to member (function). [error at 28]
Sep 25 19:10:06       TORCH_CUDA_API std::ostream & operator<< (std::ostream &out, const FilterDescriptor &d)
Sep 25 19:10:06       ----------------------------^
Sep 25 19:10:06   If declarator-id:
Sep 25 19:10:06     Invalid C++ declaration: Expecting "(" in parameters-and-qualifiers. [error at 28]
Sep 25 19:10:06       TORCH_CUDA_API std::ostream & operator<< (std::ostream &out, const FilterDescriptor &d)
Sep 25 19:10:06       ----------------------------^

https://dr.pytorch.org/api/view-log-full?build_id=137561284
https://dr.pytorch.org/api/view-log-full?build_id=137564136

The operator<< TensorDescriptor next to it doesn't have TORCH_CUDA_API, so I guess we don't need it?

}
};

std::string repro_from_args(const ConvolutionArgs& args) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little ambivalent about this. On the one hand, it is cool to produce repros in this style. But on the other hand, ensuring that the sample code here doesn't go stale is not going to be too easy. A compromise might be to produce the data in question here into a generic machine readable format (e.g., JSON) and then have a little Python script on the other side that takes that data and reinterprets it into an actual sequence of torch calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with adding this in as a one off, but if we start doing this pattern in other places in our codebase, we should think about some general infrastructure for doing this sort of thing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also have this as an exposed function and test that produced python code is runnable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I checked your proposal at #37160 , which is very nice and comprehensive. Since convolution is mostly guaranteed to be backward compatible and interface stable, this sample code here would work for a long time. I'm glad to refactor this in the way you mentioned when there is a better "repro code generation mechanism" available, and I'm looking forward to it.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@xwang233
Copy link
Collaborator Author

@ezyang @ngimel , are we going to have this feature for the next release?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in b4ba66a.

facebook-github-bot pushed a commit that referenced this pull request Oct 2, 2020
Summary:
Originally introduced in #45023. When I was doing test in the original PR, it was a Conv3d, so this problem was not discovered.

Arrays in `ConvolutionParams` have a fixed length of 3 or 5. This is because `max_dim` is set as a constexpr of 3, regardless of Conv2d or Conv3d. The current code will make some error message be weird. See below in the comments.

https://github.com/pytorch/pytorch/blob/9201c37d020007979e144693d86c8e8599e2fd8f/aten/src/ATen/native/cudnn/Conv.cpp#L212-L226

Pull Request resolved: #45729

Reviewed By: mruberry

Differential Revision: D24081542

Pulled By: ngimel

fbshipit-source-id: 141f8946f4d0db63a723131775731272abeaa6ab
xwang233 added a commit to xwang233/pytorch that referenced this pull request Oct 2, 2020
Summary:
Originally introduced in pytorch#45023. When I was doing test in the original PR, it was a Conv3d, so this problem was not discovered.

Arrays in `ConvolutionParams` have a fixed length of 3 or 5. This is because `max_dim` is set as a constexpr of 3, regardless of Conv2d or Conv3d. The current code will make some error message be weird. See below in the comments.

https://github.com/pytorch/pytorch/blob/9201c37d020007979e144693d86c8e8599e2fd8f/aten/src/ATen/native/cudnn/Conv.cpp#L212-L226

Pull Request resolved: pytorch#45729

Reviewed By: mruberry

Differential Revision: D24081542

Pulled By: ngimel

fbshipit-source-id: 141f8946f4d0db63a723131775731272abeaa6ab
malfet pushed a commit that referenced this pull request Oct 6, 2020
Summary:
Originally introduced in #45023. When I was doing test in the original PR, it was a Conv3d, so this problem was not discovered.

Arrays in `ConvolutionParams` have a fixed length of 3 or 5. This is because `max_dim` is set as a constexpr of 3, regardless of Conv2d or Conv3d. The current code will make some error message be weird. See below in the comments.

https://github.com/pytorch/pytorch/blob/9201c37d020007979e144693d86c8e8599e2fd8f/aten/src/ATen/native/cudnn/Conv.cpp#L212-L226

Pull Request resolved: #45729

Reviewed By: mruberry

Differential Revision: D24081542

Pulled By: ngimel

fbshipit-source-id: 141f8946f4d0db63a723131775731272abeaa6ab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants