Skip to content

Commit 9fa1dff

Browse files
EthanSteinbergsoumith
authored andcommitted
Allow the use of torch.device for loading (#7339)
* Allow using torch.device for loading * Make recommended changes * Better tests
1 parent b6adf68 commit 9fa1dff

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

test/test_torch.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6359,17 +6359,43 @@ def map_location(storage, loc):
63596359

63606360
def load_bytes():
63616361
with open(test_file_path, 'rb') as f:
6362-
data = io.BytesIO(f.read())
6363-
return data
6362+
return io.BytesIO(f.read())
63646363

63656364
fileobject_lambdas = [lambda: test_file_path, load_bytes]
6366-
map_locations = [map_location, {'cuda:0': 'cpu'}, 'cpu']
6365+
cpu_map_locations = [
6366+
map_location,
6367+
{'cuda:0': 'cpu'},
6368+
'cpu',
6369+
torch.device('cpu'),
6370+
]
6371+
gpu_0_map_locations = [
6372+
{'cuda:0': 'cuda:0'},
6373+
'cuda',
6374+
'cuda:0',
6375+
torch.device('cuda'),
6376+
torch.device('cuda', 0)
6377+
]
6378+
gpu_last_map_locations = [
6379+
'cuda:{}'.format(torch.cuda.device_count() - 1),
6380+
]
63676381

6368-
for fileobject_lambda in fileobject_lambdas:
6369-
for map_location in map_locations:
6370-
tensor = torch.load(fileobject_lambda(), map_location=map_location)
6371-
self.assertIsInstance(tensor, torch.FloatTensor)
6372-
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
6382+
def check_map_locations(map_locations, tensor_class, intended_device):
6383+
for fileobject_lambda in fileobject_lambdas:
6384+
for map_location in map_locations:
6385+
tensor = torch.load(fileobject_lambda(), map_location=map_location)
6386+
6387+
self.assertEqual(tensor.device, intended_device)
6388+
self.assertIsInstance(tensor, tensor_class)
6389+
self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]]))
6390+
6391+
check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu'))
6392+
if torch.cuda.is_available():
6393+
check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0))
6394+
check_map_locations(
6395+
gpu_last_map_locations,
6396+
torch.cuda.FloatTensor,
6397+
torch.device('cuda', torch.cuda.device_count() - 1)
6398+
)
63736399

63746400
def test_serialization_filelike_api_requirements(self):
63756401
filemock = FilelikeMock(b'', has_readinto=False)

torch/serialization.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def _cpu_deserialize(obj, location):
6666

6767
def _cuda_deserialize(obj, location):
6868
if location.startswith('cuda'):
69-
device = max(int(location[5:]), 0)
69+
if location[5:] == '':
70+
device = 0
71+
else:
72+
device = max(int(location[5:]), 0)
7073
return obj.cuda(device)
7174

7275

@@ -273,15 +276,15 @@ def load(f, map_location=None, pickle_module=pickle):
273276
Args:
274277
f: a file-like object (has to implement read, readline, tell, and seek),
275278
or a string containing a file name
276-
map_location: a function, string or a dict specifying how to remap storage
279+
map_location: a function, torch.device, string or a dict specifying how to remap storage
277280
locations
278281
pickle_module: module used for unpickling metadata and objects (has to
279282
match the pickle_module used to serialize file)
280283
281284
Example:
282285
>>> torch.load('tensors.pt')
283286
# Load all tensors onto the CPU
284-
>>> torch.load('tensors.pt', map_location='cpu')
287+
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
285288
# Load all tensors onto the CPU, using a function
286289
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
287290
# Load all tensors onto GPU 1
@@ -318,6 +321,9 @@ def restore_location(storage, location):
318321
elif isinstance(map_location, _string_classes):
319322
def restore_location(storage, location):
320323
return default_restore_location(storage, map_location)
324+
elif isinstance(map_location, torch.device):
325+
def restore_location(storage, location):
326+
return default_restore_location(storage, str(map_location))
321327
else:
322328
def restore_location(storage, location):
323329
result = map_location(storage, location)

0 commit comments

Comments
 (0)