Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions aten/src/ATen/native/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
throw std::runtime_error(ss.str());
}

static void checkIndexTensorTypes(TensorList indices) {
for (auto& tensor : indices) {
if (tensor.defined()) {
auto& type = tensor.type();
auto scalarType = type.scalarType();
if (scalarType != kLong && scalarType != kByte) {
throw std::runtime_error("tensors used as indices must be long or byte tensors");
}
}
}
}

static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList indices) {
// Expands byte tensors (masks) into the equivalent indexing by LongTensors
std::vector<Tensor> result;
Expand All @@ -57,8 +69,15 @@ static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList ind
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (int64_t j = 0; j < nonzero.size(1); j++) {
result.emplace_back(nonzero.select(1, j));
auto is_empty = nonzero.numel() == 0;
for (int64_t j = 0; j < index.dim(); j++) {
if (is_empty) {
// We can't call select on an empty tensor so we just create an empty
// tensor.
result.emplace_back(nonzero.type().tensor());
} else {
result.emplace_back(nonzero.select(1, j));
}
}
} else {
result.emplace_back(index);
Expand Down Expand Up @@ -100,7 +119,7 @@ transposeToFront(Tensor self, TensorList indices) {
transposedIndices.emplace_back();
}
}
return std::make_tuple<>(self.permute(dims), std::move(transposedIndices));
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
}

static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
Expand Down Expand Up @@ -176,9 +195,22 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
return linearIndex;
}

static bool hasEmptyTensor(TensorList tensors) {
for (auto& tensor : tensors) {
if (tensor.defined() && tensor.numel() == 0) {
return true;
}
}
return false;
}

static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig);
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
auto indices = expandByteTensors(self, orig);
if (hasEmptyTensor(indices)) {
return std::make_tuple(self, self.type().toScalarType(kLong).tensor());
}
// next broadcast all index tensors together
indices = expand_outplace(indices);
// add missing null Tensors so that it matches self.dim()
Expand All @@ -191,7 +223,7 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
std::tie(self, indices) = transposeToFront(self, indices);
}
auto linearIndex = computeLinearIndex(self, indices);
return std::make_tuple<>(self, linearIndex);
return std::make_tuple(self, linearIndex);
}

Tensor index(const Tensor & self, TensorList indices) {
Expand Down
9 changes: 8 additions & 1 deletion test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def test_byte_mask(self):
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))

v = Variable(torch.Tensor([1]))
self.assertEqual(v[v == 0], Variable(torch.Tensor()))

def test_multiple_byte_mask(self):
v = Variable(torch.randn(5, 7, 3))
# note: these broadcast together and are transposed to the first dim
Expand Down Expand Up @@ -75,6 +78,11 @@ def test_int_indices_broadcast(self):
result = x[rows[:, None], columns]
self.assertEqual(result.data.tolist(), [[0, 2], [9, 11]])

def test_empty_index(self):
x = Variable(torch.arange(0, 12).view(4, 3))
idx = Variable(torch.LongTensor())
self.assertEqual(x[idx].numel(), 0)

def test_basic_advanced_combined(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
Expand Down Expand Up @@ -200,7 +208,6 @@ def test_empty_tuple_index(self):
self.assertEqual(a[()], a)
self.assertEqual(a[()].data_ptr(), a.data_ptr())

# @unittest.skip('failing')
def test_empty_fancy_index(self):
# Empty list index creates an empty array
a = tensor([1, 2, 3])
Expand Down