Skip to content

Commit 931616a

Browse files
author
Elias Ellison
committed
Add support for nested var names in parser
ghstack-source-id: f4592e5 Pull Request resolved: #75124
1 parent 406414e commit 931616a

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

test/test_jit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8262,6 +8262,44 @@ def test_irparser(self):
82628262
"""
82638263
FileCheck().run(graph_str, parse_ir(graph_str))
82648264

8265+
def test_parse_tensor_constants(self):
8266+
def foo():
8267+
return torch.zeros([4, 4])
8268+
8269+
foo_s = torch.jit.script(foo)
8270+
torch._C._jit_pass_constant_propagation(foo_s.graph)
8271+
8272+
g = str(foo_s.graph)
8273+
g_parsed = parse_ir(g, parse_tensor_constants=True)
8274+
self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph)))
8275+
func = torch._C._create_function_from_graph("forward", g_parsed)
8276+
8277+
out_parsed = func()
8278+
out_func = foo()
8279+
# not checking data, just dtype, size etc
8280+
out_parsed[:] = 0
8281+
out_func[:] = 0
8282+
self.assertEqual(out_func, out_parsed)
8283+
8284+
with self.assertRaises(RuntimeError):
8285+
parse_ir(g, parse_tensor_constants=False)
8286+
8287+
def test_parse_nested_names(self):
8288+
g_str = """
8289+
graph(%x.1 : Tensor):
8290+
%3 : int = prim::Constant[value=1]()
8291+
%2 : int = prim::Constant[value=2]()
8292+
%hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3)
8293+
return (%hi.submod.value.5)
8294+
"""
8295+
g = parse_ir(g_str)
8296+
round_trip_g = parse_ir(str(g))
8297+
self.assertEqual(canonical(g), canonical(round_trip_g))
8298+
8299+
func1 = torch._C._create_function_from_graph("forward", g)
8300+
func2 = torch._C._create_function_from_graph("forward", round_trip_g)
8301+
self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2])))
8302+
82658303
def test_is_after_use(self):
82668304
def sorted_input_use(g):
82678305
uses = list(next(g.inputs()).uses())

torch/csrc/jit/ir/irparser.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,23 @@ VarWithType IRParser::parseVarWithType(bool allow_optional) {
128128

129129
std::string IRParser::parseVar() {
130130
L.expect('%');
131-
if (L.cur().kind == TK_IDENT) {
132-
auto name = L.expect(TK_IDENT).text();
133-
if (L.cur().kind == TK_NUMBER) {
134-
auto suffix = L.expect(TK_NUMBER).text();
135-
AT_ASSERT(suffix[0] == '.');
136-
name += suffix;
131+
std::string name;
132+
bool continue_parsing;
133+
do {
134+
if (L.cur().kind == TK_IDENT) {
135+
name += L.expect(TK_IDENT).text();
136+
} else {
137+
name += L.expect(TK_NUMBER).text();
137138
}
138-
return name;
139-
} else {
140-
return L.expect(TK_NUMBER).text();
141-
}
139+
continue_parsing = false;
140+
if (L.nextIf('.')) {
141+
continue_parsing = true;
142+
name += '.';
143+
} else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') {
144+
continue_parsing = true;
145+
}
146+
} while (continue_parsing);
147+
return name;
142148
}
143149

144150
void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {

0 commit comments

Comments
 (0)