@@ -13166,31 +13166,26 @@ def test_uses():
1316613166
1316713167 self.checkScript(test_uses, ())
1316813168
13169- @unittest.skipIf(True, "Removing weak script")
13170- def test_overloading(self):
13171- @torch._jit_internal.weak_module
13169+ def test_method_overloading(self):
1317213170 class W(torch.nn.Module):
13173- __overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
13174-
1317513171 def __init__(self):
1317613172 super(W, self).__init__()
1317713173
13178- @torch._jit_internal.weak_script_method
13179- def forward_tuple (self, x):
13174+ @torch.jit._overload_method # noqa: F811
13175+ def forward (self, x): # noqa: F811
1318013176 # type: (Tuple[Tensor, Tensor]) -> Tensor
13181- return x[0] + 5
13182-
13183- def forward(self, x):
13184- # manually do argument switching
13185- if isinstance(x, tuple):
13186- return self.forward_tuple(x)
13187- else:
13188- return self.forward_tensor(x)
13177+ pass
1318913178
13190- @torch._jit_internal.weak_script_method
13191- def forward_tensor (self, x):
13179+ @torch.jit._overload_method # noqa: F811
13180+ def forward (self, x): # noqa: F811
1319213181 # type: (Tensor) -> Tensor
13193- return x + 20
13182+ pass
13183+
13184+ def forward(self, x): # noqa: F811
13185+ if isinstance(x, Tensor):
13186+ return x + 20
13187+ else:
13188+ return x[0] + 5
1319413189
1319513190 class S(torch.jit.ScriptModule):
1319613191 def __init__(self):
@@ -13201,14 +13196,61 @@ def __init__(self):
1320113196 def forward(self, x):
1320213197 return self.weak(x) + self.weak((x, x))
1320313198
13204- s = S()
13199+ s_mod = S()
1320513200 x = torch.ones(1)
13206- self.assertEqual(s (x), x + 20 + 5 + x)
13201+ self.assertEqual(s_mod (x), x + 20 + 5 + x)
1320713202
1320813203 w = W()
1320913204 self.assertEqual(w((x, x)), x + 5)
1321013205 self.assertEqual(w((x)), x + 20)
1321113206
13207+ class Unannotated(torch.nn.Module):
13208+ def __init__(self):
13209+ super(Unannotated, self).__init__()
13210+
13211+ @torch.jit._overload_method # noqa: F811
13212+ def hello(self, x): # noqa: F811
13213+ pass
13214+
13215+ @torch.jit._overload_method # noqa: F811
13216+ def hello(self, x): # noqa: F811
13217+ # type: (int) -> (int)
13218+ pass
13219+
13220+ def hello(self, x): # noqa: F811
13221+ return x + 3
13222+
13223+ def forward(self):
13224+ return self.hello(1), self.hello(.5)
13225+
13226+ w = Unannotated()
13227+ with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
13228+ torch.jit.script(w)
13229+
13230+ class CompileOverloadError(torch.nn.Module):
13231+ def __init__(self):
13232+ super(CompileOverloadError, self).__init__()
13233+
13234+ @torch.jit._overload_method # noqa: F811
13235+ def hello(self, x): # noqa: F811
13236+ # type: (str) -> (int)
13237+ pass
13238+
13239+ @torch.jit._overload_method # noqa: F811
13240+ def hello(self, x): # noqa: F811
13241+ # type: (int) -> (int)
13242+ pass
13243+
13244+ def hello(self, x): # noqa: F811
13245+ return x + 1
13246+
13247+ def forward(self):
13248+ return self.hello("hi"), self.hello(.5)
13249+
13250+ w = CompileOverloadError()
13251+ with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
13252+ torch.jit.script(w)
13253+
1321213254 def test_select_after_chunk(self):
1321313255 def foo(x):
1321413256 chunked = torch.chunk(x, 1)
0 commit comments