Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions aten/src/THC/THCDeviceTensor-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,45 @@ __host__ __device__ bool
THCDeviceTensor<T, Dim, IndexT, PtrTraits>::isContiguousRange(
int first, int last) const {

int64_t prevSize = last < Dim ? getStride(last) * getSize(last) : 1;
// We're testing if dimensions [first, last) are a part of a contiguous range
// Call a dimension untrivial if it has size > 1 and trivial if it has size 1
//
// There is an edge case when the sizes of the last dimensions of
// [first, last) are all 1's. The following needs to happen:
// 1. If [first, last) ends with a trivial dim, find the next untrivial dimension,
// and get its size. If it doesn't exist, use 1 as the size.
// 2. Find the dim newLast such that newLast is the largest dim < last
// such that newLast is untrivial.
//
// We now test that [first, newLast] is a contiguous range, using the size of
// the next untrivial dimension that follows newLast to compute prevSize,
// or prevSize = 1 if all the dimensions following newLast are all trivial.
// (newLast, last) are ignored because they are trivial and don't matter.

int newLast = last;
int next_untrivial_dim = last;
if (next_untrivial_dim < Dim && getSize(next_untrivial_dim) == 1) {
// Find the next untrivial dim that is > last
while (next_untrivial_dim < Dim && getSize(next_untrivial_dim) == 1) {
++next_untrivial_dim;
}

// Find the first untrivial dim that is < last
int newLast = last;
while (newLast >= first && getSize(newLast) == 1) {
--newLast;
}

// Our entire range [first, last) was trivial
if (newLast == first) {
return true;
}
}

for (int i = last - 1; i >= first; --i) {
int64_t prevSize = next_untrivial_dim < Dim ?
getStride(next_untrivial_dim) * getSize(next_untrivial_dim) : 1;

for (int i = newLast - 1; i >= first; --i) {
if (getSize(i) != (IndexT) 1) {
if (getStride(i) == prevSize) {
prevSize *= getSize(i);
Expand Down Expand Up @@ -305,18 +341,37 @@ THCDeviceTensor<T, Dim, IndexT, PtrTraits>::downcastOuter() {
if (i < ignoredDims) {
// Collapse these dimensions
collapsedSize *= getSize(i);
} else {
// Non-collapsed dimensions
if (i == ignoredDims) {
// This is the first non-collapsed dimension
newSize[i - ignoredDims] = collapsedSize * getSize(i);
} else {
// Subsequent non-collapsed dimensions
newSize[i - ignoredDims] = getSize(i);
continue;
}

// Non-collapsed dimensions

if (i == ignoredDims) {
// This is the first non-collapsed dimension
newSize[i - ignoredDims] = collapsedSize * getSize(i);

// If the size of this dimension is 1, the stride could
// be anything. Recompute a reasonable stride based
// on the assumption that the outer dimensions are
// all contiguous.
if (getSize(i) == 1) {
int innerSize = 1;
for (int j = ignoredDims + 1; j < Dim; ++j) {
innerSize *= getSize(j);
}
newStride[i - ignoredDims] = innerSize;
continue;
}

// If the size of this dimension wasn't 1, then
// use the stride information
newStride[i - ignoredDims] = getStride(i);
continue;
}

// Subsequent non-collapsed dimensions
newSize[i - ignoredDims] = getSize(i);
newStride[i - ignoredDims] = getStride(i);
}

return THCDeviceTensor<T, NewDim, IndexT, PtrTraits>(
Expand Down
15 changes: 15 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,21 @@ def _test_vs_Embedding(N, D, B, L):
offset[-1] = 100
self.assertRaises(ValueError, lambda: es(input.view(-1), offset))

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_maxpool3d_contiguous_range(self):
# Mostly just a test if THCDeviceTensor::isContiguousRange
# is okay with dims of size 1 but crazy strides
x = Variable(torch.randn(7, 1, 5, 3, 2).cuda())
strange_strides = [30, 1234, 6, 2, 1]
y = x.as_strided(x.size(), strange_strides)
x = x.cpu().as_strided(x.size(), strange_strides)

# Should not crash
out_y = F.max_pool3d(y, (5, 1, 1), stride=(5, 1, 1))
out_x = F.max_pool3d(x, (5, 1, 1), stride=(5, 1, 1))

self.assertEqual(out_y, out_x.cuda())

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_AvgPool3d_backward_after_cat_dim1_cuda(self):
# x has to have batch_size 1 to test contiguous checks
Expand Down