Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10959,6 +10959,30 @@ def test_indexing_out_of_bounds_neg():
self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
"out of range")

def negative_index():
tup = (1, 2, 3, 4)
return tup[-1]

self.checkScript(negative_index, [])

def really_negative_index():
tup = (1, 2, 3, 4)
return tup[-100]

self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range")

def negative_slice():
tup = (1, 2, 3, 4)
return tup[-3:4]

self.checkScript(negative_slice, [])

def really_slice_out_of_bounds():
tup = (1, 2, 3, 4)
return tup[-300:4000]

self.checkScript(really_slice_out_of_bounds, [])

def test_namedtuple_attr(self):
def f(x):
return x.max(dim=1).indices + torch.max(x, dim=1).indices
Expand Down
15 changes: 11 additions & 4 deletions torch/csrc/jit/passes/lower_tuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <c10/util/Exception.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/range_utils.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -48,12 +49,18 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
}
return;
}
n->output()->replaceAllUsesWith(construct->inputs().at(*maybe_int));
auto tuple = n->inputs().at(0);
const size_t tuple_len = tuple->type()->containedTypes().size();
const size_t normalized_idx = normalizeIndex(*maybe_int, tuple_len);
n->output()->replaceAllUsesWith(construct->inputs().at(normalized_idx));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the only change needed, but it's not safe to use the normalized_idx since it maybe out of bounds. I don't think tupleslice has a problem.

} else if (n->kind() == prim::TupleSlice) {
std::vector<Value*> values;
int64_t beg = n->i(attr::beg);
int64_t end = n->i(attr::end);
for (int64_t i = beg; i < end; i += 1) {
const size_t tuple_len = n->inputs().at(0)->type()->containedTypes().size();
int64_t beg;
int64_t end;
std::tie(beg, end) =
clamp_bounds(n->i(attr::beg), n->i(attr::end), tuple_len);
for (int64_t i = beg; i < end; i++) {
values.push_back(construct->inputs().at(i));
}
auto graph = n->owningGraph();
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/jit/range_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <utility>

namespace torch {
namespace jit {
// Convert an python index (which may be negative) into an index usable for a
// C++ container
inline size_t normalizeIndex(int64_t idx, size_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}

// Clamp `start` and `end` in the way Python does for iterable slicing.
inline std::pair<size_t, size_t> clamp_bounds(
int64_t start,
int64_t end,
size_t list_size) {
const size_t normalized_start =
std::max((size_t)0, normalizeIndex(start, list_size));
const size_t normalized_end =
std::min(list_size, normalizeIndex(end, list_size));
return std::pair<size_t, size_t>(normalized_start, normalized_end);
}
} // namespace jit
} // namespace torch
32 changes: 13 additions & 19 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/range_utils.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/jit_exception.h>
Expand Down Expand Up @@ -145,16 +146,6 @@ static at::Tensor to_dispatch(
}
}

// Convert an python index (which may be negative) into an index usable for a
// C++ container
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}

RegisterOperators reg(
{Operator(
prim::profile,
Expand Down Expand Up @@ -718,8 +709,12 @@ RegisterOperators reg(
Operator(
prim::TupleSlice,
[](const Node* node) {
int64_t beg_ind = node->i(attr::beg);
int64_t end_ind = node->i(attr::end);
size_t tuple_len =
node->inputs().at(0)->type()->containedTypes().size();
int64_t beg_ind;
int64_t end_ind;
std::tie(beg_ind, end_ind) =
clamp_bounds(node->i(attr::beg), node->i(attr::end), tuple_len);
return [=](Stack& stack) {
auto t = pop(stack).toTuple();
const auto& elems = t->elements();
Expand Down Expand Up @@ -1138,8 +1133,7 @@ int stringSlice(Stack& stack) {
const int64_t size = string.size();

// Clamp start and end to the bounds of the list
start = std::max(int64_t(0), normalizeIndex(start, size));
end = std::min(size, normalizeIndex(end, size));
std::tie(start, end) = clamp_bounds(start, end, size);

if (end <= start) {
// Slice is empty
Expand Down Expand Up @@ -1573,13 +1567,13 @@ int listSlice(Stack& stack) {
int64_t step;

pop(stack, list, start, end, step);
const int64_t list_size = list->elements().size();
const size_t list_size = list->elements().size();

// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));
size_t normalized_start;
size_t normalized_end;
std::tie(normalized_start, normalized_end) =
clamp_bounds(start, end, list_size);

std::vector<TElement> sliced_list;
if (normalized_end <= normalized_start) {
Expand Down