@@ -6359,17 +6359,43 @@ def map_location(storage, loc):
63596359
63606360 def load_bytes ():
63616361 with open (test_file_path , 'rb' ) as f :
6362- data = io .BytesIO (f .read ())
6363- return data
6362+ return io .BytesIO (f .read ())
63646363
63656364 fileobject_lambdas = [lambda : test_file_path , load_bytes ]
6366- map_locations = [map_location , {'cuda:0' : 'cpu' }, 'cpu' ]
6365+ cpu_map_locations = [
6366+ map_location ,
6367+ {'cuda:0' : 'cpu' },
6368+ 'cpu' ,
6369+ torch .device ('cpu' ),
6370+ ]
6371+ gpu_0_map_locations = [
6372+ {'cuda:0' : 'cuda:0' },
6373+ 'cuda' ,
6374+ 'cuda:0' ,
6375+ torch .device ('cuda' ),
6376+ torch .device ('cuda' , 0 )
6377+ ]
6378+ gpu_last_map_locations = [
6379+ 'cuda:{}' .format (torch .cuda .device_count () - 1 ),
6380+ ]
63676381
6368- for fileobject_lambda in fileobject_lambdas :
6369- for map_location in map_locations :
6370- tensor = torch .load (fileobject_lambda (), map_location = map_location )
6371- self .assertIsInstance (tensor , torch .FloatTensor )
6372- self .assertEqual (tensor , torch .FloatTensor ([[1.0 , 2.0 ], [3.0 , 4.0 ]]))
6382+ def check_map_locations (map_locations , tensor_class , intended_device ):
6383+ for fileobject_lambda in fileobject_lambdas :
6384+ for map_location in map_locations :
6385+ tensor = torch .load (fileobject_lambda (), map_location = map_location )
6386+
6387+ self .assertEqual (tensor .device , intended_device )
6388+ self .assertIsInstance (tensor , tensor_class )
6389+ self .assertEqual (tensor , tensor_class ([[1.0 , 2.0 ], [3.0 , 4.0 ]]))
6390+
6391+ check_map_locations (cpu_map_locations , torch .FloatTensor , torch .device ('cpu' ))
6392+ if torch .cuda .is_available ():
6393+ check_map_locations (gpu_0_map_locations , torch .cuda .FloatTensor , torch .device ('cuda' , 0 ))
6394+ check_map_locations (
6395+ gpu_last_map_locations ,
6396+ torch .cuda .FloatTensor ,
6397+ torch .device ('cuda' , torch .cuda .device_count () - 1 )
6398+ )
63736399
63746400 def test_serialization_filelike_api_requirements (self ):
63756401 filemock = FilelikeMock (b'' , has_readinto = False )
0 commit comments