11# Owner(s): ["oncall: jit"]
22
3- # flake8: noqa
4- # TODO: enable linting check for this file
5-
63from typing import List , Any
74import torch
85import torch .nn as nn
1411# Make the helper files in test/ importable
1512pytorch_test_dir = os .path .dirname (os .path .dirname (os .path .realpath (__file__ )))
1613sys .path .append (pytorch_test_dir )
17- from torch .testing ._internal .jit_utils import JitTestCase , execWrapper
1814
1915if __name__ == '__main__' :
2016 raise RuntimeError ("This test file is not meant to be run directly, use:\n \n "
@@ -109,6 +105,7 @@ def forward2(self, x: Tensor) -> Tensor:
109105 return self .two (self .one (x , x )) + 1
110106
111107 make_global (OneTwoModule , OneTwoClass )
108+
112109 def use_module_interface (mod_list : List [OneTwoModule ], x : torch .Tensor ):
113110 return mod_list [0 ].forward (x ) + mod_list [1 ].forward (x )
114111
@@ -135,6 +132,7 @@ class TestInterface(nn.Module):
135132 def one (self , inp1 , inp2 ):
136133 # type: (Tensor, Tensor) -> Tensor
137134 pass
135+
138136 def forward (self , input ):
139137 # type: (Tensor) -> Tensor
140138 r"""stuff 1"""
@@ -169,6 +167,7 @@ def forward(self, x: Tensor) -> Tensor:
169167 pass
170168
171169 make_global (OneTwoModule )
170+
172171 @torch .jit .script
173172 def as_module_interface (x : OneTwoModule ) -> OneTwoModule :
174173 return x
@@ -208,6 +207,7 @@ def forward(self, input: torch.Tensor) -> Any:
208207 pass
209208
210209 make_global (TensorToAny )
210+
211211 @torch .jit .script
212212 def as_tensor_to_any (x : TensorToAny ) -> TensorToAny :
213213 return x
@@ -218,6 +218,7 @@ def forward(self, input: Any) -> Any:
218218 pass
219219
220220 make_global (AnyToAny )
221+
221222 @torch .jit .script
222223 def as_any_to_any (x : AnyToAny ) -> AnyToAny :
223224 return x
@@ -373,8 +374,8 @@ def forward(self, input: Tensor) -> Tensor:
373374 scripted_no_module_interface .proxy_mod = torch .jit .script (OrigModule ())
374375 # proxy_mod is neither a module interface or have the same JIT type, should fail
375376 with self .assertRaisesRegex (RuntimeError ,
376- "Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
377- "for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'" ):
377+ r "Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
378+ r "for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'" ):
378379 scripted_no_module_interface .proxy_mod = torch .jit .script (NewModule ())
379380
380381 def test_script_module_as_interface_swap (self ):
@@ -466,7 +467,7 @@ def forward(self, x):
466467 m .eval ()
467468 mf = torch ._C ._freeze_module (m ._c )
468469 # Assume interface has no aliasing
469- mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
470+ mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
470471 input = torch .tensor ([1 ])
471472 out_s = m .forward (input )
472473 out_f = mf .forward (input )
@@ -481,6 +482,7 @@ def __init__(self):
481482 def forward (self , x ):
482483 self .b += 2
483484 return self .b
485+
484486 @torch .jit .export
485487 def getb (self , x ):
486488 return self .b
@@ -513,7 +515,7 @@ def forward(self, x):
513515 m .proxy_mod = m .sub
514516 m .eval ()
515517 with self .assertRaisesRegex (RuntimeError , "failed to freeze interface attribute 'proxy_mod'" ):
516- mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
518+ mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
517519
518520 def test_freeze_module_with_inplace_mutation_in_interface (self ):
519521 class SubModule (torch .nn .Module ):
@@ -524,6 +526,7 @@ def __init__(self):
524526 def forward (self , x ):
525527 self .b [0 ] += 2
526528 return self .b
529+
527530 @torch .jit .export
528531 def getb (self , x ):
529532 return self .b
@@ -551,15 +554,15 @@ def __init__(self):
551554
552555 def forward (self , x ):
553556 y = self .proxy_mod (x )
554- z = self .sub .getb (x )
557+ z = self .sub .getb (x )
555558 return y [0 ] + z [0 ]
556559
557560 m = torch .jit .script (TestModule ())
558561 m .proxy_mod = m .sub
559562 m .sub .b = m .proxy_mod .b
560563 m .eval ()
561564 with self .assertRaisesRegex (RuntimeError , "failed to freeze interface attribute 'proxy_mod'" ):
562- mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
565+ mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
563566
564567 def test_freeze_module_with_mutated_interface (self ):
565568 class SubModule (torch .nn .Module ):
@@ -569,6 +572,7 @@ def __init__(self):
569572
570573 def forward (self , x ):
571574 return self .b
575+
572576 @torch .jit .export
573577 def getb (self , x ):
574578 return self .b
@@ -597,13 +601,13 @@ def __init__(self):
597601 def forward (self , x ):
598602 self .proxy_mod = self .sub
599603 y = self .proxy_mod (x )
600- z = self .sub .getb (x )
604+ z = self .sub .getb (x )
601605 return y [0 ] + z [0 ]
602606
603607 m = torch .jit .script (TestModule ())
604608 m .eval ()
605609 with self .assertRaisesRegex (RuntimeError , "failed to freeze interface attribute 'proxy_mod'" ):
606- mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
610+ mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
607611
608612 def test_freeze_module_with_interface_and_fork (self ):
609613 class SubModule (torch .nn .Module ):
@@ -638,13 +642,13 @@ def __init__(self):
638642
639643 def forward (self , x ):
640644 y = self .proxy_mod (x )
641- z = self .sub (x )
645+ z = self .sub (x )
642646 return y + z
643647
644648 class MainModule (torch .nn .Module ):
645649 def __init__ (self ):
646650 super (MainModule , self ).__init__ ()
647- self .test = TestModule ()
651+ self .test = TestModule ()
648652
649653 def forward (self , x ):
650654 fut = torch .jit ._fork (self .test .forward , x )
@@ -654,7 +658,7 @@ def forward(self, x):
654658
655659 m = torch .jit .script (MainModule ())
656660 m .eval ()
657- mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
661+ mf = torch ._C ._freeze_module (m ._c , freezeInterfaces = True )
658662
659663 def test_module_apis_interface (self ):
660664 @torch .jit .interface
0 commit comments