Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6840,6 +6840,80 @@ def opt_func(x):
# type: (Optional[int]) -> bool
return isinstance(x, int)

def test_dropout_eval(self):
class ScriptedConv2d(torch.jit.ScriptModule):
def __init__(self, in_channels, out_channels, **kwargs):
super(ScriptedConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

@torch.jit.script_method
def forward(self, x):
x = self.conv(x)
return x
x = self.bn(x)
return F.relu(x, inplace=True)

class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super(ScriptMod, self).__init__()
self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2)

@torch.jit.script_method
def forward(self, x):
x = self.Conv2d_1a_3x3(x)
return x
return F.dropout(x, training=self.training)

class EagerConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(EagerConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

def forward(self, x):
x = self.conv(x)
return x
x = self.bn(x)
return F.relu(x, inplace=True)

class EagerMod(torch.nn.Module):
def __init__(self):
super(EagerMod, self).__init__()
self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2)

def forward(self, x):
x = self.Conv2d_1a_3x3(x)
return x
return F.dropout(x, training=self.training)

script_input = torch.rand(4, 3, 299, 299)
eager_input = script_input.clone()

with freeze_rng_state():
script_mod = ScriptMod()
script_mod.eval()
script_output = script_mod(script_input)

with freeze_rng_state():
eager_mod = EagerMod()
eager_mod.eval()
eager_output = eager_mod(eager_input)

self.assertEqual(script_output, eager_output)

with freeze_rng_state():
script_mod = ScriptMod()
script_mod.train()
script_output = script_mod(script_input)

with freeze_rng_state():
eager_mod = EagerMod()
eager_mod.train()
eager_output = eager_mod(eager_input)

self.assertEqual(script_output, eager_output)

def test_python_call(self):
def pyfunc(a):
return a * 3.0
Expand Down
1 change: 0 additions & 1 deletion test/test_jit_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def func(x):

a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
s = torch.jit.script(func, (a,))
self.assertAllFused(s.graph_for(a,), except_for={'aten::div', 'prim::Constant'})
c = s(a)
c.sum().backward()
graph = backward_graph(s)
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,11 @@ const std::vector<std::string> functions = {
mask.bernoulli_(p1m)
res = mask * input / p1m

if not train:
p1m = 1.
res = input
mask = torch.ones_like(input)

def backward(grad_output):
use_cuda = grad_output.is_cuda
if use_cuda:
Expand Down