@@ -13274,48 +13274,158 @@ def test_uses():
1327413274
1327513275 self.checkScript(test_uses, ())
1327613276
13277- @unittest.skipIf(True, "Removing weak script")
13278- def test_overloading(self):
13279- @torch._jit_internal.weak_module
13280- class W(torch.nn.Module):
13281- __overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
13282-
13277+ def test_method_overloading(self):
13278+ class Over(torch.nn.Module):
1328313279 def __init__(self):
13284- super(W , self).__init__()
13280+ super(Over , self).__init__()
1328513281
13286- @torch._jit_internal.weak_script_method
13287- def forward_tuple (self, x):
13282+ @torch.jit._overload_method # noqa: F811
13283+ def forward (self, x): # noqa: F811
1328813284 # type: (Tuple[Tensor, Tensor]) -> Tensor
13289- return x[0] + 5
13290-
13291- def forward(self, x):
13292- # manually do argument switching
13293- if isinstance(x, tuple):
13294- return self.forward_tuple(x)
13295- else:
13296- return self.forward_tensor(x)
13285+ pass
1329713286
13298- @torch._jit_internal.weak_script_method
13299- def forward_tensor (self, x):
13287+ @torch.jit._overload_method # noqa: F811
13288+ def forward (self, x): # noqa: F811
1330013289 # type: (Tensor) -> Tensor
13301- return x + 20
13290+ pass
13291+
13292+ def forward(self, x): # noqa: F811
13293+ if isinstance(x, Tensor):
13294+ return x + 20
13295+ else:
13296+ return x[0] + 5
1330213297
1330313298 class S(torch.jit.ScriptModule):
1330413299 def __init__(self):
1330513300 super(S, self).__init__()
13306- self.weak = W ()
13301+ self.weak = Over ()
1330713302
1330813303 @torch.jit.script_method
1330913304 def forward(self, x):
1331013305 return self.weak(x) + self.weak((x, x))
1331113306
13312- s = S()
13307+ s_mod = S()
1331313308 x = torch.ones(1)
13314- self.assertEqual(s(x), x + 20 + 5 + x)
13309+ self.assertEqual(s_mod(x), x + 20 + 5 + x)
13310+
13311+ over = Over()
13312+ self.assertEqual(over((x, x)), x + 5)
13313+ self.assertEqual(over((x)), x + 20)
13314+
13315+ class Unannotated(torch.nn.Module):
13316+ def __init__(self):
13317+ super(Unannotated, self).__init__()
13318+
13319+ @torch.jit._overload_method # noqa: F811
13320+ def hello(self, x): # noqa: F811
13321+ pass
13322+
13323+ @torch.jit._overload_method # noqa: F811
13324+ def hello(self, x): # noqa: F811
13325+ # type: (int) -> (int)
13326+ pass
13327+
13328+ def hello(self, x): # noqa: F811
13329+ return x + 3
13330+
13331+ def forward(self):
13332+ return self.hello(1), self.hello(.5)
13333+
13334+ w = Unannotated()
13335+ with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
13336+ torch.jit.script(w)
13337+
13338+ class CompileOverloadError(torch.nn.Module):
13339+ def __init__(self):
13340+ super(CompileOverloadError, self).__init__()
13341+
13342+ @torch.jit._overload_method # noqa: F811
13343+ def hello(self, x): # noqa: F811
13344+ # type: (str) -> (int)
13345+ pass
13346+
13347+ @torch.jit._overload_method # noqa: F811
13348+ def hello(self, x): # noqa: F811
13349+ # type: (int) -> (int)
13350+ pass
13351+
13352+ def hello(self, x): # noqa: F811
13353+ return x + 1
13354+
13355+ def forward(self):
13356+ return self.hello("hi"), self.hello(.5)
13357+
13358+ w = CompileOverloadError()
13359+ with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
13360+ torch.jit.script(w)
13361+
13362+ # testing overload declared first, then non-overload
13363+ with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13364+ class W3(torch.nn.Module):
13365+ def __init__(self):
13366+ super(W3, self).__init__()
13367+
13368+ @torch.jit._overload_method # noqa: F811
13369+ def forward(self, x): # noqa: F811
13370+ # type: (int) -> int
13371+ pass
13372+
13373+ @torch.jit._overload_method # noqa: F811
13374+ def forward(self, x): # noqa: F811
13375+ # type: (Tensor) -> Tensor
13376+ pass
13377+
13378+ def forward(self, x): # noqa: F811
13379+ return x + 5
13380+
13381+ a = W3()
13382+ b = torch.jit.script(a)
13383+
13384+ class W3(torch.nn.Module):
13385+ def __init__(self):
13386+ super(W3, self).__init__()
13387+
13388+ def forward(self, x): # noqa: F811
13389+ return x + 5 + 10
13390+
13391+ a = W3()
13392+ b = torch.jit.script(a)
13393+
13394+ # testing non-overload declared first, then overload
13395+ class W2(torch.nn.Module):
13396+ def __init__(self):
13397+ super(W2, self).__init__()
13398+
13399+ def hello(self, x1, x2):
13400+ return x1 + x2
13401+
13402+ def forward(self, x):
13403+ return self.hello(x, x)
13404+
13405+ a = torch.jit.script(W2())
13406+ self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
13407+
13408+ class W2(torch.nn.Module):
13409+ def __init__(self):
13410+ super(W2, self).__init__()
13411+
13412+ @torch.jit._overload_method # noqa: F811
13413+ def hello(self, x): # noqa: F811
13414+ pass
13415+
13416+ @torch.jit._overload_method # noqa: F811
13417+ def hello(self, x): # noqa: F811
13418+ # type: (int) -> (int)
13419+ pass
13420+
13421+ def hello(self, x): # noqa: F811
13422+ return x + 5 + 10
13423+
13424+ def forward(self, x):
13425+ return self.hello(1), self.hello(x)
1331513426
13316- w = W()
13317- self.assertEqual(w((x, x)), x + 5)
13318- self.assertEqual(w((x)), x + 20)
13427+ with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13428+ a = torch.jit.script(W2())
1331913429
1332013430 def test_select_after_chunk(self):
1332113431 def foo(x):
0 commit comments