Skip to content

Commit 8a888c4

Browse files
zou3519facebook-github-bot
authored andcommitted
Reimplement as_strided in ATen. (#13185)
Summary: This moves away from using tensor.set_(...) for as_strided, which went through TH and was weirdly slow/complicated. The new as_strided has a new invariant that it will never resize the storage to a larger size (the previous as_strided allowed that behavior but it seemed weird and none of our code relied on it.) This offers a small speedup on as_strided: it went from 1300ns to 1100ns although the benchmarks get a little noisy here. Also on the changelog is a quick fix to resize_ code to avoid unsigned underflow. I'll rewrite the resize_ zero dim logic in a future diff, it doesn't make sense the way it is written right now. Pull Request resolved: #13185 Reviewed By: ezyang Differential Revision: D12809160 Pulled By: zou3519 fbshipit-source-id: 3885df9d863baab2b2f8d8e2f8e2bfe660a49d85
1 parent 8c2d0c8 commit 8a888c4

File tree

5 files changed

+97
-5
lines changed

5 files changed

+97
-5
lines changed

aten/src/ATen/core/TensorImpl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,17 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
726726
storage_offset_ = storage_offset;
727727
}
728728

729+
/* Sets the storage of this tensor to be new_storage */
730+
void set_storage(const Storage& new_storage) {
731+
auto* new_storage_ = new_storage.unsafeGetStorageImpl();
732+
auto* old_storage_ = storage_.unsafeGetStorageImpl();
733+
AT_ASSERTM(old_storage_, "Tensor: invalid null storage");
734+
if (new_storage_ == old_storage_) {
735+
return;
736+
}
737+
storage_ = new_storage;
738+
}
739+
729740
/**
730741
* Like set_sizes_and_strides but assumes contiguous strides.
731742
*

aten/src/ATen/native/Resize.h

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ inline TensorImpl* resize_impl_cpu_(
3030
return self;
3131
}
3232

33-
size_t storage_size = 1;
33+
int64_t storage_size = 1;
3434
if (stride) {
3535
self->set_sizes_and_strides(size, *stride);
3636
// NB: storage size can be different from numel.
3737
for (size_t dim = 0; dim < size.size(); ++dim) {
38+
// FIXME: Don't rely on storage_size being negative because this
39+
// may not be true for some edge cases.
3840
storage_size += (size[dim] - 1) * stride.value()[dim];
3941
}
4042
} else {
@@ -46,4 +48,63 @@ inline TensorImpl* resize_impl_cpu_(
4648
return self;
4749
}
4850

51+
static inline int64_t computeStorageSize(IntList sizes, IntList strides) {
52+
int64_t storage_size = 1;
53+
for (size_t dim = 0; dim < sizes.size(); ++dim) {
54+
if (sizes[dim] == 0) {
55+
return 0;
56+
}
57+
storage_size += strides[dim] * (sizes[dim] - 1);
58+
}
59+
return storage_size;
60+
}
61+
62+
static inline void checkInBoundsForStorage(
63+
IntList size,
64+
IntList stride,
65+
int64_t storage_offset,
66+
const Storage& new_storage) {
67+
int64_t storage_size = computeStorageSize(size, stride);
68+
if (storage_size == 0) {
69+
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
70+
return;
71+
}
72+
int64_t new_storage_size = new_storage.numel();
73+
AT_CHECK(
74+
storage_offset + storage_size <= new_storage_size,
75+
"setStorage: sizes ", size, ", strides ", stride, ","
76+
" and storage offset ", storage_offset,
77+
" requiring a storage size of ", storage_size + storage_offset,
78+
" are out of bounds for storage with numel ", new_storage_size);
79+
}
80+
81+
/**
82+
* Set self's storage to be new_storage with sizes, strides, and storage_offset.
83+
* (size, stride, storage_offset) must be in bounds for the new storage.
84+
*/
85+
inline void setStorage(
86+
const Tensor& self,
87+
const Storage& new_storage,
88+
int64_t storage_offset,
89+
IntList size,
90+
IntList stride) {
91+
checkInBoundsForStorage(size, stride, storage_offset, new_storage);
92+
93+
auto* self_ = self.unsafeGetTensorImpl();
94+
95+
/* storage */
96+
self_->set_storage(new_storage);
97+
98+
/* storage offset */
99+
AT_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
100+
self_->set_storage_offset(storage_offset);
101+
102+
/* size and stride */
103+
AT_ASSERT(size.size() == stride.size());
104+
if (self_->sizes() == size && self_->strides() == stride) {
105+
return;
106+
}
107+
self_->set_sizes_and_strides(size, stride);
108+
}
109+
49110
}}

aten/src/ATen/native/TensorShape.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ATen/WrapDimUtils.h"
99
#include "c10/util/Exception.h"
1010
#include "c10/util/Optional.h"
11+
#include "ATen/native/Resize.h"
1112
#include <ATen/SparseTensorUtils.h>
1213
#include <algorithm>
1314
#include <vector>
@@ -150,19 +151,32 @@ Tensor expand_as(const Tensor& self, const Tensor& other) {
150151
}
151152

152153
Tensor as_strided(const Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
153-
return at::empty({0}, self.options()).set_(self.storage(), storage_offset, size, stride);
154+
auto result = at::empty({0}, self.options());
155+
setStorage(
156+
result,
157+
self.storage(),
158+
storage_offset,
159+
size,
160+
stride);
161+
return result;
154162
}
155163

156164
Tensor &as_strided_(Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
157-
return self.set_(self.storage(), storage_offset, size, stride);
165+
setStorage(
166+
self,
167+
self.storage(),
168+
storage_offset,
169+
size,
170+
stride);
171+
return self;
158172
}
159173

160174
Tensor as_strided(const Tensor& self, IntList size, IntList stride) {
161175
return at::as_strided(self, size, stride, self.storage_offset());
162176
}
163177

164178
Tensor &as_strided_(Tensor& self, IntList size, IntList stride) {
165-
return at::as_strided_(self, size, stride, self.storage_offset());
179+
return self.as_strided_(size, stride, self.storage_offset());
166180
}
167181

168182
Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length) {

aten/src/ATen/native/cuda/Resize.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ inline TensorImpl* resize_impl_cuda_(
3838
guard = DeviceGuard(self->storage().device().index());
3939
}
4040

41-
size_t storage_size = 1;
41+
int64_t storage_size = 1;
4242
if (stride) {
4343
self->set_sizes_and_strides(size, *stride);
4444
// NB: storage size can be different from numel.
4545
for (size_t dim = 0; dim < size.size(); ++dim) {
46+
// FIXME: Don't rely on storage_size being negative because this
47+
// may not be true for some edge cases.
4648
storage_size += (size[dim] - 1) * stride.value()[dim];
4749
}
4850
} else {

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,21 @@
201201

202202
- func: as_strided(Tensor self, IntList size, IntList stride) -> Tensor
203203
variants: function, method
204+
device_guard: false
204205

205206
- func: as_strided_(Tensor self, IntList size, IntList stride) -> Tensor
206207
variants: function, method
208+
device_guard: false
207209

208210
- func: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset) -> Tensor
209211
variants: function, method
212+
device_guard: false
210213
python_default_init:
211214
storage_offset: self.storage_offset()
212215

213216
- func: as_strided_(Tensor self, IntList size, IntList stride, int64_t storage_offset) -> Tensor
214217
variants: function, method
218+
device_guard: false
215219
python_default_init:
216220
storage_offset: self.storage_offset()
217221

0 commit comments

Comments
 (0)