Skip to content

Commit 92b82ec

Browse files
committed
*_scatter ops should preserve input stride/storage_offset
ghstack-source-id: 484c48a Pull Request resolved: #91029
1 parent 2e46969 commit 92b82ec

File tree

8 files changed

+313
-196
lines changed

8 files changed

+313
-196
lines changed

aten/src/ATen/FunctionalTensorWrapper.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,14 @@ functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_st
146146
void FunctionalTensorWrapper::commit_update() {
147147
auto storage_impl = functional_storage_impl();
148148
storage_impl->add_update(value_, view_metas_);
149-
// Invariant: commit_update() is called during an inplace operation.
150-
// Tensor inputs to the operation are synced before runnig the op,
151-
// so the current tensor must be up-to-date with its alias at this point.
152-
generation_ = storage_impl->generation();
149+
// As an optimization, we used to mark the tensor here as "up-to-date",
150+
// That way, code like:
151+
// x = torch.ones(1'000'000)
152+
// x[0].add_(1)
153+
// doesn't result in an unnecessary materialization of the base.
154+
// This optimization results in the slice temporarily haven't incorrect
155+
// stride/storage_offset though, and DCE should handle that optimization anyway.
156+
// generation_ = storage_impl->generation();
153157
}
154158

155159
bool FunctionalTensorWrapper::is_up_to_date() const {

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ MemOverlap has_internal_overlap(TensorImpl* t) {
1616
return MemOverlap::No;
1717
}
1818

19-
auto strides = t->strides();
20-
auto sizes = t->sizes();
19+
auto strides = t->sym_strides();
20+
auto sizes = t->sym_sizes();
2121
for (const auto i : c10::irange(strides.size())) {
2222
if (strides[i] == 0 && sizes[i] > 1) {
2323
return MemOverlap::Yes;

aten/src/ATen/native/Copy.cpp

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/native/quantized/Copy.h>
99
#include <ATen/native/mps/Copy.h>
1010
#include <ATen/native/vulkan/ops/Copy.h>
11+
#include <ATen/native/TensorShape.h>
1112
#include <ATen/quantized/Quantizer.h>
1213
#include <ATen/vulkan/Context.h>
1314
#include <ATen/metal/Context.h>
@@ -278,32 +279,6 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
278279
return self;
279280
}
280281

281-
// NB: cribbed from https://github.com/pytorch/pytorch/pull/88198
282-
at::Tensor clone_preserve_strides(const at::Tensor& self) {
283-
TORCH_INTERNAL_ASSERT(self.has_storage());
284-
// In cases where the input tensor has internal memory overlap, we cannot actually
285-
// preserve the strides/storage_offset of the input tensor, because
286-
// *_scatter ops will try to copy_() into the cloned tensor.
287-
// However, this should **never** show up in functionalized user code;
288-
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
289-
//
290-
// The one place that this does come up is in autograd - if there's a select_scatter
291-
// in the forward, then autograd will generate one for the backward.
292-
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
293-
// with internal overlap.
294-
//if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
295-
// return self.clone();
296-
//}
297-
auto dtype_size = self.dtype().itemsize();
298-
auto nbytes = self.storage().sym_nbytes();
299-
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
300-
auto numel = nbytes / dtype_size;
301-
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
302-
auto clone = self_full_size.clone();
303-
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
304-
return out;
305-
}
306-
307282
Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
308283
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
309284
// (1) It isn't exposed to the frontend (no python bindings)

aten/src/ATen/native/TensorShape.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,22 +3801,58 @@ std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tenso
38013801
return outputs;
38023802
}
38033803

3804+
3805+
// Clones a tensor by cloning the underlying storage that it came from,
3806+
// which allows us to replicate the exact strides/storage_offset in the cloned tensor.
3807+
// Note [*_scatter ops preserve strides]
3808+
// In order for functionalization to preserve stride correctness, the *_scatter
3809+
// operators that it calls must preserve the striding behavior of their inputs.
3810+
// Specifically, the output of *_scatter(base, mutated_view, ...)
3811+
// should have identical size/stride/storage_offset to "base".
3812+
at::Tensor clone_preserve_strides(const at::Tensor& self) {
3813+
TORCH_INTERNAL_ASSERT(self.has_storage());
3814+
// In cases where the input tensor has internal memory overlap, we cannot actually
3815+
// preserve the strides/storage_offset of the input tensor, because
3816+
// *_scatter ops will try to copy_() into the cloned tensor.
3817+
// However, this should **never** show up in functionalized user code;
3818+
// most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
3819+
//
3820+
// The one place that this does come up is in autograd - if there's a select_scatter
3821+
// in the forward, then autograd will generate one for the backward.
3822+
// If the input to the select_scatter is grad_output, then this could be an expanded tensor
3823+
// with internal overlap.
3824+
if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
3825+
return self.clone();
3826+
}
3827+
auto dtype_size = self.dtype().itemsize();
3828+
auto nbytes = self.storage().sym_nbytes();
3829+
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
3830+
auto numel = nbytes / dtype_size;
3831+
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
3832+
auto clone = self_full_size.clone();
3833+
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
3834+
return out;
3835+
}
3836+
3837+
38043838
at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
3805-
auto output = self.clone();
3839+
// See Note [*_scatter ops preserve strides]
3840+
auto output = clone_preserve_strides(self);
38063841
auto slice = output.slice(dim, start, end, step);
38073842
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
38083843
slice.copy_(src);
38093844
return output;
38103845
}
38113846
at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
3812-
auto output = self.clone();
3847+
auto output = clone_preserve_strides(self);
38133848
auto slice = output.select_symint(dim, index);
38143849
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
38153850
slice.copy_(src);
38163851
return output;
38173852
}
38183853
at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
3819-
auto output = self.clone();
3854+
// See Note [*_scatter ops preserve strides]
3855+
auto output = clone_preserve_strides(self);
38203856
auto slice = output.diagonal(offset, dim1, dim2);
38213857
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
38223858
slice.copy_(src);
@@ -3825,7 +3861,8 @@ at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64
38253861
at::Tensor as_strided_scatter_symint(const at::Tensor& self, const at::Tensor& src, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
38263862
// See Note [as_strided_scatter backward support]
38273863
TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
3828-
auto output = self.clone();
3864+
// See Note [*_scatter ops preserve strides]
3865+
auto output = clone_preserve_strides(self);
38293866
auto slice = output.as_strided_symint(size, stride, std::move(storage_offset));
38303867
TORCH_CHECK(slice.sym_sizes() == src.sym_sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sym_sizes(), ", slice size = ", slice.sym_sizes());
38313868
slice.copy_(src);

aten/src/ATen/native/TensorShape.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
namespace at {
77
namespace native {
8+
9+
TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
10+
811
inline bool cat_should_skip_tensor(const Tensor& t) {
912
return t.numel() == 0 && t.dim() == 1;
1013
}

0 commit comments

Comments
 (0)