@@ -273,15 +273,15 @@ def load(f, map_location=None, pickle_module=pickle):
273273 Args:
274274 f: a file-like object (has to implement read, readline, tell, and seek),
275275 or a string containing a file name
276- map_location: a function, string or a dict specifying how to remap storage
276+ map_location: a function, torch.device, string or a dict specifying how to remap storage
277277 locations
278278 pickle_module: module used for unpickling metadata and objects (has to
279279 match the pickle_module used to serialize file)
280280
281281 Example:
282282 >>> torch.load('tensors.pt')
283283 # Load all tensors onto the CPU
284- >>> torch.load('tensors.pt', map_location='cpu')
284+ >>> torch.load('tensors.pt', map_location=torch.device( 'cpu') )
285285 # Load all tensors onto the CPU, using a function
286286 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
287287 # Load all tensors onto GPU 1
@@ -318,6 +318,16 @@ def restore_location(storage, location):
318318 elif isinstance (map_location , _string_classes ):
319319 def restore_location (storage , location ):
320320 return default_restore_location (storage , map_location )
321+ elif isinstance (map_location , torch .device ):
322+ if map_location .type == 'cpu' :
323+ map_str = 'cpu'
324+ elif map_location .type == 'cuda' :
325+ map_str = 'cuda:{}' .format (map_location .index or 0 )
326+ else :
327+ raise ValueError ("The given map_location device is not a cpu or cuda" )
328+
329+ def restore_location (storage , location ):
330+ return default_restore_location (storage , map_str )
321331 else :
322332 def restore_location (storage , location ):
323333 result = map_location (storage , location )
0 commit comments