Skip to content

Should torch.load and torch.save take a device (since PyTorch 0.4) #7178

@daniel-j-h

Description

@daniel-j-h

Reading the PyTorch 0.4.0 migration guide I came across the torch.device abstraction.

It is great for creating a device once and then passing it around and using the .to functions, e.g. as in:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = data.to(device)

The torch.load function does not take a device, though. Instead it takes a map_location argument which can either be a lambda, a mapping, or since #4203 it can be a string like 'cpu'.

Now the question is why are there these two different concepts and can they be unified into one device abstraction? Otherwise we can pass the device around except for serialization where we need to transform the device abstraction into a map_location parameter.

Can we unify these concepts behind an API like the following?

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
restored = torch.load('model.pth', device=device)

Related: #6630 - torch.save should also take a device

Metadata

Metadata

Assignees

No one assigned

    Labels

    todoNot as important as medium or high priority tasks, but we will work on these.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions