Skip to content

Commit cf2889a

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
add support for breaks and continues (#21692)
Summary: Add support for breaks and continues in the jit. We do with a Graph transform pre-SSA. A graph of the form ``` def test(): while i < 5: if i == 3: break i += 1 print(i) ``` has the body of the loop transformed to ``` if i == 3: did_break = True else: did_break = False if did_break: loop_exit = True else: i += 1 print(i) loop_exit = i < 5 ``` I am going to add more tests but I think it is ready for review now. Pull Request resolved: #21692 Differential Revision: D16215807 Pulled By: eellison fbshipit-source-id: 365102f42de4861d9323caaeb39a96de7619a667
1 parent b3147bc commit cf2889a

19 files changed

+1131
-111
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ namespace c10 {
4949
_(prim, IgnoredPythonOp) \
5050
_(prim, Reverse) \
5151
_(prim, Return) \
52+
_(prim, BreakStmt) \
53+
_(prim, ContinueStmt) \
5254
_(prim, Store) \
5355
_(prim, AutogradZero) \
5456
_(prim, AutogradAnyNonZero) \
@@ -109,6 +111,7 @@ namespace c10 {
109111
_(prim, TimePoint) \
110112
_(prim, CallFunction) \
111113
_(prim, CallMethod) \
114+
_(prim, LoopContinuation) \
112115
_(aten, append) \
113116
_(aten, item) \
114117
_(aten, format) \

caffe2/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,12 +430,15 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
430430
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
431431
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
432432
${TORCH_SRC_DIR}/csrc/jit/script/convert_to_ssa.cpp
433+
${TORCH_SRC_DIR}/csrc/jit/script/exit_transforms.cpp
434+
${TORCH_SRC_DIR}/csrc/jit/script/inline_loop_condition.cpp
433435
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
434436
${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
435437
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
436438
${TORCH_SRC_DIR}/csrc/jit/script/class_type.cpp
437439
${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
438440
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
441+
${TORCH_SRC_DIR}/csrc/jit/script/canonicalize_modified_loop.cpp
439442
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
440443
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
441444
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp

test/test_jit.py

Lines changed: 262 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Torch
44
from torch import Tensor
5-
from torch._C import TensorType, parse_ir, _propagate_shapes, _jit_python_print
5+
from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes, _jit_python_print
66
from torch._six import inf, PY2, PY37, StringIO
77
from torch.autograd import Variable, Function
88
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
@@ -6008,6 +6008,267 @@ def forward(self, x):
60086008

60096009
self.assertEqual(script_output, eager_output)
60106010

6011+
def test_nested_breaks(self):
6012+
def no_bool_loop_outputs(g):
6013+
# testing that the "did exit" transform values are not loop block
6014+
# outputs (and thus not affecting one loop from another)
6015+
loops = g.findAllNodes("prim::Loop")
6016+
for loop in loops:
6017+
for out in loop.outputs():
6018+
self.assertTrue(out.type() != BoolType.get())
6019+
6020+
def test(y):
6021+
# type: (int)
6022+
ret = 0
6023+
tensor = torch.tensor(0)
6024+
while int(tensor.add_(1)) < 4:
6025+
if y == 1:
6026+
continue
6027+
for i in range(y):
6028+
continue
6029+
ret += 1
6030+
ret += 1
6031+
return ret, int(tensor)
6032+
6033+
self.checkScript(test, (1,))
6034+
self.checkScript(test, (2,))
6035+
no_bool_loop_outputs(torch.jit.script(test).graph)
6036+
6037+
def foo():
6038+
y = torch.tensor(0)
6039+
z = 0
6040+
while int(y.add_(1)) < 20:
6041+
if int(y) < 10:
6042+
for i in range(6):
6043+
if i == 3:
6044+
continue
6045+
else:
6046+
if i > 3:
6047+
break
6048+
z += 2
6049+
if int(y) == 18:
6050+
break
6051+
if int(y) == 15:
6052+
continue
6053+
z += 1
6054+
return int(y), z
6055+
6056+
no_bool_loop_outputs(torch.jit.script(foo).graph)
6057+
self.checkScript(foo, ())
6058+
6059+
def test_nested_two():
6060+
i = 0
6061+
k = 0
6062+
while i < 5:
6063+
for j in range(5):
6064+
k += 1
6065+
if j == 3:
6066+
continue
6067+
i += 1
6068+
k += 1
6069+
if i == 4:
6070+
break
6071+
return i, k
6072+
6073+
self.checkScript(test_nested_two, ())
6074+
no_bool_loop_outputs(torch.jit.script(test_nested_two).graph)
6075+
6076+
def test_breaks_continues(self):
6077+
def foo_continue(cond):
6078+
# type: (int)
6079+
j = 1
6080+
for i in range(5):
6081+
if i == cond:
6082+
continue
6083+
j += 1
6084+
return j
6085+
6086+
def foo_break(cond):
6087+
# type: (int)
6088+
j = 1
6089+
for i in range(5):
6090+
if i == cond:
6091+
break
6092+
j += 1
6093+
return j
6094+
6095+
for i in range(1, 4):
6096+
self.checkScript(foo_continue, (i,))
6097+
self.checkScript(foo_break, (i,))
6098+
6099+
def test_refine_outside_loop():
6100+
if True:
6101+
x = None
6102+
else:
6103+
x = 1
6104+
i = 0
6105+
j = 0
6106+
while (x is None or torch.jit._unwrap_optional(x) > 3):
6107+
if i < 3:
6108+
if i < 3:
6109+
x = torch.jit.annotate(Optional[int], None)
6110+
i += 1
6111+
continue
6112+
x = 1
6113+
else:
6114+
x = 1 if x is None else x
6115+
x = x + 1
6116+
j = x + x
6117+
6118+
return x, j
6119+
6120+
self.checkScript(test_refine_outside_loop, ())
6121+
6122+
def assign_after_break(y):
6123+
# type: (int)
6124+
x = 0
6125+
for i in range(y):
6126+
x = y * 2 + i
6127+
break
6128+
x = 4
6129+
return x
6130+
6131+
self.checkScript(assign_after_break, (1,))
6132+
self.checkScript(assign_after_break, (2,))
6133+
self.checkScript(assign_after_break, (3,))
6134+
6135+
def assign_after_break_nested(y):
6136+
# type: (int)
6137+
x = 0
6138+
for i in range(y):
6139+
if y == 1:
6140+
x = 5
6141+
break
6142+
assert 1 == 2
6143+
else:
6144+
x = x + 1
6145+
break
6146+
assert 1 == 2
6147+
x = -30
6148+
assert 1 == 2
6149+
return x
6150+
6151+
self.checkScript(assign_after_break_nested, (1,))
6152+
self.checkScript(assign_after_break_nested, (2,))
6153+
self.checkScript(assign_after_break_nested, (3,))
6154+
6155+
def may_break(y):
6156+
# type: (int)
6157+
x = 0
6158+
for i in range(y):
6159+
if y == 1:
6160+
x = 5
6161+
else:
6162+
x = x + 1
6163+
break
6164+
x = -30
6165+
return x
6166+
6167+
self.checkScript(may_break, (1,))
6168+
self.checkScript(may_break, (2,))
6169+
self.checkScript(may_break, (3,))
6170+
6171+
def test(x, y):
6172+
# type: (int, int)
6173+
a = 1
6174+
while (x > 0):
6175+
if y == 3:
6176+
for i in range(y):
6177+
a += (1 % (i + 1))
6178+
x -= 1
6179+
if x == 3:
6180+
a = x * 3
6181+
break
6182+
if x < 3:
6183+
if x == 1:
6184+
a -= 2
6185+
x -= 1
6186+
break
6187+
a -= 1
6188+
x -= 3
6189+
return a, x
6190+
6191+
self.checkScript(test, (10, 3))
6192+
self.checkScript(test, (10, 2))
6193+
self.checkScript(test, (3, 2))
6194+
self.checkScript(test, (5, 3))
6195+
self.checkScript(test, (2, 3))
6196+
6197+
def test_delete_after_break(x):
6198+
# type: (int)
6199+
a = 1
6200+
b = 1
6201+
for i in range(x):
6202+
a = i * 3
6203+
break
6204+
b = i * 5
6205+
return a, b
6206+
6207+
self.checkScript(test_delete_after_break, (0,))
6208+
self.checkScript(test_delete_after_break, (1,))
6209+
6210+
def test_will_break_after_guard(x):
6211+
# type: (int)
6212+
a = 1
6213+
for i in range(x):
6214+
if i == 4:
6215+
a = 3
6216+
break
6217+
a -= 1
6218+
break
6219+
assert 1 == 2
6220+
a -= -100
6221+
return a
6222+
6223+
self.checkScript(test_will_break_after_guard, (0,))
6224+
self.checkScript(test_will_break_after_guard, (2,))
6225+
self.checkScript(test_will_break_after_guard, (4,))
6226+
6227+
def test_varexit(cond):
6228+
# type: (int)
6229+
m = 0
6230+
for i in range(3):
6231+
if cond == 2:
6232+
if cond == 2:
6233+
m = 2
6234+
break
6235+
k = 1
6236+
else:
6237+
k = 2
6238+
m += k
6239+
return m
6240+
6241+
# use of k tests the pathway where we have to insert unitialized
6242+
self.checkScript(test_varexit, (3,))
6243+
self.checkScript(test_varexit, (2,))
6244+
6245+
def test_break_true():
6246+
i = 0
6247+
while True:
6248+
i += 1
6249+
if i == 3:
6250+
break
6251+
while False:
6252+
i += 1
6253+
return i
6254+
6255+
self.checkScript(test_break_true, ())
6256+
6257+
def test_break_continue_error(self):
6258+
with self.assertRaisesRegex(RuntimeError, "Syntax"):
6259+
cu = torch.jit.CompilationUnit('''
6260+
def other_func(a):
6261+
break
6262+
''')
6263+
6264+
with self.assertRaisesRegex(RuntimeError, "Syntax"):
6265+
cu = torch.jit.CompilationUnit('''
6266+
def other_func(a):
6267+
for i in range(5):
6268+
def foo():
6269+
break
6270+
''')
6271+
60116272
def test_python_call(self):
60126273
def pyfunc(a):
60136274
return a * 3.0

tools/build_variables.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@
117117
"torch/csrc/jit/script/logging.cpp",
118118
"torch/csrc/jit/script/final_returns.cpp",
119119
"torch/csrc/jit/script/convert_to_ssa.cpp",
120+
"torch/csrc/jit/script/exit_transforms.cpp",
121+
"torch/csrc/jit/script/inline_loop_condition.cpp",
122+
"torch/csrc/jit/script/canonicalize_modified_loop.cpp",
120123
"torch/csrc/jit/script/script_type_parser.cpp",
121124
"torch/csrc/jit/script/sugared_value.cpp",
122125
"torch/csrc/jit/script/schema_matching.cpp",

torch/csrc/jit/ir_views.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,44 @@ struct LoopView {
101101
bodyBlock()->permuteInputs(adjusted_block_order);
102102
}
103103

104+
void replaceMaxTripCount(Value* new_max_trip_count) {
105+
node_->replaceInput(0, new_max_trip_count);
106+
}
107+
void replaceInputCondition(Value* new_input_condition) {
108+
node_->replaceInput(1, new_input_condition);
109+
}
110+
111+
// our way of encoding loops makes them difficult to turn back into python
112+
// syntax. we have to check properties of the condition and trip count inputs
113+
// to figure out which one it initially was. ModifiedLoops are not directly
114+
// mappable to either For or While
115+
enum LoopType { While, For, ModifiedLoop };
116+
117+
LoopType loopType() {
118+
auto trip_count = toIValue(maxTripCount());
119+
auto cond_input = toIValue(inputCond());
120+
auto cond_next = toIValue(nextCond());
121+
122+
bool condition_is_always_true =
123+
cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
124+
bool trip_count_is_specified = !trip_count || // trip is not a constant
125+
trip_count->toInt() !=
126+
std::numeric_limits<int64_t>::max() || // it is a constant but not
127+
// the default one
128+
currentTripCount()->uses().size() >
129+
0; // it is actually being used in the body.
130+
131+
if (condition_is_always_true) {
132+
// if the trip count was not specified this was a user-written while True:
133+
return trip_count_is_specified ? For : While;
134+
} else {
135+
if (trip_count_is_specified) {
136+
return ModifiedLoop;
137+
}
138+
return While;
139+
}
140+
}
141+
104142
private:
105143
Node* node_;
106144

0 commit comments

Comments
 (0)