-
Notifications
You must be signed in to change notification settings - Fork 26.8k
Description
🐛 Describe the bug
torch.nn.functional.interpolate fails with a RuntimeError when the following conditions are met:
- The input tensor uses the
channels_lastmemory format. - The input shape is larger than a certain threshold.
The following code works fine, producing a tensor with the expected shape [31, 64, 1024, 1024]:
x = torch.rand((31, 64, 512, 512)).cuda().to(memory_format=torch.channels_last)
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest').shapetorch.Size([31, 64, 1024, 1024])
However, when the input batch dimension is 32 or larger, it fails:
x = torch.rand((32, 64, 512, 512)).cuda().to(memory_format=torch.channels_last)
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest').shapeRuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements
If the memory layout is contiguous rather than channels last, it works fine too:
x = torch.rand((32, 64, 512, 512)).cuda()
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest').shapetorch.Size([32, 64, 1024, 1024])
The error is raised here. I'm not sure about the details, but I think a potential workaround could be to automatically revert to contiguous format, rather than failing.
Versions
Collecting environment information...
PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 3090
Nvidia driver version: 470.42.01
cuDNN version: Probably one of the following:
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.1+cu113
[pip3] torchaudio==0.10.1+cu113
[pip3] torchvision==0.11.2+cu113
[conda] numpy 1.20.3 pypi_0 pypi
[conda] torch 1.10.1+cu113 pypi_0 pypi
[conda] torchaudio 0.10.1+cu113 pypi_0 pypi
[conda] torchvision 0.11.2+cu113 pypi_0 pypi
I was using the RTX 3090 in this test. I have observed the same behaviour using other cards (RTX A6000, for instance) in other systems running different versions of Python, PyTorch and OS.
cc @ezyang @gchanan @zou3519 @ngimel @VitalyFedyunin @jamesr66a @csarofeen @ptrblck @xwang233