-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
The torch.sum() function writes its outputs in erroneous locations when the out kwarg is used. Repro:
import torch
i = 1
a = torch.zeros(5, 3)
b = torch.randn(3, 5)
ac = a.clone()
ac[:, i].copy_(b.sum(0))
print(ac)
ac = a.clone()
b.sum(0, out=ac[:, i])
print(ac)Output:
lvdmaaten-mbp:Desktop lvdmaaten$ python bug.py
0.0000 -0.1515 0.0000
0.0000 1.8761 0.0000
0.0000 -1.7563 0.0000
0.0000 0.5194 0.0000
0.0000 1.5322 0.0000
[torch.FloatTensor of size 5x3]
0.0000 -0.1515 1.8761
-1.7563 0.5194 1.5322
0.0000 0.0000 0.0000
0.0000 0.0000 0.0000
0.0000 0.0000 0.0000
[torch.FloatTensor of size 5x3]I presume what happens is that torch.sum() does not respect the stride of a[:, i] when putting outputs in place.
Tested on PyTorch version 0.3.0.post4.
Metadata
Metadata
Assignees
Labels
No labels