-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
todoNot as important as medium or high priority tasks, but we will work on these.Not as important as medium or high priority tasks, but we will work on these.
Description
Reading the PyTorch 0.4.0 migration guide I came across the torch.device abstraction.
It is great for creating a device once and then passing it around and using the .to functions, e.g. as in:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = data.to(device)The torch.load function does not take a device, though. Instead it takes a map_location argument which can either be a lambda, a mapping, or since #4203 it can be a string like 'cpu'.
Now the question is why are there these two different concepts and can they be unified into one device abstraction? Otherwise we can pass the device around except for serialization where we need to transform the device abstraction into a map_location parameter.
Can we unify these concepts behind an API like the following?
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
restored = torch.load('model.pth', device=device)Related: #6630 - torch.save should also take a device
zou3519, huanglianghua, bkowshik and UdonDa
Metadata
Metadata
Assignees
Labels
todoNot as important as medium or high priority tasks, but we will work on these.Not as important as medium or high priority tasks, but we will work on these.