@@ -8262,6 +8262,44 @@ def test_irparser(self):
82628262 """
82638263 FileCheck().run(graph_str, parse_ir(graph_str))
82648264
8265+ def test_parse_tensor_constants(self):
8266+ def foo():
8267+ return torch.zeros([4, 4])
8268+
8269+ foo_s = torch.jit.script(foo)
8270+ torch._C._jit_pass_constant_propagation(foo_s.graph)
8271+
8272+ g = str(foo_s.graph)
8273+ g_parsed = parse_ir(g, parse_tensor_constants=True)
8274+ self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph)))
8275+ func = torch._C._create_function_from_graph("forward", g_parsed)
8276+
8277+ out_parsed = func()
8278+ out_func = foo()
8279+ # not checking data, just dtype, size etc
8280+ out_parsed[:] = 0
8281+ out_func[:] = 0
8282+ self.assertEqual(out_func, out_parsed)
8283+
8284+ with self.assertRaises(RuntimeError):
8285+ parse_ir(g, parse_tensor_constants=False)
8286+
8287+ def test_parse_nested_names(self):
8288+ g_str = """
8289+ graph(%x.1 : Tensor):
8290+ %3 : int = prim::Constant[value=1]()
8291+ %2 : int = prim::Constant[value=2]()
8292+ %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3)
8293+ return (%hi.submod.value.5)
8294+ """
8295+ g = parse_ir(g_str)
8296+ round_trip_g = parse_ir(str(g))
8297+ self.assertEqual(canonical(g), canonical(round_trip_g))
8298+
8299+ func1 = torch._C._create_function_from_graph("forward", g)
8300+ func2 = torch._C._create_function_from_graph("forward", round_trip_g)
8301+ self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2])))
8302+
82658303 def test_is_after_use(self):
82668304 def sorted_input_use(g):
82678305 uses = list(next(g.inputs()).uses())
0 commit comments