-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
module: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
I was looking to use scatter_add_ to do bincount.
import torch
a = torch.LongTensor([2, 0, 3, 3])
r = torch.LongTensor(5)
# works
r.zero_().scatter_(0, a, 1)
# 1
# 0
# 1
# 1
# 0
# [torch.LongTensor of size 5]
# scalar source doesn't work
r.zero_().scatter_add_(0, a, 1)
#Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
#TypeError: scatter_add_ received an invalid combination of arguments - got (int, torch.LongTensor, int), #but expected (int dim, torch.LongTensor index, torch.LongTensor src)
# no broadcasting? but no checking for memory bounds either?
r.zero_().scatter_add_(0, a, torch.LongTensor([1]))
#1.4033e+14
# 0.0000e+00
# 1.0000e+00
# 5.4931e+18
# 0.0000e+00
# [torch.LongTensor of size 5]
# works ok
r.zero_().scatter_add_(0, a, torch.LongTensor([1]).expand_as(a))
# 1
# 0
# 1
# 2
# 0
# [torch.LongTensor of size 5]at 0.4.0a0+1fdb392
Metadata
Metadata
Assignees
Labels
module: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module