Skip to content

Commit f6658bc

Browse files
jotsifsoumith
authored andcommitted
Issue 14984: Remove divide by zero error in index_put_ (#14986)
Summary: No check for zero index tensor was done in the accumulate=True (serial) case in the new TensorIterator code since #13420. #14984 Pull Request resolved: #14986 Differential Revision: D13417861 Pulled By: colesbury fbshipit-source-id: e6ed1af8f708b53a35803fc157ed1f043169ec89
1 parent bd51538 commit f6658bc

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ void TensorIterator::serial_for_each(const loop_t& loop, Range range) const {
385385
}
386386

387387
void TensorIterator::serial_for_each(const loop2d_t& loop, Range range) const {
388+
if (range.size() == 0) {
389+
return;
390+
}
388391
auto strides = get_strides();
389392
while (strides.size() < 2 * ntensors()) {
390393
strides.push_back(0);
@@ -677,8 +680,10 @@ DimCounter::DimCounter(IntList shape, Range range)
677680
int64_t ndim = values.size();
678681
for (int dim = 0; dim < ndim; dim++) {
679682
int64_t size = shape[dim];
680-
values[dim] = linear_offset % size;
681-
linear_offset /= size;
683+
if (size > 0) {
684+
values[dim] = linear_offset % size;
685+
linear_offset /= size;
686+
}
682687
}
683688
AT_ASSERT(linear_offset == 0);
684689
}

test/test_indexing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def test_byte_mask(self):
4545
v = torch.tensor([1.])
4646
self.assertEqual(v[v == 0], torch.tensor([]))
4747

48+
def test_byte_mask_accumulate(self):
49+
mask = torch.zeros(size=(10, ), dtype=torch.uint8)
50+
y = torch.ones(size=(10, 10))
51+
y.index_put_((mask, ), y[mask], accumulate=True)
52+
self.assertEqual(y, torch.ones(size=(10, 10)))
53+
4854
def test_multiple_byte_mask(self):
4955
v = torch.randn(5, 7, 3)
5056
# note: these broadcast together and are transposed to the first dim

0 commit comments

Comments
 (0)