Skip to content

Commit 881adb5

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
fix tuple indexing bug (#21521)
Summary: lower tuples pass didn't check bounds for tuple index Pull Request resolved: #21521 Differential Revision: D15716813 Pulled By: eellison fbshipit-source-id: 8eead98c2c63118e7d24a8c8bf6184b02afb7dcd
1 parent a5cca4d commit 881adb5

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

test/test_jit.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10536,6 +10536,30 @@ def test_indexing_out_of_bounds_neg():
1053610536
self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
1053710537
"out of range")
1053810538

10539+
def negative_index():
10540+
tup = (1, 2, 3, 4)
10541+
return tup[-1]
10542+
10543+
self.checkScript(negative_index, [])
10544+
10545+
def really_negative_index():
10546+
tup = (1, 2, 3, 4)
10547+
return tup[-100]
10548+
10549+
self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range")
10550+
10551+
def negative_slice():
10552+
tup = (1, 2, 3, 4)
10553+
return tup[-3:4]
10554+
10555+
self.checkScript(negative_slice, [])
10556+
10557+
def really_slice_out_of_bounds():
10558+
tup = (1, 2, 3, 4)
10559+
return tup[-300:4000]
10560+
10561+
self.checkScript(really_slice_out_of_bounds, [])
10562+
1053910563
def test_namedtuple_attr(self):
1054010564
def f(x):
1054110565
return x.max(dim=1).indices + torch.max(x, dim=1).indices

torch/csrc/jit/passes/lower_tuples.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,16 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
4848
}
4949
return;
5050
}
51-
n->output()->replaceAllUsesWith(construct->inputs().at(*maybe_int));
51+
auto int_idx = *maybe_int;
52+
auto len = construct->output()->type()->containedTypes().size();
53+
if (int_idx < 0) {
54+
int_idx += len;
55+
}
56+
// currently, we allow non-constant tuple index if the tuple is of one type.
57+
// so we need to check bounds here
58+
if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) {
59+
n->output()->replaceAllUsesWith(construct->inputs().at(int_idx));
60+
}
5261
} else if (n->kind() == prim::TupleSlice) {
5362
std::vector<Value*> values;
5463
int64_t beg = n->i(attr::beg);

0 commit comments

Comments
 (0)