Skip to content

Commit ec6af39

Browse files
Allow using torch.device for loading
1 parent 4c51107 commit ec6af39

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

test/test_torch.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6265,14 +6265,32 @@ def load_bytes():
62656265
return data
62666266

62676267
fileobject_lambdas = [lambda: test_file_path, load_bytes]
6268-
map_locations = [map_location, {'cuda:0': 'cpu'}, 'cpu']
6268+
map_locations = [map_location, {'cuda:0': 'cpu'}, 'cpu', torch.device('cpu')]
62696269

62706270
for fileobject_lambda in fileobject_lambdas:
62716271
for map_location in map_locations:
62726272
tensor = torch.load(fileobject_lambda(), map_location=map_location)
62736273
self.assertIsInstance(tensor, torch.FloatTensor)
62746274
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
62756275

6276+
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
6277+
def test_serialization_map_location_cuda(self):
6278+
test_file_path = download_file('https://download.pytorch.org/test_data/gpu_tensors.pt')
6279+
6280+
def load_bytes():
6281+
with open(test_file_path, 'rb') as f:
6282+
data = io.BytesIO(f.read())
6283+
return data
6284+
6285+
fileobject_lambdas = [lambda: test_file_path, load_bytes]
6286+
map_locations = [{'cuda:0': 'cuda:0'}, 'cuda:0', torch.device('cuda'), torch.device('cuda', 0)]
6287+
6288+
for fileobject_lambda in fileobject_lambdas:
6289+
for map_location in map_locations:
6290+
tensor = torch.load(fileobject_lambda(), map_location=map_location)
6291+
self.assertIsInstance(tensor, torch.cuda.FloatTensor)
6292+
self.assertEqual(tensor, torch.cuda.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
6293+
62766294
def test_serialization_filelike_api_requirements(self):
62776295
filemock = FilelikeMock(b'', has_readinto=False)
62786296
tensor = torch.randn(3, 5)

torch/serialization.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,15 @@ def load(f, map_location=None, pickle_module=pickle):
273273
Args:
274274
f: a file-like object (has to implement read, readline, tell, and seek),
275275
or a string containing a file name
276-
map_location: a function, string or a dict specifying how to remap storage
276+
map_location: a function, torch.device, string or a dict specifying how to remap storage
277277
locations
278278
pickle_module: module used for unpickling metadata and objects (has to
279279
match the pickle_module used to serialize file)
280280
281281
Example:
282282
>>> torch.load('tensors.pt')
283283
# Load all tensors onto the CPU
284-
>>> torch.load('tensors.pt', map_location='cpu')
284+
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
285285
# Load all tensors onto the CPU, using a function
286286
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
287287
# Load all tensors onto GPU 1
@@ -318,6 +318,16 @@ def restore_location(storage, location):
318318
elif isinstance(map_location, _string_classes):
319319
def restore_location(storage, location):
320320
return default_restore_location(storage, map_location)
321+
elif isinstance(map_location, torch.device):
322+
if map_location.type == 'cpu':
323+
map_str = 'cpu'
324+
elif map_location.type == 'cuda':
325+
map_str = 'cuda:{}'.format(map_location.index or 0)
326+
else:
327+
raise ValueError("The given map_location device is not a cpu or cuda")
328+
329+
def restore_location(storage, location):
330+
return default_restore_location(storage, map_str)
321331
else:
322332
def restore_location(storage, location):
323333
result = map_location(storage, location)

0 commit comments

Comments
 (0)