Skip to content

Commit 8c3e1b7

Browse files
zou3519soumith
authored andcommitted
Add proper shape checking to torch.cat (#4087)
* Fix catArray in THTensor Asserts that the inputs have the same size except in the cat dimension or are empty (or a mix of both). * Fix catArray for THCTensor * Document torch.cat shape checks * Fix types
1 parent 0185d5a commit 8c3e1b7

File tree

5 files changed

+190
-139
lines changed

5 files changed

+190
-139
lines changed

test/test_cuda.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,38 @@ def test_cat_autogpu(self):
719719
z = torch.cat([x, y], 0)
720720
self.assertEqual(z.get_device(), x.get_device())
721721

722+
def test_cat(self):
723+
SIZE = 10
724+
for dim in range(-3, 3):
725+
pos_dim = dim if dim >= 0 else 3 + dim
726+
x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim).cuda()
727+
y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim).cuda()
728+
z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim).cuda()
729+
730+
res1 = torch.cat((x, y, z), dim)
731+
self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
732+
self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
733+
self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
734+
735+
x = torch.randn(20, SIZE, SIZE).cuda()
736+
self.assertEqual(torch.cat(torch.split(x, 7)), x)
737+
self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
738+
739+
y = torch.randn(1, SIZE, SIZE).cuda()
740+
z = torch.cat([x, y])
741+
self.assertEqual(z.size(), (21, SIZE, SIZE))
742+
743+
def test_cat_bad_input_sizes(self):
744+
x = torch.randn(2, 1).cuda()
745+
y = torch.randn(2, 1, 1).cuda()
746+
z = torch.randn(2, 1, 1).cuda()
747+
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
748+
749+
x = torch.randn(2, 1, 2).cuda()
750+
y = torch.randn(2, 1, 1).cuda()
751+
z = torch.randn(2, 2, 1).cuda()
752+
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
753+
722754
def test_serialization(self):
723755
x = torch.randn(4, 4).cuda()
724756
with tempfile.NamedTemporaryFile() as f:

test/test_torch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,17 @@ def test_cat(self):
19181918

19191919
self.assertRaises(RuntimeError, lambda: torch.cat([]))
19201920

1921+
def test_cat_bad_input_sizes(self):
1922+
x = torch.randn(2, 1)
1923+
y = torch.randn(2, 1, 1)
1924+
z = torch.randn(2, 1, 1)
1925+
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
1926+
1927+
x = torch.randn(2, 1, 2)
1928+
y = torch.randn(2, 1, 1)
1929+
z = torch.randn(2, 2, 1)
1930+
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
1931+
19211932
def test_stack(self):
19221933
x = torch.rand(2, 3, 4)
19231934
y = torch.rand(2, 3, 4)

torch/_torch_docs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,14 +594,18 @@
594594
cat(seq, dim=0, out=None) -> Tensor
595595
596596
Concatenates the given sequence of :attr:`seq` tensors in the given dimension.
597+
All tensors must either have the same shape (except in the cat dimension) or be
598+
empty.
597599
598600
:func:`torch.cat` can be seen as an inverse operation for :func:`torch.split`
599601
and :func:`torch.chunk`
600602
601603
:func:`cat` can be best understood via examples.
602604
603605
Args:
604-
seq (sequence of tensors): any python sequence of tensors of the same type
606+
seq (sequence of tensors): any python sequence of tensors of the same type.
607+
Non-empty tensors provided must have the same shape, except in the
608+
cat dimension.
605609
dim (int, optional): the dimension over which the tensors are concatenated
606610
out (Tensor, optional): the output tensor
607611

torch/lib/TH/generic/THTensorMath.c

Lines changed: 90 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,116 +2834,114 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
28342834
THTensor_(catArray)(r_, inputs, 2, dimension);
28352835
}
28362836

2837+
void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension);
2838+
inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension)
2839+
{
2840+
int first_dims = first->nDimension;
2841+
int second_dims = second->nDimension;
2842+
THArgCheck(first_dims == second_dims, 0,
2843+
"Tensors must have same number of dimensions: got %d and %d",
2844+
first_dims, second_dims);
2845+
for (int dim = 0; dim < first_dims; dim++) {
2846+
if (dim == dimension) {
2847+
continue;
2848+
}
2849+
int64_t first_dim_size = first->size[dim];
2850+
int64_t second_dim_size = second->size[dim];
2851+
THArgCheck(first_dim_size == second_dim_size, 0,
2852+
"Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
2853+
dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
2854+
}
2855+
}
2856+
28372857
void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension)
28382858
{
2839-
THLongStorage *size;
2840-
int i, j;
2841-
int64_t offset;
2842-
int maxDim = dimension + 1;
2859+
// Find a non-empty tensor to record nDims
28432860
int allEmpty = 1;
2844-
int allContiguous = 1;
2845-
2846-
// cat_dimension is the actual dimension we cat along
2847-
int cat_dimension = dimension;
2848-
2849-
for (i = 0; i < numInputs; i++)
2850-
{
2851-
maxDim = THMax(maxDim, inputs[i]->nDimension);
2861+
int nDims = 0;
2862+
THTensor *notEmptyTensor;
2863+
for (int i = 0; i < numInputs; i++) {
2864+
int input_dims = inputs[i]->nDimension;
2865+
if (input_dims == 0) {
2866+
continue;
2867+
}
2868+
// We've found a non-empty tensor
2869+
allEmpty = 0;
2870+
notEmptyTensor = inputs[i];
2871+
nDims = input_dims;
2872+
break;
2873+
}
2874+
if (allEmpty) {
2875+
return;
28522876
}
28532877

2878+
// Compute cat_dimension based on the non-empty tensor
2879+
THArgCheck(dimension >= -1 && dimension < nDims, 4, "invalid dimension %d", dimension);
28542880
// When the user input dimension is -1 (i.e. -2 in C)
2855-
// Then we pick the maximum last dimension across all tensors.
2856-
if ( dimension + TH_INDEX_BASE == -1 )
2857-
{
2858-
cat_dimension = maxDim?(maxDim-1):0;
2881+
// Then we pick the last dimension across non-empty tensors.
2882+
int cat_dimension = dimension;
2883+
if (dimension + TH_INDEX_BASE == -1) {
2884+
cat_dimension = nDims ? nDims - 1 : 0;
28592885
}
28602886

28612887
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
2862-
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
2863-
2864-
size = THLongStorage_newWithSize(maxDim);
28652888

2866-
for(i = 0; i < maxDim; i++)
2867-
{
2868-
// dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
2869-
int64_t dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1);
2870-
if (i == cat_dimension)
2871-
{
2872-
for (j = 1; j < numInputs; j++)
2873-
{
2874-
// accumulate the size over the dimension we want to cat on.
2875-
// Empty tensors are allowed
2876-
dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1);
2877-
}
2889+
// Compute size of the result in the cat dimension
2890+
int64_t cat_dim_size = 0;
2891+
for (int i = 0; i < numInputs; i++) {
2892+
THTensor *tensor = inputs[i];
2893+
if (tensor->nDimension == 0) {
2894+
continue;
28782895
}
2879-
else
2880-
{
2881-
for (j = 1; j < numInputs; j++)
2882-
{
2883-
int64_t sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1));
2884-
// If it's a dimension we're not catting on
2885-
// Then fail if sizes are different AND > 0
2886-
if (dimSize != sz && dimSize && sz)
2887-
{
2888-
THLongStorage_free(size);
2889-
THError("inconsistent tensor sizes");
2890-
}
2891-
else if(!dimSize)
2892-
{
2893-
dimSize = sz;
2894-
}
2895-
}
2896-
}
2897-
allEmpty = allEmpty && !dimSize;
2898-
size->data[i] = dimSize;
2896+
THTensor_(check_shape_except_dim)(notEmptyTensor, tensor, cat_dimension);
2897+
cat_dim_size += tensor->size[cat_dimension];
28992898
}
29002899

2901-
// Initiate catting and resizing
2902-
// If at least one of the input is not empty
2903-
if (!allEmpty)
2904-
{
2905-
THTensor_(resize)(result, size, NULL);
2900+
// Compute the size of the result
2901+
THLongStorage *size = THLongStorage_newWithSize(nDims);
2902+
for (int dim = 0; dim < nDims; dim++) {
2903+
int64_t result_dim_size = notEmptyTensor->size[dim];
2904+
if (dim == cat_dimension) {
2905+
result_dim_size = cat_dim_size;
2906+
}
2907+
size->data[dim] = result_dim_size;
2908+
}
2909+
THTensor_(resize)(result, size, NULL);
29062910

2907-
// Check contiguity of all inputs and result
2908-
for (i = 0; i < numInputs; i++) {
2909-
if(inputs[i]->nDimension) {
2910-
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
2911-
}
2911+
// Check contiguity of all inputs and result
2912+
int allContiguous = 1;
2913+
for (int i = 0; i < numInputs; i++) {
2914+
if(inputs[i]->nDimension) {
2915+
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
29122916
}
2913-
allContiguous = allContiguous && THTensor_(isContiguous)(result);
2917+
}
2918+
allContiguous = allContiguous && THTensor_(isContiguous)(result);
29142919

2915-
// First path is for contiguous inputs along dim 1
2916-
// Second path for non-contiguous
2917-
if (cat_dimension == 0 && allContiguous)
2918-
{
2919-
real* result_data = result->storage->data + result->storageOffset;
2920-
offset = 0;
2921-
for (j = 0; j < numInputs; j++)
2922-
{
2923-
if (inputs[j]->nDimension)
2924-
{
2925-
THTensor* input0 = inputs[j];
2926-
real* input0_data = input0->storage->data + input0->storageOffset;
2927-
int64_t input0_size = THTensor_(nElement)(input0);
2928-
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
2929-
offset += input0_size;
2930-
}
2920+
// First path is for contiguous inputs along dim 0
2921+
// Second path for non-contiguous
2922+
int64_t offset;
2923+
if (cat_dimension == 0 && allContiguous) {
2924+
real* result_data = result->storage->data + result->storageOffset;
2925+
offset = 0;
2926+
for (int j = 0; j < numInputs; j++) {
2927+
if (inputs[j]->nDimension) {
2928+
THTensor* input0 = inputs[j];
2929+
real* input0_data = input0->storage->data + input0->storageOffset;
2930+
int64_t input0_size = THTensor_(nElement)(input0);
2931+
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
2932+
offset += input0_size;
29312933
}
29322934
}
2933-
else
2934-
{
2935-
offset = 0;
2936-
for (j = 0; j < numInputs; j++)
2937-
{
2938-
if (inputs[j]->nDimension)
2939-
{
2940-
int64_t dimSize = cat_dimension < inputs[j]->nDimension ? inputs[j]->size[cat_dimension] : 1;
2941-
THTensor *nt = THTensor_(newWithTensor)(result);
2942-
THTensor_(narrow)(nt, NULL, cat_dimension, offset, dimSize);
2943-
THTensor_(copy)(nt, inputs[j]);
2944-
THTensor_(free)(nt);
2945-
offset += dimSize;
2946-
}
2935+
} else {
2936+
offset = 0;
2937+
for (int j = 0; j < numInputs; j++) {
2938+
if (inputs[j]->nDimension) {
2939+
int64_t dimSize = cat_dimension < inputs[j]->nDimension ? inputs[j]->size[cat_dimension] : 1;
2940+
THTensor *nt = THTensor_(newWithTensor)(result);
2941+
THTensor_(narrow)(nt, NULL, cat_dimension, offset, dimSize);
2942+
THTensor_(copy)(nt, inputs[j]);
2943+
THTensor_(free)(nt);
2944+
offset += dimSize;
29472945
}
29482946
}
29492947
}

0 commit comments

Comments
 (0)