@@ -1177,7 +1177,6 @@ def forward(self, x):
11771177 weight=weight_observer._c)
11781178 }
11791179 torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, True)
1180- print()
11811180 dtypes = set([obs.getattr('dtype') for x, obs in m.conv._modules._c.items()
11821181 if x.startswith('_observer_')])
11831182 assert len(dtypes) == 2, 'Expected to have 2 different types of dtype'
@@ -1263,6 +1262,10 @@ def forward(self, x, w0, w1, w2):
12631262
12641263 m = torch.jit.script(M())
12651264 observer = torch.jit.script(default_observer())
1265+
1266+ # run the observer once to avoid warning on an empty observer
1267+ observer(torch.rand(2, 2))
1268+
12661269 qconfig_dict = {
12671270 '':
12681271 QConfig(
@@ -2069,19 +2072,6 @@ def f(x):
20692072
20702073 self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
20712074
2072- def test_legacy_fail(self):
2073- class MyLegacyFn(Function):
2074- def forward(self, x):
2075- return x
2076-
2077- def backward(self, grad_output):
2078- return grad_output
2079-
2080- x = torch.tensor([0.], requires_grad=True)
2081- with warnings.catch_warnings(record=True):
2082- with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
2083- torch.jit._get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
2084-
20852075 def test_inplace_transplant(self):
20862076 x = torch.tensor([0.], requires_grad=True)
20872077
@@ -2235,6 +2225,9 @@ def full_with_shape_like(x):
22352225 self.assertEqual(ge(y).shape, y.shape)
22362226 self.assertEqual(ge(x).shape, x.shape)
22372227
2228+ # Suppression: we are intentionally slicing a tensor, we don't care that it
2229+ # will be constantified
2230+ @suppress_warnings
22382231 def do_trace_slice(self, requires_grad):
22392232 def slice(x):
22402233 results = []
@@ -4409,7 +4402,7 @@ def __init__(self, fn):
44094402 def forward(self, x):
44104403 return self.fn(x)
44114404
4412- m = M(F .sigmoid)
4405+ m = M(torch .sigmoid)
44134406 inp = torch.rand(2, 3)
44144407 self.checkModule(m, (inp, ))
44154408
@@ -5017,12 +5010,12 @@ def weighted_kernel_sum(self, weight):
50175010 check_weight = torch.rand(1, 1, 3, 3)
50185011 check_forward_input = torch.rand(1, 1, 3, 3)
50195012 check_inputs.append({'forward' : check_forward_input, 'weighted_kernel_sum' : check_weight})
5020- module = torch.jit.trace_module(n, inputs, True, True, check_inputs)
5013+ module = torch.jit.trace_module(n, inputs, check_trace= True, check_inputs= check_inputs)
50215014 self.assertTrue(module._c._has_method("forward"))
50225015 self.assertTrue(module._c._has_method("weighted_kernel_sum"))
50235016
50245017 module = torch.jit.trace(n.forward, example_forward_input)
5025- module = torch.jit.trace(n.forward, example_forward_input, True, [example_forward_input])
5018+ module = torch.jit.trace(n.forward, example_forward_input, check_trace= True, check_inputs= [example_forward_input])
50265019 with self.assertRaisesRegex(AttributeError, "trace doesn't support compiling individual module's functions"):
50275020 module = torch.jit.trace(n.weighted_kernel_sum, inputs)
50285021
@@ -11148,6 +11141,8 @@ def forward(self, x):
1114811141 # test copy
1114911142 m_c = m.copy()
1115011143
11144+ # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
11145+ @suppress_warnings
1115111146 @skipIfCompiledWithoutNumpy
1115211147 def test_rnn_trace_override(self):
1115311148 from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
@@ -12123,6 +12118,7 @@ def forward(self, x):
1212312118 mte, (torch.zeros(1, 2, 3),), None, verbose=False,
1212412119 example_outputs=outputs)
1212512120
12121+ @suppress_warnings
1212612122 def test_onnx_export_script_truediv(self):
1212712123 class ModuleToExport(torch.jit.ScriptModule):
1212812124 def __init__(self):
@@ -12135,8 +12131,9 @@ def forward(self, x):
1213512131
1213612132 mte = ModuleToExport()
1213712133 outputs = mte(torch.zeros(1, 2, 3))
12134+
1213812135 torch.onnx.export_to_pretty_string(
12139- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
12136+ mte, (torch.zeros(1, 2, 3, dtype=torch.float ),), None, verbose=False,
1214012137 example_outputs=outputs)
1214112138
1214212139 def test_onnx_raw_export_script_truediv(self):
@@ -12153,6 +12150,7 @@ def forward(self, x):
1215312150 outputs = mte(torch.zeros(1, 2, 3))
1215412151 torch.onnx.export_to_pretty_string(
1215512152 mte, (torch.zeros(1, 2, 3),), None, verbose=False,
12153+ add_node_names=False, do_constant_folding=False,
1215612154 example_outputs=outputs, export_raw_ir=True)
1215712155
1215812156 def test_onnx_export_script_non_alpha_add_sub(self):
@@ -14734,6 +14732,8 @@ def forward(self, x):
1473414732 f = io.BytesIO()
1473514733 torch.onnx.export_to_pretty_string(
1473614734 FooMod(), (torch.rand(3, 4),), f,
14735+ add_node_names=False,
14736+ do_constant_folding=False,
1473714737 operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
1473814738
1473914739 @suppress_warnings
@@ -14766,6 +14766,8 @@ def foo(x):
1476614766 traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
1476714767
1476814768 if 'fbgemm' in torch.backends.quantized.supported_engines:
14769+ # Suppression: using deprecated quant api
14770+ @suppress_warnings
1476914771 def test_quantization_modules(self):
1477014772 K1, N1 = 2, 2
1477114773
0 commit comments