Skip to content

torch.nn.functional.gumbel_softmax yields NaNs #22442

@vlievin

Description

@vlievin

🐛 Bug

'torch.nn.function.gumbel_softmax' yields NaNs on CUDA device (but not on CPU). Default parameters are used (tau=1, hard=False).

To Reproduce

The following code generate random logits on CPU and on GPU and print a message if NaNs are encountered. Code to reproduce the behavior:

import sys
print("Python version:", sys.version)
import torch
print("\nPyTorch version:", torch.__version__)
torch.cuda.current_device()

device = 'cuda'

for k in range(5000):
  logits = torch.zeros((1000,64)).normal_().to(device)
  z = torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
  n_nans = torch.isnan(z).sum()
  if n_nans > 0 :
    print(f"\ndevice: {device} step: {k} Number of NaNs (z): {n_nans} logits mean: {logits.mean().item()} logits min: {logits.min().item()} logits max: {logits.max().item()} number of NaNs (logits): {torch.isnan(logits).sum()}")
   
  
device = 'cpu'

for k in range(5000):
  logits = torch.zeros((1000,64)).normal_().to(device)
  z = torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, dim=-1)
  n_nans = torch.isnan(z).sum()
  if n_nans > 0 :
    print(f"\ndevice: {device}  step: {k} Number of NaNs (z): {n_nans} logits mean: {logits.mean().item()} logits min: {logits.min().item()} logits max: {logits.max().item()} number of NaNs (logits): {torch.isnan(logits).sum()}")

Output:

Python version: 3.6.8 (default, Jan 14 2019, 11:02:34) 
[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]

PyTorch version: 1.1.0

device: cuda step: 1946 Number of NaNs (z): 64 logits mean: 0.01120043732225895 logits min: -4.310624122619629 logits max: 4.168891429901123 number of NaNs (logits): 0

device: cuda step: 2115 Number of NaNs (z): 64 logits mean: 0.0014047547010704875 logits min: -4.174510478973389 logits max: 4.151383399963379 number of NaNs (logits): 0

device: cuda step: 2183 Number of NaNs (z): 64 logits mean: -0.0049570659175515175 logits min: -3.6217334270477295 logits max: 4.108720779418945 number of NaNs (logits): 0

device: cuda step: 2527 Number of NaNs (z): 64 logits mean: -0.0021658651530742645 logits min: -4.11637544631958 logits max: 3.9977431297302246 number of NaNs (logits): 0

device: cuda step: 3353 Number of NaNs (z): 64 logits mean: 0.003437937470152974 logits min: -4.200624465942383 logits max: 4.650900840759277 number of NaNs (logits): 0

device: cuda step: 4036 Number of NaNs (z): 64 logits mean: -0.0031167957931756973 logits min: -4.098625659942627 logits max: 4.771944046020508 number of NaNs (logits): 0

device: cuda step: 4837 Number of NaNs (z): 64 logits mean: 0.001691692159511149 logits min: -4.2665252685546875 logits max: 4.493197441101074 number of NaNs (logits): 0

Expected behavior

´torch.nn.functional.gumbel_softmax´ should not return NaNs.

Environment

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla K80
Nvidia driver version: 410.79
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0

Versions of relevant libraries:
[pip3] numpy==1.16.4
[pip3] torch==1.1.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.3.0
[conda] Could not collect

Additional context

Code example was run in Colab.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: numerical-stabilityProblems related to numerical stability of operationstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions