Skip to content

Commit 8307f21

Browse files
committed
Allow map_location in torch.load to be a string
1 parent e393a4f commit 8307f21

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

test/test_torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4400,6 +4400,10 @@ def map_location(storage, loc):
44004400
self.assertEqual(type(tensor), torch.FloatTensor)
44014401
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
44024402

4403+
tensor = torch.load(test_file_path, map_location='cpu')
4404+
self.assertEqual(type(tensor), torch.FloatTensor)
4405+
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
4406+
44034407
def test_from_buffer(self):
44044408
a = bytearray([1, 2, 3, 4])
44054409
self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])

torch/serialization.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
from contextlib import closing, contextmanager
1212
from ._utils import _import_dotted_name
13+
from ._six import string_classes as _string_classes
1314
if sys.version_info[0] == 2:
1415
import cPickle as pickle
1516
else:
@@ -225,7 +226,10 @@ def load(f, map_location=None, pickle_module=pickle):
225226
the right device. Otherwise, torch.load will fall back to the default behavior,
226227
as if map_location wasn't specified.
227228
228-
If map_location is a dict, it will be used to remap location tags
229+
If map_location is a string, it should be a device tag, where all tensors
230+
should be loaded.
231+
232+
Otherwise, if map_location is a dict, it will be used to remap location tags
229233
appearing in the file (keys), to ones that specify where to put the
230234
storages (values).
231235
@@ -236,14 +240,16 @@ def load(f, map_location=None, pickle_module=pickle):
236240
f: a file-like object (has to implement fileno that returns a file
237241
descriptor, and must implement seek), or a string containing a file
238242
name
239-
map_location: a function or a dict specifying how to remap storage
243+
map_location: a function, string or a dict specifying how to remap storage
240244
locations
241245
pickle_module: module used for unpickling metadata and objects (has to
242246
match the pickle_module used to serialize file)
243247
244248
Example:
245249
>>> torch.load('tensors.pt')
246250
# Load all tensors onto the CPU
251+
>>> torch.load('tensors.pt', map_location='cpu')
252+
# Load all tensors onto the CPU, using a function
247253
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
248254
# Load all tensors onto GPU 1
249255
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
@@ -273,6 +279,9 @@ def _load(f, map_location, pickle_module):
273279
def restore_location(storage, location):
274280
location = map_location.get(location, location)
275281
return default_restore_location(storage, location)
282+
elif isinstance(map_location, _string_classes):
283+
def restore_location(storage, location):
284+
return default_restore_location(storage, map_location)
276285
else:
277286
def restore_location(storage, location):
278287
result = map_location(storage, location)

0 commit comments

Comments
 (0)