Skip to content

Commit bd7e9c4

Browse files
suofacebook-github-bot
authored andcommitted
[jit] stop printing crap in test_jit (#33917)
Summary: Pull Request resolved: #33917 Test Plan: Imported from OSS Differential Revision: D20150750 Pulled By: suo fbshipit-source-id: 9a35298a8856d423fb6b9043174853cccf968706
1 parent d66c320 commit bd7e9c4

File tree

8 files changed

+40
-26
lines changed

8 files changed

+40
-26
lines changed

test/jit/test_class_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,8 @@ def __init__(self, w):
926926

927927
def forward(self, x):
928928
# Make sure class constant is accessible in method
929-
print(self.w)
930-
return x
929+
y = self.w
930+
return x, y
931931

932932
# Test serialization/deserialization of class constant
933933
for c in (2, 1.0, None, True, 'str', (2, 3), [5.9, 7.3]):

test/jit/test_export_modes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def forward(self, x, y):
7979
f = io.BytesIO()
8080
torch.onnx.export_to_pretty_string(
8181
ModelWithAtenNotONNXOp(), (x, y), f,
82+
add_node_names=False,
83+
do_constant_folding=False,
8284
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
8385

8486
# torch.fmod is using to test ONNX_ATEN.
@@ -94,4 +96,6 @@ def forward(self, x, y):
9496
y = torch.randn(3, 4, dtype=torch.float32)
9597
torch.onnx.export_to_pretty_string(
9698
ModelWithAtenFmod(), (x, y), f,
99+
add_node_names=False,
100+
do_constant_folding=False,
97101
operator_export_type=OperatorExportTypes.ONNX_ATEN)

test/jit/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def test_snli(self):
396396
self._test_snli(self, device='cpu')
397397

398398
if 'fbgemm' in torch.backends.quantized.supported_engines:
399+
# Suppression: this exercises a deprecated API
400+
@suppress_warnings
399401
def test_snli_quantized(self):
400402
self._test_snli(self, device='cpu', quantized=True)
401403

@@ -540,6 +542,8 @@ def test_vae(self):
540542
self._test_vae(self, device='cpu')
541543

542544
if 'fbgemm' in torch.backends.quantized.supported_engines:
545+
# Suppression: this exercises a deprecated API
546+
@suppress_warnings
543547
def test_vae_quantized(self):
544548
self._test_vae(self, device='cpu', quantized=True)
545549

test/jit/test_recursive_script.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.nn as nn
8-
import torch.nn.functional as F
98
from torch import Tensor
109
from torch.testing import FileCheck
1110
from collections import OrderedDict
@@ -65,7 +64,7 @@ def __init__(self, fn):
6564
def forward(self, x):
6665
return self.fn(x)
6766

68-
mod = M(F.sigmoid)
67+
mod = M(torch.sigmoid)
6968

7069
self.checkModule(mod, (torch.randn(2, 2),))
7170

test/jit/test_unsupported_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def orthogonal_():
111111
return torch.nn.init.orthogonal_(torch.empty(3, 5))
112112

113113
def sparse():
114-
return torch.nn.init.sparse(torch.empty(3, 5), sparsity=.1)
114+
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=.1)
115115

116116
for func in [calculate_gain, eye_, dirac_, kaiming_uniform_, orthogonal_, sparse]:
117117
# doesn't error in eager

test/test_jit.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,7 +1177,6 @@ def forward(self, x):
11771177
weight=weight_observer._c)
11781178
}
11791179
torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, True)
1180-
print()
11811180
dtypes = set([obs.getattr('dtype') for x, obs in m.conv._modules._c.items()
11821181
if x.startswith('_observer_')])
11831182
assert len(dtypes) == 2, 'Expected to have 2 different types of dtype'
@@ -1263,6 +1262,10 @@ def forward(self, x, w0, w1, w2):
12631262

12641263
m = torch.jit.script(M())
12651264
observer = torch.jit.script(default_observer())
1265+
1266+
# run the observer once to avoid warning on an empty observer
1267+
observer(torch.rand(2, 2))
1268+
12661269
qconfig_dict = {
12671270
'':
12681271
QConfig(
@@ -2069,19 +2072,6 @@ def f(x):
20692072

20702073
self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
20712074

2072-
def test_legacy_fail(self):
2073-
class MyLegacyFn(Function):
2074-
def forward(self, x):
2075-
return x
2076-
2077-
def backward(self, grad_output):
2078-
return grad_output
2079-
2080-
x = torch.tensor([0.], requires_grad=True)
2081-
with warnings.catch_warnings(record=True):
2082-
with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
2083-
torch.jit._get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
2084-
20852075
def test_inplace_transplant(self):
20862076
x = torch.tensor([0.], requires_grad=True)
20872077

@@ -2235,6 +2225,9 @@ def full_with_shape_like(x):
22352225
self.assertEqual(ge(y).shape, y.shape)
22362226
self.assertEqual(ge(x).shape, x.shape)
22372227

2228+
# Suppression: we are intentionally slicing a tensor, we don't care that it
2229+
# will be constantified
2230+
@suppress_warnings
22382231
def do_trace_slice(self, requires_grad):
22392232
def slice(x):
22402233
results = []
@@ -4409,7 +4402,7 @@ def __init__(self, fn):
44094402
def forward(self, x):
44104403
return self.fn(x)
44114404

4412-
m = M(F.sigmoid)
4405+
m = M(torch.sigmoid)
44134406
inp = torch.rand(2, 3)
44144407
self.checkModule(m, (inp, ))
44154408

@@ -5017,12 +5010,12 @@ def weighted_kernel_sum(self, weight):
50175010
check_weight = torch.rand(1, 1, 3, 3)
50185011
check_forward_input = torch.rand(1, 1, 3, 3)
50195012
check_inputs.append({'forward' : check_forward_input, 'weighted_kernel_sum' : check_weight})
5020-
module = torch.jit.trace_module(n, inputs, True, True, check_inputs)
5013+
module = torch.jit.trace_module(n, inputs, check_trace=True, check_inputs=check_inputs)
50215014
self.assertTrue(module._c._has_method("forward"))
50225015
self.assertTrue(module._c._has_method("weighted_kernel_sum"))
50235016

50245017
module = torch.jit.trace(n.forward, example_forward_input)
5025-
module = torch.jit.trace(n.forward, example_forward_input, True, [example_forward_input])
5018+
module = torch.jit.trace(n.forward, example_forward_input, check_trace=True, check_inputs=[example_forward_input])
50265019
with self.assertRaisesRegex(AttributeError, "trace doesn't support compiling individual module's functions"):
50275020
module = torch.jit.trace(n.weighted_kernel_sum, inputs)
50285021

@@ -11148,6 +11141,8 @@ def forward(self, x):
1114811141
# test copy
1114911142
m_c = m.copy()
1115011143

11144+
# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
11145+
@suppress_warnings
1115111146
@skipIfCompiledWithoutNumpy
1115211147
def test_rnn_trace_override(self):
1115311148
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
@@ -12123,6 +12118,7 @@ def forward(self, x):
1212312118
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
1212412119
example_outputs=outputs)
1212512120

12121+
@suppress_warnings
1212612122
def test_onnx_export_script_truediv(self):
1212712123
class ModuleToExport(torch.jit.ScriptModule):
1212812124
def __init__(self):
@@ -12135,8 +12131,9 @@ def forward(self, x):
1213512131

1213612132
mte = ModuleToExport()
1213712133
outputs = mte(torch.zeros(1, 2, 3))
12134+
1213812135
torch.onnx.export_to_pretty_string(
12139-
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
12136+
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False,
1214012137
example_outputs=outputs)
1214112138

1214212139
def test_onnx_raw_export_script_truediv(self):
@@ -12153,6 +12150,7 @@ def forward(self, x):
1215312150
outputs = mte(torch.zeros(1, 2, 3))
1215412151
torch.onnx.export_to_pretty_string(
1215512152
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
12153+
add_node_names=False, do_constant_folding=False,
1215612154
example_outputs=outputs, export_raw_ir=True)
1215712155

1215812156
def test_onnx_export_script_non_alpha_add_sub(self):
@@ -14734,6 +14732,8 @@ def forward(self, x):
1473414732
f = io.BytesIO()
1473514733
torch.onnx.export_to_pretty_string(
1473614734
FooMod(), (torch.rand(3, 4),), f,
14735+
add_node_names=False,
14736+
do_constant_folding=False,
1473714737
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
1473814738

1473914739
@suppress_warnings
@@ -14766,6 +14766,8 @@ def foo(x):
1476614766
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
1476714767

1476814768
if 'fbgemm' in torch.backends.quantized.supported_engines:
14769+
# Suppression: using deprecated quant api
14770+
@suppress_warnings
1476914771
def test_quantization_modules(self):
1477014772
K1, N1 = 2, 2
1477114773

torch/onnx/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _decide_external_data_format(use_external_data_format, operator_export_type,
252252
val_use_external_data_format = _resolve_args_by_export_type("use_external_data_format",
253253
use_external_data_format,
254254
operator_export_type)
255-
# f can be a non-string in regular-sized model export case, but for large model export, f must be a non-empty
255+
# f can be a non-string in regular-sized model export case, but for large model export, f must be a non-empty
256256
# string specifying the location of the model. For large model cases, if f is not a non-empty string,
257257
# then this method returns an empty string, which is an error condition for the large model export code
258258
# path later (but not for regular model export code path).
@@ -392,7 +392,8 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
392392
operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE,
393393
example_outputs=None, propagate=False, google_printer=False,
394394
opset_version=None, _retain_param_name=True,
395-
keep_initializers_as_inputs=None, custom_opsets=None):
395+
keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True,
396+
do_constant_folding=True):
396397
if aten or export_raw_ir:
397398
assert operator_export_type is None
398399
assert aten ^ export_raw_ir
@@ -403,6 +404,8 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
403404
input_names, output_names, operator_export_type,
404405
export_type, example_outputs, propagate, google_printer,
405406
opset_version, _retain_param_name,
407+
do_constant_folding=do_constant_folding,
408+
add_node_names=add_node_names,
406409
keep_initializers_as_inputs=keep_initializers_as_inputs,
407410
custom_opsets=custom_opsets)
408411

torch/testing/_internal/jit_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def checkScript(self,
385385
source,
386386
inputs,
387387
script.__name__,
388-
capture_output,
388+
optimize=optimize,
389+
inputs_requires_grad=inputs_requires_grad,
390+
capture_output=capture_output,
389391
profiling=profiling,
390392
frames_up=2)
391393

0 commit comments

Comments
 (0)