Skip to content

Commit 9ed3eba

Browse files
committed
Fix error message for cat-ing zero-dim tensors
1 parent 940a0ab commit 9ed3eba

File tree

5 files changed

+24
-16
lines changed

5 files changed

+24
-16
lines changed

test/test_torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2305,7 +2305,9 @@ def test_cat_bad_input_sizes(self):
23052305
def test_cat_scalars(self):
23062306
x = torch.tensor(0)
23072307
y = torch.tensor(1)
2308-
self.assertRaises(RuntimeError, lambda: torch.cat([x, y]))
2308+
with self.assertRaisesRegexp(RuntimeError,
2309+
'zero-dimensional.*cannot be concatenated'):
2310+
torch.cat([x, y])
23092311

23102312
def test_stack(self):
23112313
x = torch.rand(2, 3, 4)

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
self: not_implemented("btrisolve")
156156

157157
- name: cat(TensorList tensors, int64_t dim)
158-
tensors: cat_tensors_backward(grad, to_arg_sizes(tensors, dim), dim)
158+
tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
159159

160160
- name: cauchy_(Tensor self, double median, double sigma, Generator generator)
161161
self: zeros_like(grad)

tools/autograd/load_derivatives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,10 @@ def saved_variables(formula, args):
294294
'suffix': '_numel',
295295
'type': 'int64_t',
296296
}),
297-
# replace to_arg_sizes(self, 2) with self_argsizes_2
298-
(r'to_arg_sizes\({}, (\w+)\)', {
299-
'suffix': lambda m: '_sizes_{}'.format(*m.groups()),
300-
'type': 'IntList',
297+
# replace to_args_sizes(self) with self_args_sizes
298+
(r'to_args_sizes\({}\)', {
299+
'suffix': '_args_sizes',
300+
'type': 'std::vector<std::vector<int64_t>>',
301301
}),
302302
# replace TensorGeometry(self) with self_geometry
303303
(r'TensorGeometry\({}\)', {

tools/autograd/templates/Functions.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,17 +364,23 @@ Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntList sizes) {
364364
return self;
365365
}
366366

367-
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<int64_t> &sizes, int64_t dim) {
367+
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
368+
if (sizes.size() > 0) {
369+
// cat wraps dim to the first tensor's shape
370+
dim = at::maybe_wrap_dim(dim, sizes[0].size());
371+
}
368372
std::vector<Tensor> grad_inputs(sizes.size());
369373
int64_t accumulate = 0;
370374
for (size_t i = 0; i < sizes.size(); ++i) {
371-
auto size = sizes[i];
372-
accumulate += size;
373-
if (size == 0) {
375+
auto& shape = sizes[i];
376+
// If input was empty tensor, gradInput should be empty tensor.
377+
if (shape[0] == 0) {
374378
grad_inputs[i] = at::zeros(grad.type(), {0});
375-
} else {
376-
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
379+
continue;
377380
}
381+
auto size = shape[dim];
382+
accumulate += size;
383+
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
378384
}
379385
return grad_inputs;
380386
}

tools/autograd/templates/VariableType.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,12 @@ Tensor VariableType::contiguous(const Tensor & self) const {
409409
return self.clone();
410410
}
411411

412-
static std::vector<int64_t> to_arg_sizes(TensorList tensors, int64_t dim) {
413-
std::vector<int64_t> arg_sizes(tensors.size());
412+
static std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
413+
std::vector<std::vector<int64_t>> args_sizes(tensors.size());
414414
for (size_t i = 0; i < tensors.size(); ++i) {
415-
arg_sizes[i] = tensors[i].size(dim);
415+
args_sizes[i] = tensors[i].sizes();
416416
}
417-
return arg_sizes;
417+
return args_sizes;
418418
}
419419

420420
${type_derived_method_definitions}

0 commit comments

Comments
 (0)