Skip to content

Commit eb5daa9

Browse files
authored
Make cat/cat_out native function that rejects scalar inputs. (#4992)
* Make cat/cat_out native function that rejects scalar inputs. * Print position of scalar in error message.
1 parent a8bda67 commit eb5daa9

File tree

5 files changed

+34
-2
lines changed

5 files changed

+34
-2
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4162,7 +4162,7 @@
41624162
]]
41634163

41644164
[[
4165-
name: cat
4165+
name: _cat
41664166
cname: catArray
41674167
variants: [function]
41684168
return: self

aten/src/ATen/native/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ at::Tensor _convolution(
382382
outputs[g] = at::_convolution_nogroup(
383383
input_g, weight_g, bias_g, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
384384
}
385-
output = cat(outputs, 1);
385+
output = at::cat(outputs, 1);
386386
}
387387
}
388388

aten/src/ATen/native/TensorShape.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,25 @@
88
namespace at {
99
namespace native {
1010

11+
static void check_cat_no_zero_dim(TensorList tensors) {
12+
for(size_t i = 0; i < tensors.size(); ++i) {
13+
auto& t = tensors[i];
14+
if (t.dim() == 0) {
15+
runtime_error("zero-dimensional tensor (at position %zu) cannot be concatenated", i);
16+
}
17+
}
18+
}
19+
20+
Tensor & cat_out(Tensor & result, TensorList tensors, int64_t dim) {
21+
check_cat_no_zero_dim(tensors);
22+
return at::_cat_out(result, tensors, dim);
23+
}
24+
25+
Tensor cat(TensorList tensors, int64_t dim) {
26+
check_cat_no_zero_dim(tensors);
27+
return at::_cat(tensors, dim);
28+
}
29+
1130
std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
1231
if (self.dim() == 0) {
1332
throw std::runtime_error("chunk expects at least a 1-dimensional tensor");

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747

4848
- func: bernoulli_(Tensor self, double p=0.5, Generator* generator=nullptr) -> Tensor
4949

50+
- func: cat(TensorList tensors, int64_t dim=0) -> Tensor
51+
variants: function
52+
53+
- func: cat_out(Tensor result, TensorList tensors, int64_t dim=0) -> Tensor
54+
variants: function
55+
5056
- func: chunk(Tensor self, int64_t chunks, int64_t dim=0) -> TensorList
5157

5258
- func: cudnn_is_acceptable(Tensor self) -> bool

test/test_torch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,6 +1966,13 @@ def test_cat_bad_input_sizes(self):
19661966
z = torch.randn(2, 2, 1)
19671967
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
19681968

1969+
@unittest.skipIf(not torch._C._with_scalars(), "scalars not enabled")
1970+
def test_cat_scalars(self):
1971+
from torch.autograd import variable
1972+
x = variable(0)
1973+
y = variable(1)
1974+
self.assertRaises(RuntimeError, lambda: torch.cat([x, y]))
1975+
19691976
def test_stack(self):
19701977
x = torch.rand(2, 3, 4)
19711978
y = torch.rand(2, 3, 4)

0 commit comments

Comments
 (0)