1010import warnings
1111from contextlib import closing , contextmanager
1212from ._utils import _import_dotted_name
13+ from ._six import string_classes as _string_classes
1314if sys .version_info [0 ] == 2 :
1415 import cPickle as pickle
1516else :
@@ -225,7 +226,10 @@ def load(f, map_location=None, pickle_module=pickle):
225226 the right device. Otherwise, torch.load will fall back to the default behavior,
226227 as if map_location wasn't specified.
227228
228- If map_location is a dict, it will be used to remap location tags
229+ If map_location is a string, it should be a device tag, where all tensors
230+ should be loaded.
231+
232+ Otherwise, if map_location is a dict, it will be used to remap location tags
229233 appearing in the file (keys), to ones that specify where to put the
230234 storages (values).
231235
@@ -236,14 +240,16 @@ def load(f, map_location=None, pickle_module=pickle):
236240 f: a file-like object (has to implement fileno that returns a file
237241 descriptor, and must implement seek), or a string containing a file
238242 name
239- map_location: a function or a dict specifying how to remap storage
243+ map_location: a function, string or a dict specifying how to remap storage
240244 locations
241245 pickle_module: module used for unpickling metadata and objects (has to
242246 match the pickle_module used to serialize file)
243247
244248 Example:
245249 >>> torch.load('tensors.pt')
246250 # Load all tensors onto the CPU
251+ >>> torch.load('tensors.pt', map_location='cpu')
252+ # Load all tensors onto the CPU, using a function
247253 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
248254 # Load all tensors onto GPU 1
249255 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
@@ -273,6 +279,9 @@ def _load(f, map_location, pickle_module):
273279 def restore_location (storage , location ):
274280 location = map_location .get (location , location )
275281 return default_restore_location (storage , location )
282+ elif isinstance (map_location , _string_classes ):
283+ def restore_location (storage , location ):
284+ return default_restore_location (storage , map_location )
276285 else :
277286 def restore_location (storage , location ):
278287 result = map_location (storage , location )
0 commit comments