Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aten/src/ATen/native/SummaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ Tensor _bincount_cpu_template(
if (minlength < 0) {
AT_ERROR("minlength should be >= 0");
}
if (self.dim() != 1 || self.numel() == 0 || *self.min().data<input_t>() < 0) {
if (self.dim() == 1 && self.numel() == 0) {
return native::zeros({minlength}, kLong);
}
if (self.dim() != 1 || *self.min().data<input_t>() < 0) {
AT_ERROR("bincount only supports 1-d non-negative integral inputs.");
}

Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/native/cuda/SummaryOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@ Tensor _bincount_cuda_template(
if (minlength < 0) {
AT_ERROR("minlength should be >= 0");
}
if (self.dim() != 1 || self.numel() == 0 ||
if (self.dim() == 1 && self.numel() == 0) {
return native::zeros({minlength}, device(kCUDA).dtype(kLong));
}
if (self.dim() != 1 ||
(!std::is_same<input_t, uint8_t>::value &&
*self.min().toBackend(kCPU).data<input_t>() < 0)) {
AT_ERROR("bincount only supports 1-d non-negative integral inputs.");
Expand Down
6 changes: 6 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8223,6 +8223,12 @@ def _test_bincount(self, device):
with self.assertRaisesRegex(RuntimeError, 'same length'):
torch.bincount(torch.tensor([1, 0], device=device),
torch.tensor([1., 0.3, 0.5], device=device))
# 1-d input with no elements and default minlength
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
torch.zeros(0, dtype=torch.long, device=device))
# 1-d input with no elements and specified minlength
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
torch.zeros(10, dtype=torch.long, device=device))

# test tensor method without weights
long_counts = torch.tensor(
Expand Down
11 changes: 7 additions & 4 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,19 +604,22 @@ def parse_kwargs(desc):
Count the frequency of each value in an array of non-negative ints.

The number of bins (size 1) is one larger than the largest value in
:attr:`input`. If :attr:`minlength` is specified, the number of bins is at least
:attr:`minlength`. If ``n`` is the value at position ``i``,
:attr:`input` unless :attr:`input` is empty, in which case the result is a
tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least
:attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size
:attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``,
:math:`out[n] += weights[i]` if :attr:`weights` is specified else
:math:`out[n] += 1`.

Arguments:
input (Tensor): 1-d int tensor
weights (Tensor): optional, weight for each value in the input tensor.
Should be of same size as input tensor.
minlength (int): optional, min number of bins. Should be non-negative.
minlength (int): optional, minimum number of bins. Should be non-negative.

Shape:
output (Tensor): ``Size([max(input) + 1])``
output (Tensor): ``Size([max(input) + 1])`` if :attr:`input` is non-empty, else
``Size(0)``

Example::

Expand Down