-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriage review
Description
🐛 Bug
index_put_ does not work (anymore) as expected with int64 cuda tensor.
To Reproduce
import torch
print("My pytorch version {}\n\n".format(torch.__version__))
def test_index_put(device, dtype):
print("===== {}, {} =====".format(str(device), str(dtype)))
A = torch.zeros((2,2), device = device, dtype = dtype)
values = torch.as_tensor([1, 1, 1], device = device,dtype = dtype)
rows = torch.as_tensor([0, 1, 1], device = device, dtype = torch.int64)
cols = torch.as_tensor([0, 0, 0], device = device, dtype = torch.int64)
idxes = (rows, cols)
A.index_put_(idxes, values, accumulate=True)
print(A)
test_index_put('cpu', torch.float32)
test_index_put('cpu', torch.int64)
test_index_put('cuda', torch.float32)
test_index_put('cuda', torch.int64)With pytorch 1.1.0 this produces:
My pytorch version 1.1.0
===== cpu, torch.float32 =====
tensor([[1., 0.],
[2., 0.]])
===== cpu, torch.int64 =====
tensor([[1, 0],
[2, 0]])
===== cuda, torch.float32 =====
tensor([[1., 0.],
[2., 0.]], device='cuda:0')
===== cuda, torch.int64 =====
tensor([[1, 0],
[2, 0]], device='cuda:0')
However, with pytorch 1.3.1 the following happens:
My pytorch version 1.3.1
===== cpu, torch.float32 =====
tensor([[1., 0.],
[2., 0.]])
===== cpu, torch.int64 =====
tensor([[1, 0],
[2, 0]])
===== cuda, torch.float32 =====
tensor([[1., 0.],
[2., 0.]], device='cuda:0')
===== cuda, torch.int64 =====
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-2de03e14a631> in <module>
18 test_index_put('cpu', torch.int64)
19 test_index_put('cuda', torch.float32)
---> 20 test_index_put('cuda', torch.int64)
<ipython-input-2-2de03e14a631> in test_index_put(device, dtype)
11 cols = torch.as_tensor([0, 0, 0], device = device, dtype = torch.int64)
12 idxes = (rows, cols)
---> 13 A.index_put_(idxes, values, accumulate=True)
14
15 print(A)
RuntimeError: "embedding_backward" not implemented for 'Long'
Expected behavior
index_put_ should work regardless of the device. 😉
Environment
- PyTorch version: 1.3.1
- OS: Ubuntu 18.04.3 LTS
- Installed pytorch with
pip - Python version: 3.6
- CUDA runtime version: 10.1.243, cuDNN: 7.6.3
- GPU models and configuration: GPU 0: GeForce RTX 2070
- Nvidia driver version: 418.87.01
- Any other relevant information:
Versions of relevant libraries:
[pip3] numpy==1.17.1
[pip3] torch==1.3.1
[pip3] torchvision==0.4.2
Additional context
Metadata
Metadata
Assignees
Labels
high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriage review