-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Allow the use of torch.device for loading #7339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b0a1498 to
ec6af39
Compare
| elif isinstance(map_location, _string_classes): | ||
| def restore_location(storage, location): | ||
| return default_restore_location(storage, map_location) | ||
| elif isinstance(map_location, torch.device): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test needs to be deduplicated. Apart from that LGTM.
test/test_torch.py
Outdated
| def load_bytes(): | ||
| with open(test_file_path, 'rb') as f: | ||
| data = io.BytesIO(f.read()) | ||
| return data |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]])) | ||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') | ||
| def test_serialization_map_location_cuda(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| elif isinstance(map_location, _string_classes): | ||
| def restore_location(storage, location): | ||
| return default_restore_location(storage, map_location) | ||
| elif isinstance(map_location, torch.device): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two last minor things. Thanks!
torch/serialization.py
Outdated
| elif map_location.type == 'cuda': | ||
| map_str = 'cuda:{}'.format(map_location.index or 0) | ||
| else: | ||
| raise ValueError("The given map_location device is not a cpu or cuda") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| for map_location in gpu_map_locations: | ||
| tensor = torch.load(fileobject_lambda(), map_location=map_location) | ||
| self.assertIsInstance(tensor, torch.cuda.FloatTensor) | ||
| self.assertEqual(tensor, torch.cuda.FloatTensor([[1.0, 2.0], [3.0, 4.0]])) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
* Allow using torch.device for loading * Make recommended changes * Better tests
This pull request allows you to use torch.device as a map location when using torch.load. It's a pretty minor thing, but it helps makes the API a bit more consistent with letting you use torch.device everywhere.
This partially solves #7178.
My main question here is how much to change the torch.load API. This PR does the minimal change, which is simply adding torch.device as yet another option for specifying the map_location. The real question is whether torch.device should also be supported when a dict or function is passed. There is sorta a tradeoff here between consistency and stability. There is also the question of whether we want to support torch.device as keys.
Let me know what you guys think.