Skip to content

[JIT] Fusion of Dropout without constant is_training parameter is unsuccessful #24032

@kevinstephano

Description

@kevinstephano

🐛 Bug

When I create a jit.script function that includes torch.nn.functional.dropout without a constant is_training parameter. The fusion does not work. This did previously work.

To Reproduce

Here is an example script:

import torch
import torch.nn.functional as F

@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) :
    # type: (Tensor, Tensor, float, bool) -> Tensor
    out = F.dropout(x, p=prob, training=is_training)
    out = residual + out
    return out

@torch.jit.script
def jit_dropout_add_const(x, residual, prob) :
    # type: (Tensor, Tensor, float) -> Tensor
    out = F.dropout(x, p=prob, training=True)
    out = residual + out
    return out


inputs = torch.ones(5, 5 , dtype=torch.float16, device=torch.device("cuda:0"), requires_grad=True)
residuals = torch.ones(5, 5 , dtype=torch.float16, device=torch.device("cuda:0"))

output_bad = jit_dropout_add(inputs, residuals, 0.1, True)
output_good = jit_dropout_add_const(inputs, residuals, 0.1)

Expected behavior

If you run the script above with PYTORCH_FUSION_DEBUG=1, you will see in the good case, shown below, n0 and n1 read from memory.

void kernel_1(IndexType totalElements, const TensorInfo<half,1> t0, const TensorInfo<half,1> t1, double s2, double s3, const TensorInfo<half,1> t4, const TensorInfo<half,1> t5, const TensorInfo<uint8_t,1> t6 ,unsigned long long seed, unsigned long long offset) {
  
  for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
        linearIndex < totalElements;
        linearIndex += gridDim.x * blockDim.x) {
      

      // calculate the results
      float n0 = __half2float(t0.data[t0_offset]);
      float n1 = __half2float(t1.data[t1_offset]);
      double n2 = s2;
      double n3 = s3;
      int64_t n4 = 1;
      float n5 = uniform(rnd());
      uint8_t n6 = n5 < n3;
      float n7 = (((float) n6));
      float n8 = n7 * n1;
      float n9 = n8 * ((float) n2);
      float n10 = n0 + ((float) n4)*n9;
      t4.data[t4_offset] = __float2half(n10);
      t5.data[t5_offset] = __float2half(n9);
      t6.data[t6_offset] = n6;
    
    }
}

cc @suo

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions