Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9266,6 +9266,40 @@ def foo(a):
return b + 1
self.checkScript(foo, (torch.rand(3),))

def test_tuple_assignments(self):
def var_tuple_assign(x, y):
# type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
(a, b), c = x, y
return a + b + c

tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))

def nested_tuple_assign(x, y, z):
# type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
a, (b, (c, d)), (e, f) = x, y, z
return a + b + c + d + e + f

self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))

def subscript_tuple_assign(a, x, i):
# type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
a[i], (x[i], b) = 1, (2, 3)
return a[i] + 1, x + 5, b

self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))

# python 2 does not support star assignments so we use compilation unit to test instead
star_code = '''
def star_tuple_assign():
# type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
a, (b, *c), *d = 1, (2, 3, 4), 5, 6
return a, b, c, d
'''

self.checkScript(star_code, (), name='star_tuple_assign', outputs=(1, 2, (3, 4), (5, 6)))


def test_multi_reduction(self):
with self.assertRaisesRegex(
RuntimeError,
Expand Down
36 changes: 29 additions & 7 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ struct to_ir {
// Validate that the `lhs` Expr's in an assignment statement are valid. That
// is:
//
// 1) All lhs Expr's are either Var or Starred nodes
// 1) All lhs Expr's are either Var, Tuple or Starred nodes
// 2) There is at most one Starred node in the lhs Expr
// 3) A Starred node can only appear when there is another non-Starred lhs
// Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
Expand All @@ -1622,7 +1622,8 @@ struct to_ir {
size_t num_normal_assign = 0;
size_t num_starred = 0;
for (const auto& assignee : lhs) {
if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT) {
if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT
|| assignee.kind() == TK_TUPLE_LITERAL) {
num_normal_assign++;
} else if (assignee.kind() == TK_STARRED) {
num_starred++;
Expand Down Expand Up @@ -1911,8 +1912,13 @@ struct to_ir {
if (starred_unpack)
n_binders--;
auto output = emitSugaredExpr(rhs, n_binders);
auto outputs = output->asTuple(
rhs.range(),
emitTupleAssign(tl, output, rhs.range(), n_binders, starred_unpack);
}

void emitTupleAssign(const TupleLiteral& tl, const SugaredValuePtr& rhs_output,
const SourceRange& rhs_loc, size_t n_binders, bool starred_unpack) {
auto outputs = rhs_output->asTuple(
rhs_loc,
method,
starred_unpack ? c10::nullopt : c10::optional<size_t>{n_binders});
if (outputs.size() < n_binders) {
Expand All @@ -1924,15 +1930,21 @@ struct to_ir {
throw ErrorReport(tl) << "too many values to unpack: need " << n_binders
<< " but found " << outputs.size();
}

emitExprsAssign(tl.inputs(), outputs, rhs_loc, n_binders);
}

void emitExprsAssign(const List<Expr>& lhs_exprs, const at::ArrayRef<SugaredValuePtr> outputs,
const SourceRange& rhs_loc, size_t n_binders) {
int i = 0;
for (auto assignee : tl.inputs()) {
for (auto assignee : lhs_exprs) {
switch (assignee.kind()) {
case TK_SUBSCRIPT:
emitSubscriptAssign(
rhs.range(),
rhs_loc,
Subscript(assignee),
NamedValue(
rhs.range(), outputs.at(i)->asValue(rhs.range(), method)));
rhs_loc, outputs.at(i)->asValue(rhs_loc, method)));
i++;
break;
case TK_VAR:
Expand All @@ -1957,6 +1969,16 @@ struct to_ir {
environment_stack->setVar(var.range(), Var(var).name().name(), tup);
i += n_matched;
} break;
case TK_TUPLE_LITERAL: {
// recursively emit tuple assignments
TupleLiteral sub_tl = TupleLiteral(assignee);
size_t sub_n_binders = sub_tl.inputs().size();
bool sub_starred_unpack = calcNumStarredUnpack(sub_tl.inputs(), sub_tl.range());
if (sub_starred_unpack)
sub_n_binders--;
emitTupleAssign(sub_tl, outputs.at(i), rhs_loc, sub_n_binders, sub_starred_unpack);
i ++;
} break;
default:
throw ErrorReport(assignee)
<< "unexpected expression on the left-hand side";
Expand Down