Skip to content

Commit e0d5d1b

Browse files
ssnlsoumith
authored andcommitted
view in certain noncontig case (#4062)
1 parent 9394e65 commit e0d5d1b

File tree

5 files changed

+165
-11
lines changed

5 files changed

+165
-11
lines changed

aten/src/TH/generic/THTensor.c

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,61 @@ THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, int64_t size_,
226226
return self;
227227
}
228228

229+
// Also sets new_stride if viewable.
230+
//
231+
// On a high level,
232+
// 1. separate tensor->size into chunks of dimensions, where the dimensions are
233+
// ``contiguous'' in each chunk, i.e., stride[i] = size[i+1] * stride[i+1]
234+
// 2. view_size must be able to be separated into same number of chunks, where
235+
// each chunk pair has matching ``numel'', i.e., number of subspaces.
236+
static int THTensor_(isViewable)(THTensor *tensor, THLongStorage *view_size, THLongStorage *new_stride) {
237+
// dim indices
238+
int64_t tensor_d = tensor->nDimension - 1;
239+
if (tensor_d < 0) {
240+
return 1;
241+
}
242+
int64_t view_d = view_size->size - 1;
243+
// stride for each subspace in the chunk
244+
int64_t chunk_base_stride = tensor->stride[tensor_d];
245+
// numel in current chunk
246+
int64_t tensor_numel = 1;
247+
int64_t view_numel = 1;
248+
for (; tensor_d >= 0; tensor_d--) {
249+
tensor_numel *= tensor->size[tensor_d];
250+
// if end of tensor size chunk, check view
251+
if ((tensor_d == 0) ||
252+
(tensor->size[tensor_d - 1] != 1 && tensor->stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
253+
while ((view_numel < tensor_numel || view_size->data[view_d] == 1) && view_d >= 0) {
254+
new_stride->data[view_d] = view_numel * chunk_base_stride;
255+
view_numel *= view_size->data[view_d];
256+
view_d--;
257+
}
258+
if (view_numel != tensor_numel) {
259+
return 0;
260+
}
261+
if (tensor_d > 0) {
262+
chunk_base_stride = tensor->stride[tensor_d - 1];
263+
tensor_numel = 1;
264+
view_numel = 1;
265+
}
266+
}
267+
}
268+
// check that we iterated through all view size
269+
return view_d == -1;
270+
}
271+
229272
THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size)
230273
{
231-
THArgCheck(THTensor_(isContiguous)(tensor), 1, "input is not contiguous");
232274
ptrdiff_t numel = THTensor_(nElement)(tensor);
233275
THTensor *self = THTensor_(new)();
234276
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
235-
THTensor_(setStorage)(self, tensor->storage, tensor->storageOffset, inferred_size, NULL);
277+
THLongStorage *new_stride = THLongStorage_newWithSize(size->size);
278+
THArgCheck(THTensor_(isViewable)(tensor, inferred_size, new_stride), 2, "view size is "
279+
"not compatible with input tensor's size and stride (at least one dimension spans "
280+
"across two contiguous subspaces). Call .contiguous() before .view().");
281+
THTensor_(setStorage)(self, tensor->storage, tensor->storageOffset, inferred_size, new_stride);
236282
THLongStorage_free(inferred_size);
283+
THLongStorage_free(new_stride);
237284
return self;
238285
}
239286

aten/src/THC/generic/THCTensor.c

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,61 @@ THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimensi
222222
return self;
223223
}
224224

225+
// Also sets new_stride if viewable.
226+
//
227+
// On a high level,
228+
// 1. separate tensor->size into chunks of dimensions, where the dimensions are
229+
// ``contiguous'' in each chunk, i.e., stride[i] = size[i+1] * stride[i+1]
230+
// 2. view_size must be able to be separated into same number of chunks, where
231+
// each chunk pair has matching ``numel'', i.e., number of subspaces.
232+
static int THCTensor_(isViewable)(THCState *state, THCTensor *tensor, THLongStorage *view_size, THLongStorage *new_stride) {
233+
// dim indices
234+
int64_t tensor_d = tensor->nDimension - 1;
235+
if (tensor_d < 0) {
236+
return 1;
237+
}
238+
int64_t view_d = view_size->size - 1;
239+
// stride for each subspace in the chunk
240+
int64_t chunk_base_stride = tensor->stride[tensor_d];
241+
// numel in current chunk
242+
int64_t tensor_numel = 1;
243+
int64_t view_numel = 1;
244+
for (; tensor_d >= 0; tensor_d--) {
245+
tensor_numel *= tensor->size[tensor_d];
246+
// if end of tensor size chunk, check view
247+
if ((tensor_d == 0) ||
248+
(tensor->size[tensor_d - 1] != 1 && tensor->stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
249+
while ((view_numel < tensor_numel || view_size->data[view_d] == 1) && view_d >= 0) {
250+
new_stride->data[view_d] = view_numel * chunk_base_stride;
251+
view_numel *= view_size->data[view_d];
252+
view_d--;
253+
}
254+
if (view_numel != tensor_numel) {
255+
return 0;
256+
}
257+
if (tensor_d > 0) {
258+
chunk_base_stride = tensor->stride[tensor_d - 1];
259+
tensor_numel = 1;
260+
view_numel = 1;
261+
}
262+
}
263+
}
264+
// check that we iterated through all view size
265+
return view_d == -1;
266+
}
267+
225268
THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size)
226269
{
227-
THArgCheck(THCTensor_(isContiguous)(state, tensor), 2, "input is not contiguous");
228270
ptrdiff_t numel = THCTensor_(nElement)(state, tensor);
229271
THCTensor *self = THCTensor_(new)(state);
230272
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
231-
THCTensor_(setStorage)(state, self, tensor->storage, tensor->storageOffset, inferred_size, NULL);
273+
THLongStorage *new_stride = THLongStorage_newWithSize(size->size);
274+
THArgCheck(THCTensor_(isViewable)(state, tensor, inferred_size, new_stride), 2, "View size is "
275+
"not compatible with input tensor's size and stride (at least one dimension spans "
276+
"across two contiguous subspaces). Call .contiguous() before .view().");
277+
THCTensor_(setStorage)(state, self, tensor->storage, tensor->storageOffset, inferred_size, new_stride);
232278
THLongStorage_free(inferred_size);
279+
THLongStorage_free(new_stride);
233280
return self;
234281
}
235282

test/test_cuda.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def tmp(t):
302302
('triu', medium_2d, lambda t: [-2], 'negative'),
303303
('unsqueeze', new_t(2, 3, 4), lambda t: [2],),
304304
('unsqueeze', new_t(2, 3, 4), lambda t: [-2], 'neg_dim'),
305-
('view', small_3d, lambda t: [100, 10],),
305+
('view', small_3d, lambda t: [100, 10], 'contiguous'),
306306
('view_as', small_3d, lambda t: [t(100, 10)],),
307307
('zero', small_3d, lambda t: [],),
308308
('zeros', small_3d, lambda t: [1, 2, 3, 4],),
@@ -989,6 +989,9 @@ def _select_broadcastable_dims(dims_full=None):
989989
def test_det(self):
990990
TestTorch._test_det(self, lambda t: t.cuda())
991991

992+
def test_view(self):
993+
TestTorch._test_view(self, lambda t: t.cuda())
994+
992995
def test_broadcast(self):
993996
TestTorch._test_broadcast(self, lambda t: t.cuda())
994997

test/test_torch.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3828,10 +3828,11 @@ def test_var_stability(self):
38283828
self.assertEqual(tensor.var(0)[0], 0.03125)
38293829
self.assertEqual(tensor.var(), 0.03125)
38303830

3831-
def test_view(self):
3832-
tensor = torch.rand(15)
3833-
template = torch.rand(3, 5)
3834-
empty = torch.Tensor()
3831+
@staticmethod
3832+
def _test_view(self, cast):
3833+
tensor = cast(torch.rand(15))
3834+
template = cast(torch.rand(3, 5))
3835+
empty = cast(torch.Tensor())
38353836
target = template.size()
38363837
self.assertEqual(tensor.view_as(template).size(), target)
38373838
self.assertEqual(tensor.view(3, 5).size(), target)
@@ -3848,6 +3849,52 @@ def test_view(self):
38483849
self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
38493850
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
38503851
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
3852+
# test view when tensor is not contiguous in every dimension, but only
3853+
# contiguous dimensions are touched.
3854+
tensor = cast(torch.rand(4, 2, 5, 1, 6, 2, 9, 3)).transpose(-1, 2).transpose(-2, 3)
3855+
# size: [ 4, 2, 3, 9, 6, 2, 1, 5]
3856+
# stride: [3840, 1620, 1, 3, 54, 27, 324, 324]
3857+
# contiguous dim chunks: [__________, ____, ____, __________, ____, ____]
3858+
# merging 1 to chunk after: [__________, ____, ____, __________, __________]
3859+
contig_tensor = tensor.clone()
3860+
# [4, 2] => [8, 1]
3861+
# [3] => [3]
3862+
# [9] => [3, 3]
3863+
# [6, 2] => [4, 1, 3]
3864+
# [1, 5] => [5]
3865+
view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5]
3866+
self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
3867+
# [4, 2] => [2, 4]
3868+
# [3] => [3]
3869+
# [9] => [1, 9]
3870+
# [6, 2] => [2, 2, 3]
3871+
# [1, 5] => [5, 1]
3872+
view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1]
3873+
self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
3874+
# adding size 1 dims
3875+
view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1]
3876+
self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
3877+
3878+
# invalid views
3879+
self.assertRaises(RuntimeError, lambda: tensor.view(-1))
3880+
# crossing [4, 2], [3]
3881+
self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5))
3882+
# crossing [6, 2], [1, 5]
3883+
self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10))
3884+
# crossing [9], [6, 2]
3885+
self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5))
3886+
3887+
# view with stride 0 dims
3888+
tensor = cast(torch.Tensor(1, 1)).expand(3, 4) # all dims are contiguous
3889+
contig_tensor = tensor.clone()
3890+
self.assertEqual(tensor.view(-1), contig_tensor.view(-1))
3891+
self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1))
3892+
self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1))
3893+
self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
3894+
self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
3895+
3896+
def test_view(self):
3897+
TestTorch._test_view(self, lambda x: x)
38513898

38523899
def test_expand(self):
38533900
tensor = torch.rand(1, 8, 1)

torch/_tensor_docs.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,8 +1789,18 @@ def callable(a, b) -> number
17891789
different size.
17901790
17911791
The returned tensor shares the same data and must have the same number
1792-
of elements, but may have a different size. A tensor must be
1793-
:func:`contiguous` to be viewed.
1792+
of elements, but may have a different size. For a tensor to be viewed, the new
1793+
view size must be compatible with its original size and stride, i.e., each new
1794+
view dimension must either be a subspace of an original dimension, or only span
1795+
across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following
1796+
contiguity-like condition that :math:`\forall i = 0, \dots, k-1`,
1797+
1798+
.. math::
1799+
1800+
stride[i] = stride[i+1] \times size[i+1]
1801+
1802+
Otherwise, :func:`contiguous` needs to be called before the tensor can be
1803+
viewed.
17941804
17951805
Args:
17961806
args (torch.Size or int...): the desired size

0 commit comments

Comments
 (0)