Skip to content

Commit bb1e3d8

Browse files
kit1980pytorchmergebot
authored andcommitted
Enable lint for test_module_interface.py (#83359)
Pull Request resolved: #83359 Approved by: https://github.com/huydhn
1 parent c280857 commit bb1e3d8

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

test/jit/test_module_interface.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Owner(s): ["oncall: jit"]
22

3-
# flake8: noqa
4-
# TODO: enable linting check for this file
5-
63
from typing import List, Any
74
import torch
85
import torch.nn as nn
@@ -14,7 +11,6 @@
1411
# Make the helper files in test/ importable
1512
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
1613
sys.path.append(pytorch_test_dir)
17-
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
1814

1915
if __name__ == '__main__':
2016
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
@@ -109,6 +105,7 @@ def forward2(self, x: Tensor) -> Tensor:
109105
return self.two(self.one(x, x)) + 1
110106

111107
make_global(OneTwoModule, OneTwoClass)
108+
112109
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
113110
return mod_list[0].forward(x) + mod_list[1].forward(x)
114111

@@ -135,6 +132,7 @@ class TestInterface(nn.Module):
135132
def one(self, inp1, inp2):
136133
# type: (Tensor, Tensor) -> Tensor
137134
pass
135+
138136
def forward(self, input):
139137
# type: (Tensor) -> Tensor
140138
r"""stuff 1"""
@@ -169,6 +167,7 @@ def forward(self, x: Tensor) -> Tensor:
169167
pass
170168

171169
make_global(OneTwoModule)
170+
172171
@torch.jit.script
173172
def as_module_interface(x: OneTwoModule) -> OneTwoModule:
174173
return x
@@ -208,6 +207,7 @@ def forward(self, input: torch.Tensor) -> Any:
208207
pass
209208

210209
make_global(TensorToAny)
210+
211211
@torch.jit.script
212212
def as_tensor_to_any(x: TensorToAny) -> TensorToAny:
213213
return x
@@ -218,6 +218,7 @@ def forward(self, input: Any) -> Any:
218218
pass
219219

220220
make_global(AnyToAny)
221+
221222
@torch.jit.script
222223
def as_any_to_any(x: AnyToAny) -> AnyToAny:
223224
return x
@@ -373,8 +374,8 @@ def forward(self, input: Tensor) -> Tensor:
373374
scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
374375
# proxy_mod is neither a module interface or have the same JIT type, should fail
375376
with self.assertRaisesRegex(RuntimeError,
376-
"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
377-
"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'"):
377+
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
378+
r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'"):
378379
scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
379380

380381
def test_script_module_as_interface_swap(self):
@@ -466,7 +467,7 @@ def forward(self, x):
466467
m.eval()
467468
mf = torch._C._freeze_module(m._c)
468469
# Assume interface has no aliasing
469-
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
470+
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
470471
input = torch.tensor([1])
471472
out_s = m.forward(input)
472473
out_f = mf.forward(input)
@@ -481,6 +482,7 @@ def __init__(self):
481482
def forward(self, x):
482483
self.b += 2
483484
return self.b
485+
484486
@torch.jit.export
485487
def getb(self, x):
486488
return self.b
@@ -513,7 +515,7 @@ def forward(self, x):
513515
m.proxy_mod = m.sub
514516
m.eval()
515517
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
516-
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
518+
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
517519

518520
def test_freeze_module_with_inplace_mutation_in_interface(self):
519521
class SubModule(torch.nn.Module):
@@ -524,6 +526,7 @@ def __init__(self):
524526
def forward(self, x):
525527
self.b[0] += 2
526528
return self.b
529+
527530
@torch.jit.export
528531
def getb(self, x):
529532
return self.b
@@ -551,15 +554,15 @@ def __init__(self):
551554

552555
def forward(self, x):
553556
y = self.proxy_mod(x)
554-
z= self.sub.getb(x)
557+
z = self.sub.getb(x)
555558
return y[0] + z[0]
556559

557560
m = torch.jit.script(TestModule())
558561
m.proxy_mod = m.sub
559562
m.sub.b = m.proxy_mod.b
560563
m.eval()
561564
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
562-
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
565+
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
563566

564567
def test_freeze_module_with_mutated_interface(self):
565568
class SubModule(torch.nn.Module):
@@ -569,6 +572,7 @@ def __init__(self):
569572

570573
def forward(self, x):
571574
return self.b
575+
572576
@torch.jit.export
573577
def getb(self, x):
574578
return self.b
@@ -597,13 +601,13 @@ def __init__(self):
597601
def forward(self, x):
598602
self.proxy_mod = self.sub
599603
y = self.proxy_mod(x)
600-
z= self.sub.getb(x)
604+
z = self.sub.getb(x)
601605
return y[0] + z[0]
602606

603607
m = torch.jit.script(TestModule())
604608
m.eval()
605609
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
606-
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
610+
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
607611

608612
def test_freeze_module_with_interface_and_fork(self):
609613
class SubModule(torch.nn.Module):
@@ -638,13 +642,13 @@ def __init__(self):
638642

639643
def forward(self, x):
640644
y = self.proxy_mod(x)
641-
z= self.sub(x)
645+
z = self.sub(x)
642646
return y + z
643647

644648
class MainModule(torch.nn.Module):
645649
def __init__(self):
646650
super(MainModule, self).__init__()
647-
self.test= TestModule()
651+
self.test = TestModule()
648652

649653
def forward(self, x):
650654
fut = torch.jit._fork(self.test.forward, x)
@@ -654,7 +658,7 @@ def forward(self, x):
654658

655659
m = torch.jit.script(MainModule())
656660
m.eval()
657-
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
661+
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
658662

659663
def test_module_apis_interface(self):
660664
@torch.jit.interface

0 commit comments

Comments
 (0)