Skip to content

Commit 13abcd1

Browse files
committed
Check for internal memory overlap in some indexing-type functions
ghstack-source-id: 2e93fdc Pull Request resolved: #43423
1 parent b59e1f4 commit 13abcd1

File tree

9 files changed

+78
-4
lines changed

9 files changed

+78
-4
lines changed

aten/src/ATen/native/LegacyDefinitions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
#include <ATen/LegacyTHFunctionsCPU.h>
44
#include <ATen/NamedTensorUtils.h>
55
#include <ATen/ExpandUtils.h>
6+
#include <ATen/MemoryOverlap.h>
67

78
namespace at { namespace native {
89

910
// Methods
1011

1112
Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & source) {
13+
at::assert_no_internal_overlap(self);
1214
Tensor b_mask;
1315
std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_");
1416
// As we dispatch on self and TH is type-checked, we need different definitions.

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ Tensor index(const Tensor & self, TensorList indices) {
272272

273273
Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
274274
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
275+
at::assert_no_internal_overlap(result);
275276

276277
auto info = make_info(self, indices);
277278
auto iter = make_index_out_iterator(info, result);
@@ -291,6 +292,7 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
291292
index_put_accum_stub(self.device().type(), self, indices, value, unsafe);
292293
return self;
293294
}
295+
at::assert_no_internal_overlap(self);
294296
auto info = make_info(self, indices);
295297
auto iter = make_index_put_iterator(info, value);
296298
index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate);
@@ -429,6 +431,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
429431
"index_select(): self and result must have the same scalar type");
430432
TORCH_CHECK(dim == 0 || dim < self.dim(),
431433
"index_select(): Indexing dim ", dim, " is out of bounds of tensor");
434+
at::assert_no_internal_overlap(result);
432435

433436
auto result_size = self.sizes().vec();
434437
if (self.dim() > 0) {
@@ -698,6 +701,7 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
698701
"masked_select: expected BoolTensor or ByteTensor for mask");
699702
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
700703
"masked_select(): self and result must have the same scalar type");
704+
at::assert_no_internal_overlap(result);
701705

702706
at::assert_no_internal_overlap(result);
703707
at::assert_no_partial_overlap(result, self);

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,7 @@ Tensor& index_select_out_cuda(Tensor& out, const Tensor& self, int64_t dim,
813813

814814
TORCH_CHECK(at::cuda::check_device({out, self, index}),
815815
"Input, output and indices must be on the current device");
816+
at::assert_no_internal_overlap(out);
816817

817818
dim = at::maybe_wrap_dim(dim, self);
818819
TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING);

aten/src/ATen/native/cuda/LegacyDefinitions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace at { namespace native {
1111

1212
Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, Scalar value) {
1313
auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
14+
at::assert_no_internal_overlap(self);
1415
Tensor b_mask;
1516
std::tie(b_mask) = expand_inplace(self, mask, "masked_fill_");
1617
// As we dispatch on self and TH is type-checked, we need different definitions.
@@ -28,6 +29,7 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, Scalar value) {
2829

2930
Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & value) {
3031
auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
32+
at::assert_no_internal_overlap(self);
3133

3234
TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
3335
"with ", value.dim(), " dimension(s).");
@@ -47,6 +49,7 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & val
4749
}
4850

4951
Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor & source) {
52+
at::assert_no_internal_overlap(self);
5053
Tensor b_mask;
5154
std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_");
5255
// As we dispatch on self and TH is type-checked, we need different definitions.

aten/src/ATen/native/cuda/Shape.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
314314
"of the output memory locations. Found overlap in input "
315315
"tensor ", i);
316316
}
317+
at::assert_no_internal_overlap(out);
317318

318319
for (int i = 0; i < inputs.size(); i++)
319320
{

test/test_distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4926,8 +4926,8 @@ def f(*values):
49264926

49274927
# check on different data
49284928
values, sample = self._perturb(Dist, keys, values, sample)
4929-
expected = f(*values)
4930-
actual = traced_f(*values)
4929+
expected = f(*values).clone()
4930+
actual = traced_f(*values).clone()
49314931
expected[expected == float('inf')] = 0.
49324932
actual[actual == float('inf')] = 0.
49334933
self.assertEqual(expected, actual,

test/test_torch.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14414,6 +14414,67 @@ def test_bernoulli_mem_overlap(self, device):
1441414414
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
1441514415
torch.bernoulli(torch.rand_like(x), out=x)
1441614416

14417+
def test_index_put_mem_overlap(self, device):
14418+
x = torch.rand((1,), device=device).expand((6,))
14419+
y = torch.rand((6,), device=device)
14420+
ind = torch.tensor([0, 2, 3], device=device)
14421+
value = torch.rand((3,), device=device)
14422+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14423+
x.index_put_((ind,), value)
14424+
14425+
def test_masked_fill_mem_overlap(self, device):
14426+
x = torch.rand((1,), device=device).expand((6,))
14427+
mask = torch.tensor([True, False, True, True, False, False], device=device)
14428+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14429+
x.masked_fill_(mask, 0.)
14430+
14431+
fill_val = torch.tensor(0., device=device)
14432+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14433+
x.masked_fill_(mask, fill_val)
14434+
14435+
def test_masked_select_mem_overlap(self, device):
14436+
x = torch.rand((1,), device=device).expand((3,))
14437+
y = torch.rand((6,), device=device)
14438+
mask = torch.tensor([True, False, True, True, False, False], device=device)
14439+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14440+
torch.masked_select(y, mask, out=x)
14441+
14442+
def test_masked_scatter_mem_overlap(self, device):
14443+
x = torch.rand((1,), device=device).expand((6,))
14444+
src = torch.rand((3,), device=device)
14445+
mask = torch.tensor([True, False, True, True, False, False], device=device)
14446+
14447+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14448+
x.masked_scatter_(mask, src)
14449+
14450+
def test_index_select_mem_overlap(self, device):
14451+
x = torch.rand((1, 6), device=device).expand((2, 6))
14452+
y = torch.rand((3, 6), device=device)
14453+
ind = torch.tensor([0, 1], dtype=torch.int64, device=device)
14454+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14455+
torch.index_select(y, 1, ind, out=x)
14456+
14457+
def test_cat_mem_overlap(self, device):
14458+
x = torch.rand((1, 3), device=device).expand((6, 3))
14459+
y = torch.rand((3, 3), device=device)
14460+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14461+
torch.cat([y, y], out=x)
14462+
14463+
def test_scatter_mem_overlap(self, device):
14464+
x = torch.rand((1,), device=device).expand((6,))
14465+
src = torch.rand((3,), device=device)
14466+
ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
14467+
14468+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14469+
x.scatter_(0, ind, src)
14470+
14471+
def test_gather_mem_overlap(self, device):
14472+
x = torch.rand((1,), device=device).expand((3,))
14473+
src = torch.rand((6,), device=device)
14474+
ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
14475+
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14476+
torch.gather(src, 0, ind, out=x)
14477+
1441714478
def test_linlogspace_mem_overlap(self, device):
1441814479
x = torch.rand(1, device=device).expand(10)
1441914480
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):

torch/distributions/geometric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def sample(self, sample_shape=torch.Size()):
8888
def log_prob(self, value):
8989
if self._validate_args:
9090
self._validate_sample(value)
91-
value, probs = broadcast_all(value, self.probs.clone(memory_format=torch.contiguous_format))
91+
value, probs = broadcast_all(value, self.probs)
92+
probs = probs.clone(memory_format=torch.contiguous_format)
9293
probs[(probs == 1) & (value == 0)] = 0
9394
return value * (-probs).log1p() + self.probs.log()
9495

torch/distributions/multinomial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def sample(self, sample_shape=torch.Size()):
101101
def log_prob(self, value):
102102
if self._validate_args:
103103
self._validate_sample(value)
104-
logits, value = broadcast_all(self.logits.clone(memory_format=torch.contiguous_format), value)
104+
logits, value = broadcast_all(self.logits, value)
105+
logits = logits.clone(memory_format=torch.contiguous_format)
105106
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
106107
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
107108
logits[(value == 0) & (logits == -inf)] = 0

0 commit comments

Comments
 (0)