-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
When I create a jit.script function that includes torch.nn.functional.dropout without a constant is_training parameter. The fusion does not work. This did previously work.
To Reproduce
Here is an example script:
import torch
import torch.nn.functional as F
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) :
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=is_training)
out = residual + out
return out
@torch.jit.script
def jit_dropout_add_const(x, residual, prob) :
# type: (Tensor, Tensor, float) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
inputs = torch.ones(5, 5 , dtype=torch.float16, device=torch.device("cuda:0"), requires_grad=True)
residuals = torch.ones(5, 5 , dtype=torch.float16, device=torch.device("cuda:0"))
output_bad = jit_dropout_add(inputs, residuals, 0.1, True)
output_good = jit_dropout_add_const(inputs, residuals, 0.1)
Expected behavior
If you run the script above with PYTORCH_FUSION_DEBUG=1, you will see in the good case, shown below, n0 and n1 read from memory.
void kernel_1(IndexType totalElements, const TensorInfo<half,1> t0, const TensorInfo<half,1> t1, double s2, double s3, const TensorInfo<half,1> t4, const TensorInfo<half,1> t5, const TensorInfo<uint8_t,1> t6 ,unsigned long long seed, unsigned long long offset) {
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x) {
// calculate the results
float n0 = __half2float(t0.data[t0_offset]);
float n1 = __half2float(t1.data[t1_offset]);
double n2 = s2;
double n3 = s3;
int64_t n4 = 1;
float n5 = uniform(rnd());
uint8_t n6 = n5 < n3;
float n7 = (((float) n6));
float n8 = n7 * n1;
float n9 = n8 * ((float) n2);
float n10 = n0 + ((float) n4)*n9;
t4.data[t4_offset] = __float2half(n10);
t5.data[t5_offset] = __float2half(n9);
t6.data[t6_offset] = n6;
}
}
cc @suo
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue