Skip to content

Conversation

@EthanSteinberg
Copy link
Contributor

This pull request allows you to use torch.device as a map location when using torch.load. It's a pretty minor thing, but it helps makes the API a bit more consistent with letting you use torch.device everywhere.

This partially solves #7178.

My main question here is how much to change the torch.load API. This PR does the minimal change, which is simply adding torch.device as yet another option for specifying the map_location. The real question is whether torch.device should also be supported when a dict or function is passed. There is sorta a tradeoff here between consistency and stability. There is also the question of whether we want to support torch.device as keys.

Let me know what you guys think.

@EthanSteinberg EthanSteinberg changed the title Allow using torch.device for loading Allow the use of torch.device for loading May 7, 2018
elif isinstance(map_location, _string_classes):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test needs to be deduplicated. Apart from that LGTM.

def load_bytes():
with open(test_file_path, 'rb') as f:
data = io.BytesIO(f.read())
return data

This comment was marked as off-topic.

This comment was marked as off-topic.

self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
def test_serialization_map_location_cuda(self):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

elif isinstance(map_location, _string_classes):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two last minor things. Thanks!

elif map_location.type == 'cuda':
map_str = 'cuda:{}'.format(map_location.index or 0)
else:
raise ValueError("The given map_location device is not a cpu or cuda")

This comment was marked as off-topic.

This comment was marked as off-topic.

for map_location in gpu_map_locations:
tensor = torch.load(fileobject_lambda(), map_location=map_location)
self.assertIsInstance(tensor, torch.cuda.FloatTensor)
self.assertEqual(tensor, torch.cuda.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented May 10, 2018

@pytorchbot retest this please

@soumith soumith merged commit 9fa1dff into pytorch:master May 10, 2018
onnxbot added a commit to onnxbot/onnx-fb-universe that referenced this pull request May 10, 2018
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* Allow using torch.device for loading

* Make recommended changes

* Better tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants