Skip to content
Closed
24 changes: 16 additions & 8 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9364,14 +9364,22 @@ def list_iterables(x):
''')

def test_for_tuple_unpack(self):
with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
cu = torch.jit.CompilationUnit('''
def for_tuple_unpack(x, y):
for i, j in [[3, 4], [5, 6], [7, 8]]:
x += i
y += j
return x, y
''')
def for_tuple_unpack(x, y):
for i, j in [[3, 4], [5, 6], [7, 8]]:
x += i
y += j
return x, y

self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))

def nested_tuple_unpack(x, y):
# type: (List[int], List[int]) -> int
sum = 0
for i, (j, k), v in zip(x, enumerate(x), y):
sum += i + j + k + v
return sum

self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))

def test_for_tuple_assign(self):
def test_simple_assign(x):
Expand Down
24 changes: 11 additions & 13 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,12 +1314,15 @@ struct to_ir {
Value* cur_elem = iter_val->getelem(range, method, trip_count);
SugaredValuePtr sv = std::make_shared<SimpleValue>(cur_elem);
List<Expr> target_exprs = targets.value();
size_t n_binders = target_exprs.size();
validateAssignLhsExpr(target_exprs, range);

bool starred_unpack = calcNumStarredUnpack(target_exprs, range);
if (starred_unpack)
n_binders--;
emitExprsAssign(target_exprs, {sv}, range, n_binders);
// if target exprs are more than 1, it means iteration unpacking on LHS
// we create Tuple literal to wrap those target exprs for assignments
if (target_exprs.size() > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to create a list here vs just passing more than one target_expr to emitExprsAssign?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it's because sv here is the root IterableTree (which create TupleConstruct in getelem) and need to call emitTupleAssign to unwrap the tuple using asTuple, but the target_exprs here is not a tupleliteral, so emitExprsAssign could not match these two together, and we need to create a TupleLiteral wrap here.

Expr tl = TupleLiteral::create(range, target_exprs);
target_exprs = List<Expr>::create(range, {tl});
}
emitExprsAssign(target_exprs, {sv}, range, /*n_binders=*/1);
}

emitStatements(body);
Expand All @@ -1336,11 +1339,6 @@ struct to_ir {
throw ErrorReport(stmt)
<< "List of iterables is not supported currently.";
}
if (targets.size() != 1) {
throw ErrorReport(stmt)
<< "Iteration variable unpacking is not supported";
}

// Emit loop information for builtinFunction values like range(), zip(),
// enumerate() or SimpleValue like List, Tensor, Dict, etc.
auto sv = emitSugaredExpr(itrs[0], 1);
Expand Down Expand Up @@ -1425,7 +1423,7 @@ struct to_ir {
// 3) A Starred node can only appear when there is another non-Starred lhs
// Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
// all outputs into a tuple is covered by `abc = func()`.
bool calcNumStarredUnpack(const List<Expr>& lhs, const SourceRange& r) {
bool validateAssignLhsExpr(const List<Expr>& lhs, const SourceRange& r) {
size_t num_normal_assign = 0;
size_t num_starred = 0;
for (const auto& assignee : lhs) {
Expand Down Expand Up @@ -1715,7 +1713,7 @@ struct to_ir {

void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
size_t n_binders = tl.inputs().size();
bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
bool starred_unpack = validateAssignLhsExpr(tl.inputs(), tl.range());
if (starred_unpack)
n_binders--;
auto output = emitSugaredExpr(rhs, n_binders);
Expand Down Expand Up @@ -1780,7 +1778,7 @@ struct to_ir {
// recursively emit tuple assignments on tuple literal input
TupleLiteral sub_tl = TupleLiteral(assignee);
size_t sub_n_binders = sub_tl.inputs().size();
bool sub_starred_unpack = calcNumStarredUnpack(sub_tl.inputs(), sub_tl.range());
bool sub_starred_unpack = validateAssignLhsExpr(sub_tl.inputs(), sub_tl.range());
if (sub_starred_unpack)
sub_n_binders--;
emitTupleAssign(sub_tl, outputs.at(i), rhs_loc, sub_n_binders, sub_starred_unpack);
Expand Down