Skip to content

Commit d4e5d90

Browse files
colesburysoumith
authored andcommitted
Fix indexing with all zero ByteTensors (#3926)
Fixes #3914
1 parent 157f949 commit d4e5d90

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed

aten/src/ATen/native/Indexing.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
4242
throw std::runtime_error(ss.str());
4343
}
4444

45+
static void checkIndexTensorTypes(TensorList indices) {
46+
for (auto& tensor : indices) {
47+
if (tensor.defined()) {
48+
auto& type = tensor.type();
49+
auto scalarType = type.scalarType();
50+
if (scalarType != kLong && scalarType != kByte) {
51+
throw std::runtime_error("tensors used as indices must be long or byte tensors");
52+
}
53+
}
54+
}
55+
}
56+
4557
static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList indices) {
4658
// Expands byte tensors (masks) into the equivalent indexing by LongTensors
4759
std::vector<Tensor> result;
@@ -57,8 +69,15 @@ static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList ind
5769
}
5870
// Replace with nonzeros
5971
auto nonzero = index.nonzero();
60-
for (int64_t j = 0; j < nonzero.size(1); j++) {
61-
result.emplace_back(nonzero.select(1, j));
72+
auto is_empty = nonzero.numel() == 0;
73+
for (int64_t j = 0; j < index.dim(); j++) {
74+
if (is_empty) {
75+
// We can't call select on an empty tensor so we just create an empty
76+
// tensor.
77+
result.emplace_back(nonzero.type().tensor());
78+
} else {
79+
result.emplace_back(nonzero.select(1, j));
80+
}
6281
}
6382
} else {
6483
result.emplace_back(index);
@@ -100,7 +119,7 @@ transposeToFront(Tensor self, TensorList indices) {
100119
transposedIndices.emplace_back();
101120
}
102121
}
103-
return std::make_tuple<>(self.permute(dims), std::move(transposedIndices));
122+
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
104123
}
105124

106125
static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
@@ -176,9 +195,22 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
176195
return linearIndex;
177196
}
178197

198+
static bool hasEmptyTensor(TensorList tensors) {
199+
for (auto& tensor : tensors) {
200+
if (tensor.defined() && tensor.numel() == 0) {
201+
return true;
202+
}
203+
}
204+
return false;
205+
}
206+
179207
static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
208+
checkIndexTensorTypes(orig);
180209
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
181210
auto indices = expandByteTensors(self, orig);
211+
if (hasEmptyTensor(indices)) {
212+
return std::make_tuple(self, self.type().toScalarType(kLong).tensor());
213+
}
182214
// next broadcast all index tensors together
183215
indices = expand_outplace(indices);
184216
// add missing null Tensors so that it matches self.dim()
@@ -191,7 +223,7 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
191223
std::tie(self, indices) = transposeToFront(self, indices);
192224
}
193225
auto linearIndex = computeLinearIndex(self, indices);
194-
return std::make_tuple<>(self, linearIndex);
226+
return std::make_tuple(self, linearIndex);
195227
}
196228

197229
Tensor index(const Tensor & self, TensorList indices) {

test/test_indexing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def test_byte_mask(self):
4040
self.assertEqual(v[mask].shape, (3, 7, 3))
4141
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
4242

43+
v = Variable(torch.Tensor([1]))
44+
self.assertEqual(v[v == 0], Variable(torch.Tensor()))
45+
4346
def test_multiple_byte_mask(self):
4447
v = Variable(torch.randn(5, 7, 3))
4548
# note: these broadcast together and are transposed to the first dim
@@ -75,6 +78,11 @@ def test_int_indices_broadcast(self):
7578
result = x[rows[:, None], columns]
7679
self.assertEqual(result.data.tolist(), [[0, 2], [9, 11]])
7780

81+
def test_empty_index(self):
82+
x = Variable(torch.arange(0, 12).view(4, 3))
83+
idx = Variable(torch.LongTensor())
84+
self.assertEqual(x[idx].numel(), 0)
85+
7886
def test_basic_advanced_combined(self):
7987
# From the NumPy indexing example
8088
x = Variable(torch.arange(0, 12).view(4, 3))
@@ -200,7 +208,6 @@ def test_empty_tuple_index(self):
200208
self.assertEqual(a[()], a)
201209
self.assertEqual(a[()].data_ptr(), a.data_ptr())
202210

203-
# @unittest.skip('failing')
204211
def test_empty_fancy_index(self):
205212
# Empty list index creates an empty array
206213
a = tensor([1, 2, 3])

0 commit comments

Comments
 (0)