Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6261,17 +6261,43 @@ def map_location(storage, loc):

def load_bytes():
with open(test_file_path, 'rb') as f:
data = io.BytesIO(f.read())
return data
return io.BytesIO(f.read())

fileobject_lambdas = [lambda: test_file_path, load_bytes]
map_locations = [map_location, {'cuda:0': 'cpu'}, 'cpu']
cpu_map_locations = [
map_location,
{'cuda:0': 'cpu'},
'cpu',
torch.device('cpu'),
]
gpu_0_map_locations = [
{'cuda:0': 'cuda:0'},
'cuda',
'cuda:0',
torch.device('cuda'),
torch.device('cuda', 0)
]
gpu_last_map_locations = [
'cuda:{}'.format(torch.cuda.device_count() - 1),
]

for fileobject_lambda in fileobject_lambdas:
for map_location in map_locations:
tensor = torch.load(fileobject_lambda(), map_location=map_location)
self.assertIsInstance(tensor, torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
def check_map_locations(map_locations, tensor_class, intended_device):
for fileobject_lambda in fileobject_lambdas:
for map_location in map_locations:
tensor = torch.load(fileobject_lambda(), map_location=map_location)

self.assertEqual(tensor.device, intended_device)
self.assertIsInstance(tensor, tensor_class)
self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]]))

check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu'))
if torch.cuda.is_available():
check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0))
check_map_locations(
gpu_last_map_locations,
torch.cuda.FloatTensor,
torch.device('cuda', torch.cuda.device_count() - 1)
)

def test_serialization_filelike_api_requirements(self):
filemock = FilelikeMock(b'', has_readinto=False)
Expand Down
12 changes: 9 additions & 3 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def _cpu_deserialize(obj, location):

def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device = max(int(location[5:]), 0)
if location[5:] == '':
device = 0
else:
device = max(int(location[5:]), 0)
return obj.cuda(device)


Expand Down Expand Up @@ -273,15 +276,15 @@ def load(f, map_location=None, pickle_module=pickle):
Args:
f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
map_location: a function, string or a dict specifying how to remap storage
map_location: a function, torch.device, string or a dict specifying how to remap storage
locations
pickle_module: module used for unpickling metadata and objects (has to
match the pickle_module used to serialize file)

Example:
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location='cpu')
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
Expand Down Expand Up @@ -318,6 +321,9 @@ def restore_location(storage, location):
elif isinstance(map_location, _string_classes):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

def restore_location(storage, location):
return default_restore_location(storage, str(map_location))
else:
def restore_location(storage, location):
result = map_location(storage, location)
Expand Down