Skip to content

Commit 9923701

Browse files
zou3519soumith
authored andcommitted
Fix crash when cat-ing empty cuda tensors (#5971)
Fixes #5739. The CUDA path for `torch.cat` was missing a check for the case where all input tensors are empty.
1 parent 641fb21 commit 9923701

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

aten/src/THC/generic/THCTensorMath.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
139139
}
140140
}
141141

142+
// If all inputs are empty tensors, return an empty tensor
143+
if (notEmptyTensor == NULL) {
144+
return;
145+
}
146+
142147
// In the event that the user specified -1 as the concat dimension, then
143148
// we want to pick the nDims as dimension to cat along (and thus nDims - 1 as the
144149
// value due to 0-based indexing). If the nDims is // 0 (i.e. we are catting all

test/test_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,9 @@ def test_cat(self):
10781078
z = torch.cat([x, y])
10791079
self.assertEqual(z.size(), (21, SIZE, SIZE))
10801080

1081+
def test_cat_empty(self):
1082+
TestTorch._test_cat_empty(self, use_cuda=True)
1083+
10811084
def test_bernoulli(self):
10821085
x = torch.tensor([0, 1], dtype=torch.cuda.float32)
10831086
self.assertEqual(x.bernoulli().tolist(), [0, 1])

test/test_torch.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,17 +2353,25 @@ def test_cat_scalars(self):
23532353
'zero-dimensional.*cannot be concatenated'):
23542354
torch.cat([x, y])
23552355

2356-
def test_cat_empty(self):
2356+
@staticmethod
2357+
def _test_cat_empty(self, use_cuda=False):
23572358
# FIXME: this is legacy behavior and should be removed
23582359
# when we support empty tensors with arbitrary sizes
2359-
x = torch.randn(4, 3, 32, 32)
2360-
empty = torch.randn(0)
2360+
if use_cuda:
2361+
dtype = torch.cuda.float32
2362+
else:
2363+
dtype = torch.float32
2364+
2365+
x = torch.randn((4, 3, 32, 32), dtype=dtype)
2366+
empty = torch.randn((0,), dtype=dtype)
23612367

23622368
res1 = torch.cat([x, empty], dim=1)
23632369
res2 = torch.cat([empty, x], dim=1)
23642370
self.assertEqual(res1, res2)
23652371

2366-
conv = torch.nn.Conv2d(3, 3, kernel_size=1)
2372+
conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
2373+
if use_cuda:
2374+
conv = conv.cuda()
23672375
res1 = torch.cat([conv(x), empty], dim=1)
23682376
res2 = torch.cat([empty, conv(x)], dim=1)
23692377
self.assertEqual(res1, res2)
@@ -2375,6 +2383,9 @@ def test_cat_empty(self):
23752383
'expected a non-empty list of Tensors'):
23762384
torch.cat([], dim=1)
23772385

2386+
def test_cat_empty(self):
2387+
self._test_cat_empty(self)
2388+
23782389
def test_stack(self):
23792390
x = torch.rand(2, 3, 4)
23802391
y = torch.rand(2, 3, 4)

0 commit comments

Comments
 (0)