@@ -13166,48 +13166,158 @@ def test_uses():
1316613166
1316713167 self.checkScript(test_uses, ())
1316813168
13169- @unittest.skipIf(True, "Removing weak script")
13170- def test_overloading(self):
13171- @torch._jit_internal.weak_module
13172- class W(torch.nn.Module):
13173- __overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
13174-
13169+ def test_method_overloading(self):
13170+ class Over(torch.nn.Module):
1317513171 def __init__(self):
13176- super(W , self).__init__()
13172+ super(Over , self).__init__()
1317713173
13178- @torch._jit_internal.weak_script_method
13179- def forward_tuple (self, x):
13174+ @torch.jit._overload_method # noqa: F811
13175+ def forward (self, x): # noqa: F811
1318013176 # type: (Tuple[Tensor, Tensor]) -> Tensor
13181- return x[0] + 5
13182-
13183- def forward(self, x):
13184- # manually do argument switching
13185- if isinstance(x, tuple):
13186- return self.forward_tuple(x)
13187- else:
13188- return self.forward_tensor(x)
13177+ pass
1318913178
13190- @torch._jit_internal.weak_script_method
13191- def forward_tensor (self, x):
13179+ @torch.jit._overload_method # noqa: F811
13180+ def forward (self, x): # noqa: F811
1319213181 # type: (Tensor) -> Tensor
13193- return x + 20
13182+ pass
13183+
13184+ def forward(self, x): # noqa: F811
13185+ if isinstance(x, Tensor):
13186+ return x + 20
13187+ else:
13188+ return x[0] + 5
1319413189
1319513190 class S(torch.jit.ScriptModule):
1319613191 def __init__(self):
1319713192 super(S, self).__init__()
13198- self.weak = W ()
13193+ self.weak = Over ()
1319913194
1320013195 @torch.jit.script_method
1320113196 def forward(self, x):
1320213197 return self.weak(x) + self.weak((x, x))
1320313198
13204- s = S()
13199+ s_mod = S()
1320513200 x = torch.ones(1)
13206- self.assertEqual(s(x), x + 20 + 5 + x)
13201+ self.assertEqual(s_mod(x), x + 20 + 5 + x)
13202+
13203+ over = Over()
13204+ self.assertEqual(over((x, x)), x + 5)
13205+ self.assertEqual(over((x)), x + 20)
13206+
13207+ class Unannotated(torch.nn.Module):
13208+ def __init__(self):
13209+ super(Unannotated, self).__init__()
13210+
13211+ @torch.jit._overload_method # noqa: F811
13212+ def hello(self, x): # noqa: F811
13213+ pass
13214+
13215+ @torch.jit._overload_method # noqa: F811
13216+ def hello(self, x): # noqa: F811
13217+ # type: (int) -> (int)
13218+ pass
13219+
13220+ def hello(self, x): # noqa: F811
13221+ return x + 3
13222+
13223+ def forward(self):
13224+ return self.hello(1), self.hello(.5)
13225+
13226+ w = Unannotated()
13227+ with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
13228+ torch.jit.script(w)
13229+
13230+ class CompileOverloadError(torch.nn.Module):
13231+ def __init__(self):
13232+ super(CompileOverloadError, self).__init__()
13233+
13234+ @torch.jit._overload_method # noqa: F811
13235+ def hello(self, x): # noqa: F811
13236+ # type: (str) -> (int)
13237+ pass
13238+
13239+ @torch.jit._overload_method # noqa: F811
13240+ def hello(self, x): # noqa: F811
13241+ # type: (int) -> (int)
13242+ pass
13243+
13244+ def hello(self, x): # noqa: F811
13245+ return x + 1
13246+
13247+ def forward(self):
13248+ return self.hello("hi"), self.hello(.5)
13249+
13250+ w = CompileOverloadError()
13251+ with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
13252+ torch.jit.script(w)
13253+
13254+ # testing overload declared first, then non-overload
13255+ with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13256+ class W3(torch.nn.Module):
13257+ def __init__(self):
13258+ super(W3, self).__init__()
13259+
13260+ @torch.jit._overload_method # noqa: F811
13261+ def forward(self, x): # noqa: F811
13262+ # type: (int) -> int
13263+ pass
13264+
13265+ @torch.jit._overload_method # noqa: F811
13266+ def forward(self, x): # noqa: F811
13267+ # type: (Tensor) -> Tensor
13268+ pass
13269+
13270+ def forward(self, x): # noqa: F811
13271+ return x + 5
13272+
13273+ a = W3()
13274+ b = torch.jit.script(a)
13275+
13276+ class W3(torch.nn.Module):
13277+ def __init__(self):
13278+ super(W3, self).__init__()
13279+
13280+ def forward(self, x): # noqa: F811
13281+ return x + 5 + 10
13282+
13283+ a = W3()
13284+ b = torch.jit.script(a)
13285+
13286+ # testing non-overload declared first, then overload
13287+ class W2(torch.nn.Module):
13288+ def __init__(self):
13289+ super(W2, self).__init__()
13290+
13291+ def hello(self, x1, x2):
13292+ return x1 + x2
13293+
13294+ def forward(self, x):
13295+ return self.hello(x, x)
13296+
13297+ a = torch.jit.script(W2())
13298+ self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
13299+
13300+ class W2(torch.nn.Module):
13301+ def __init__(self):
13302+ super(W2, self).__init__()
13303+
13304+ @torch.jit._overload_method # noqa: F811
13305+ def hello(self, x): # noqa: F811
13306+ pass
13307+
13308+ @torch.jit._overload_method # noqa: F811
13309+ def hello(self, x): # noqa: F811
13310+ # type: (int) -> (int)
13311+ pass
13312+
13313+ def hello(self, x): # noqa: F811
13314+ return x + 5 + 10
13315+
13316+ def forward(self, x):
13317+ return self.hello(1), self.hello(x)
1320713318
13208- w = W()
13209- self.assertEqual(w((x, x)), x + 5)
13210- self.assertEqual(w((x)), x + 20)
13319+ with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13320+ a = torch.jit.script(W2())
1321113321
1321213322 def test_select_after_chunk(self):
1321313323 def foo(x):
0 commit comments