-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
'torch.nn.function.gumbel_softmax' yields NaNs on CUDA device (but not on CPU). Default parameters are used (tau=1, hard=False).
To Reproduce
The following code generate random logits on CPU and on GPU and print a message if NaNs are encountered. Code to reproduce the behavior:
import sys
print("Python version:", sys.version)
import torch
print("\nPyTorch version:", torch.__version__)
torch.cuda.current_device()
device = 'cuda'
for k in range(5000):
logits = torch.zeros((1000,64)).normal_().to(device)
z = torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
n_nans = torch.isnan(z).sum()
if n_nans > 0 :
print(f"\ndevice: {device} step: {k} Number of NaNs (z): {n_nans} logits mean: {logits.mean().item()} logits min: {logits.min().item()} logits max: {logits.max().item()} number of NaNs (logits): {torch.isnan(logits).sum()}")
device = 'cpu'
for k in range(5000):
logits = torch.zeros((1000,64)).normal_().to(device)
z = torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
n_nans = torch.isnan(z).sum()
if n_nans > 0 :
print(f"\ndevice: {device} step: {k} Number of NaNs (z): {n_nans} logits mean: {logits.mean().item()} logits min: {logits.min().item()} logits max: {logits.max().item()} number of NaNs (logits): {torch.isnan(logits).sum()}")
Output:
Python version: 3.6.8 (default, Jan 14 2019, 11:02:34)
[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]
PyTorch version: 1.1.0
device: cuda step: 1946 Number of NaNs (z): 64 logits mean: 0.01120043732225895 logits min: -4.310624122619629 logits max: 4.168891429901123 number of NaNs (logits): 0
device: cuda step: 2115 Number of NaNs (z): 64 logits mean: 0.0014047547010704875 logits min: -4.174510478973389 logits max: 4.151383399963379 number of NaNs (logits): 0
device: cuda step: 2183 Number of NaNs (z): 64 logits mean: -0.0049570659175515175 logits min: -3.6217334270477295 logits max: 4.108720779418945 number of NaNs (logits): 0
device: cuda step: 2527 Number of NaNs (z): 64 logits mean: -0.0021658651530742645 logits min: -4.11637544631958 logits max: 3.9977431297302246 number of NaNs (logits): 0
device: cuda step: 3353 Number of NaNs (z): 64 logits mean: 0.003437937470152974 logits min: -4.200624465942383 logits max: 4.650900840759277 number of NaNs (logits): 0
device: cuda step: 4036 Number of NaNs (z): 64 logits mean: -0.0031167957931756973 logits min: -4.098625659942627 logits max: 4.771944046020508 number of NaNs (logits): 0
device: cuda step: 4837 Number of NaNs (z): 64 logits mean: 0.001691692159511149 logits min: -4.2665252685546875 logits max: 4.493197441101074 number of NaNs (logits): 0
Expected behavior
´torch.nn.functional.gumbel_softmax´ should not return NaNs.
Environment
Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.12.0
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 410.79
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0
Versions of relevant libraries:
[pip3] numpy==1.16.4
[pip3] torch==1.1.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.3.0
[conda] Could not collect
Additional context
Code example was run in Colab.