Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/THC/generic/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
}
}

// If all inputs are empty tensors, return an empty tensor
if (notEmptyTensor == NULL) {
return;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}

// In the event that the user specified -1 as the concat dimension, then
// we want to pick the nDims as dimension to cat along (and thus nDims - 1 as the
// value due to 0-based indexing). If the nDims is // 0 (i.e. we are catting all
Expand Down
3 changes: 3 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,9 @@ def test_cat(self):
z = torch.cat([x, y])
self.assertEqual(z.size(), (21, SIZE, SIZE))

def test_cat_empty(self):
TestTorch._test_cat_empty(self, use_cuda=True)

def test_bernoulli(self):
x = torch.tensor([0, 1], dtype=torch.cuda.float32)
self.assertEqual(x.bernoulli().tolist(), [0, 1])
Expand Down
19 changes: 15 additions & 4 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,17 +2353,25 @@ def test_cat_scalars(self):
'zero-dimensional.*cannot be concatenated'):
torch.cat([x, y])

def test_cat_empty(self):
@staticmethod
def _test_cat_empty(self, use_cuda=False):
# FIXME: this is legacy behavior and should be removed
# when we support empty tensors with arbitrary sizes
x = torch.randn(4, 3, 32, 32)
empty = torch.randn(0)
if use_cuda:
dtype = torch.cuda.float32
else:
dtype = torch.float32

x = torch.randn((4, 3, 32, 32), dtype=dtype)
empty = torch.randn((0,), dtype=dtype)

This comment was marked as off-topic.

This comment was marked as off-topic.


res1 = torch.cat([x, empty], dim=1)
res2 = torch.cat([empty, x], dim=1)
self.assertEqual(res1, res2)

conv = torch.nn.Conv2d(3, 3, kernel_size=1)
conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
if use_cuda:
conv = conv.cuda()
res1 = torch.cat([conv(x), empty], dim=1)
res2 = torch.cat([empty, conv(x)], dim=1)
self.assertEqual(res1, res2)
Expand All @@ -2375,6 +2383,9 @@ def test_cat_empty(self):
'expected a non-empty list of Tensors'):
torch.cat([], dim=1)

def test_cat_empty(self):
self._test_cat_empty(self)

def test_stack(self):
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
Expand Down