-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Issue description
when changing to pytorch v0.4, casting API behavior changed. Here is a simple reproduce case:
import torch
a = torch.nn.Linear(1,2)
a.cuda(1)
print("weight of a:", a.weight)
a.type(torch.float16) # or use old a.type(torch.cuda.HalfTensor)
print("weight of a:", a.weight)
sample output:
weight of a: Parameter containing:
tensor([[ 0.1993],
[ 0.7559]], device='cuda:1')
weight of a: Parameter containing:
tensor([[ 0.1993],
[ 0.7559]], dtype=torch.float16, device='cuda:0')
The old behavior will keep tensor with new type on the device before casting
On the other hand
.half()
API respect tensor device
sample output:
weight of a: Parameter containing:
tensor([[ 0.2513],
[-0.9306]], device='cuda:1')
weight of a: Parameter containing:
tensor([[ 0.2512],
[-0.9307]], dtype=torch.float16, device='cuda:1')
Metadata
Metadata
Assignees
Labels
No labels