Skip to content
Closed
3 changes: 3 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ namespace c10 {
_(prim, IgnoredPythonOp) \
_(prim, Reverse) \
_(prim, Return) \
_(prim, BreakStmt) \
_(prim, ContinueStmt) \
_(prim, Store) \
_(prim, AutogradZero) \
_(prim, AutogradAnyNonZero) \
Expand Down Expand Up @@ -109,6 +111,7 @@ namespace c10 {
_(prim, TimePoint) \
_(prim, CallFunction) \
_(prim, CallMethod) \
_(prim, LoopContinuation) \
_(aten, append) \
_(aten, item) \
_(aten, format) \
Expand Down
3 changes: 3 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,15 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
${TORCH_SRC_DIR}/csrc/jit/script/convert_to_ssa.cpp
${TORCH_SRC_DIR}/csrc/jit/script/exit_transforms.cpp
${TORCH_SRC_DIR}/csrc/jit/script/inline_loop_condition.cpp
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
${TORCH_SRC_DIR}/csrc/jit/script/class_type.cpp
${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/script/canonicalize_modified_loop.cpp
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
Expand Down
263 changes: 262 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Torch
from torch import Tensor
from torch._C import TensorType, parse_ir, _propagate_shapes, _jit_python_print
from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes, _jit_python_print
from torch._six import inf, PY2, PY37, StringIO
from torch.autograd import Variable, Function
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
Expand Down Expand Up @@ -6008,6 +6008,267 @@ def forward(self, x):

self.assertEqual(script_output, eager_output)

def test_nested_breaks(self):
def no_bool_loop_outputs(g):
# testing that the "did exit" transform values are not loop block
# outputs (and thus not affecting one loop from another)
loops = g.findAllNodes("prim::Loop")
for loop in loops:
for out in loop.outputs():
self.assertTrue(out.type() != BoolType.get())

def test(y):
# type: (int)
ret = 0
tensor = torch.tensor(0)
while int(tensor.add_(1)) < 4:
if y == 1:
continue
for i in range(y):
continue
ret += 1
ret += 1
return ret, int(tensor)

self.checkScript(test, (1,))
self.checkScript(test, (2,))
no_bool_loop_outputs(torch.jit.script(test).graph)

def foo():
y = torch.tensor(0)
z = 0
while int(y.add_(1)) < 20:
if int(y) < 10:
for i in range(6):
if i == 3:
continue
else:
if i > 3:
break
z += 2
if int(y) == 18:
break
if int(y) == 15:
continue
z += 1
return int(y), z

no_bool_loop_outputs(torch.jit.script(foo).graph)
self.checkScript(foo, ())

def test_nested_two():
i = 0
k = 0
while i < 5:
for j in range(5):
k += 1
if j == 3:
continue
i += 1
k += 1
if i == 4:
break
return i, k

self.checkScript(test_nested_two, ())
no_bool_loop_outputs(torch.jit.script(test_nested_two).graph)

def test_breaks_continues(self):
def foo_continue(cond):
# type: (int)
j = 1
for i in range(5):
if i == cond:
continue
j += 1
return j

def foo_break(cond):
# type: (int)
j = 1
for i in range(5):
if i == cond:
break
j += 1
return j

for i in range(1, 4):
self.checkScript(foo_continue, (i,))
self.checkScript(foo_break, (i,))

def test_refine_outside_loop():
if True:
x = None
else:
x = 1
i = 0
j = 0
while (x is None or torch.jit._unwrap_optional(x) > 3):
if i < 3:
if i < 3:
x = torch.jit.annotate(Optional[int], None)
i += 1
continue
x = 1
else:
x = 1 if x is None else x
x = x + 1
j = x + x

return x, j

self.checkScript(test_refine_outside_loop, ())

def assign_after_break(y):
# type: (int)
x = 0
for i in range(y):
x = y * 2 + i
break
x = 4
return x

self.checkScript(assign_after_break, (1,))
self.checkScript(assign_after_break, (2,))
self.checkScript(assign_after_break, (3,))

def assign_after_break_nested(y):
# type: (int)
x = 0
for i in range(y):
if y == 1:
x = 5
break
assert 1 == 2
else:
x = x + 1
break
assert 1 == 2
x = -30
assert 1 == 2
return x

self.checkScript(assign_after_break_nested, (1,))
self.checkScript(assign_after_break_nested, (2,))
self.checkScript(assign_after_break_nested, (3,))

def may_break(y):
# type: (int)
x = 0
for i in range(y):
if y == 1:
x = 5
else:
x = x + 1
break
x = -30
return x

self.checkScript(may_break, (1,))
self.checkScript(may_break, (2,))
self.checkScript(may_break, (3,))

def test(x, y):
# type: (int, int)
a = 1
while (x > 0):
if y == 3:
for i in range(y):
a += (1 % (i + 1))
x -= 1
if x == 3:
a = x * 3
break
if x < 3:
if x == 1:
a -= 2
x -= 1
break
a -= 1
x -= 3
return a, x

self.checkScript(test, (10, 3))
self.checkScript(test, (10, 2))
self.checkScript(test, (3, 2))
self.checkScript(test, (5, 3))
self.checkScript(test, (2, 3))

def test_delete_after_break(x):
# type: (int)
a = 1
b = 1
for i in range(x):
a = i * 3
break
b = i * 5
return a, b

self.checkScript(test_delete_after_break, (0,))
self.checkScript(test_delete_after_break, (1,))

def test_will_break_after_guard(x):
# type: (int)
a = 1
for i in range(x):
if i == 4:
a = 3
break
a -= 1
break
assert 1 == 2
a -= -100
return a

self.checkScript(test_will_break_after_guard, (0,))
self.checkScript(test_will_break_after_guard, (2,))
self.checkScript(test_will_break_after_guard, (4,))

def test_varexit(cond):
# type: (int)
m = 0
for i in range(3):
if cond == 2:
if cond == 2:
m = 2
break
k = 1
else:
k = 2
m += k
return m

# use of k tests the pathway where we have to insert unitialized
self.checkScript(test_varexit, (3,))
self.checkScript(test_varexit, (2,))

def test_break_true():
i = 0
while True:
i += 1
if i == 3:
break
while False:
i += 1
return i

self.checkScript(test_break_true, ())

def test_break_continue_error(self):
with self.assertRaisesRegex(RuntimeError, "Syntax"):
cu = torch.jit.CompilationUnit('''
def other_func(a):
break
''')

with self.assertRaisesRegex(RuntimeError, "Syntax"):
cu = torch.jit.CompilationUnit('''
def other_func(a):
for i in range(5):
def foo():
break
''')

def test_python_call(self):
def pyfunc(a):
return a * 3.0
Expand Down
3 changes: 3 additions & 0 deletions tools/build_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@
"torch/csrc/jit/script/logging.cpp",
"torch/csrc/jit/script/final_returns.cpp",
"torch/csrc/jit/script/convert_to_ssa.cpp",
"torch/csrc/jit/script/exit_transforms.cpp",
"torch/csrc/jit/script/inline_loop_condition.cpp",
"torch/csrc/jit/script/canonicalize_modified_loop.cpp",
"torch/csrc/jit/script/script_type_parser.cpp",
"torch/csrc/jit/script/sugared_value.cpp",
"torch/csrc/jit/script/schema_matching.cpp",
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/jit/ir_views.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ struct LoopView {
bodyBlock()->permuteInputs(adjusted_block_order);
}

void replaceMaxTripCount(Value* new_max_trip_count) {
node_->replaceInput(0, new_max_trip_count);
}
void replaceInputCondition(Value* new_input_condition) {
node_->replaceInput(1, new_input_condition);
}

// our way of encoding loops makes them difficult to turn back into python
// syntax. we have to check properties of the condition and trip count inputs
// to figure out which one it initially was. ModifiedLoops are not directly
// mappable to either For or While
enum LoopType { While, For, ModifiedLoop };

LoopType loopType() {
auto trip_count = toIValue(maxTripCount());
auto cond_input = toIValue(inputCond());
auto cond_next = toIValue(nextCond());

bool condition_is_always_true =
cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
bool trip_count_is_specified = !trip_count || // trip is not a constant
trip_count->toInt() !=
std::numeric_limits<int64_t>::max() || // it is a constant but not
// the default one
currentTripCount()->uses().size() >
0; // it is actually being used in the body.

if (condition_is_always_true) {
// if the trip count was not specified this was a user-written while True:
return trip_count_is_specified ? For : While;
} else {
if (trip_count_is_specified) {
return ModifiedLoop;
}
return While;
}
}

private:
Node* node_;

Expand Down
Loading