Skip to content

scatter_add_ should support scalar source (including Python scalar) #5405

@vadimkantorov

Description

@vadimkantorov

I was looking to use scatter_add_ to do bincount.

import torch
a = torch.LongTensor([2, 0, 3, 3])
r = torch.LongTensor(5)

# works
r.zero_().scatter_(0, a, 1)
# 1
# 0
# 1
# 1
# 0
# [torch.LongTensor of size 5]

# scalar source doesn't work
r.zero_().scatter_add_(0, a, 1)
#Traceback (most recent call last):
#  File "<stdin>", line 1, in <module>
#TypeError: scatter_add_ received an invalid combination of arguments - got (int, torch.LongTensor, int), #but expected (int dim, torch.LongTensor index, torch.LongTensor src)

# no broadcasting? but no checking for memory bounds either?
r.zero_().scatter_add_(0, a, torch.LongTensor([1]))
#1.4033e+14
# 0.0000e+00
# 1.0000e+00
# 5.4931e+18
# 0.0000e+00
# [torch.LongTensor of size 5]

# works ok
r.zero_().scatter_add_(0, a, torch.LongTensor([1]).expand_as(a))
# 1
# 0
# 1
# 2
# 0
# [torch.LongTensor of size 5]

at 0.4.0a0+1fdb392

cc @mikaylagawarecki

Metadata

Metadata

Labels

module: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions