55import random
66import operator
77import copy
8+ import shutil
89import torch
910import torch .cuda
1011import tempfile
@@ -6169,7 +6170,7 @@ def test_parsing_intlist(self):
61696170 self .assertRaises (TypeError , lambda : torch .ones (np .array (3 , 3 )))
61706171 self .assertRaises (TypeError , lambda : torch .ones ((np .array (3 , 3 ))))
61716172
6172- def _test_serialization (self , filecontext_lambda , filecontext_read_lambda = None , test_use_filename = True ):
6173+ def _test_serialization_data (self ):
61736174 a = [torch .randn (5 , 5 ).float () for i in range (2 )]
61746175 b = [a [i % 2 ] for i in range (4 )]
61756176 b += [a [0 ].storage ()]
@@ -6179,95 +6180,115 @@ def _test_serialization(self, filecontext_lambda, filecontext_read_lambda=None,
61796180 t2 = torch .FloatTensor ().set_ (a [0 ].storage ()[1 :4 ], 0 , (3 ,), (1 ,))
61806181 b += [(t1 .storage (), t1 .storage (), t2 .storage ())]
61816182 b += [a [0 ].storage ()[0 :2 ]]
6182- if test_use_filename :
6183- use_name_options = (False , True )
6184- else :
6185- use_name_options = (False ,)
6186- for use_name in use_name_options :
6183+ return b
6184+
6185+ def _test_serialization_assert (self , b , c ):
6186+ self .assertEqual (b , c , 0 )
6187+ self .assertTrue (isinstance (c [0 ], torch .FloatTensor ))
6188+ self .assertTrue (isinstance (c [1 ], torch .FloatTensor ))
6189+ self .assertTrue (isinstance (c [2 ], torch .FloatTensor ))
6190+ self .assertTrue (isinstance (c [3 ], torch .FloatTensor ))
6191+ self .assertTrue (isinstance (c [4 ], torch .FloatStorage ))
6192+ c [0 ].fill_ (10 )
6193+ self .assertEqual (c [0 ], c [2 ], 0 )
6194+ self .assertEqual (c [4 ], torch .FloatStorage (25 ).fill_ (10 ), 0 )
6195+ c [1 ].fill_ (20 )
6196+ self .assertEqual (c [1 ], c [3 ], 0 )
6197+ self .assertEqual (c [4 ][1 :4 ], c [5 ], 0 )
6198+
6199+ # check that serializing the same storage view object unpickles
6200+ # it as one object not two (and vice versa)
6201+ views = c [7 ]
6202+ self .assertEqual (views [0 ]._cdata , views [1 ]._cdata )
6203+ self .assertEqual (views [0 ], views [2 ])
6204+ self .assertNotEqual (views [0 ]._cdata , views [2 ]._cdata )
6205+
6206+ rootview = c [8 ]
6207+ self .assertEqual (rootview .data_ptr (), c [0 ].data_ptr ())
6208+
6209+ def test_serialization (self ):
6210+ # Test serialization with a real file
6211+ b = self ._test_serialization_data ()
6212+ for use_name in (False , True ):
61876213 # Passing filename to torch.save(...) will cause the file to be opened twice,
61886214 # which is not supported on Windows
61896215 if sys .platform == "win32" and use_name :
61906216 continue
6191- if filecontext_read_lambda :
6192- with filecontext_lambda () as f :
6193- handle = f if not use_name else f .name
6194- torch .save (b , handle )
6195- with filecontext_read_lambda () as f :
6196- handle = f if not use_name else f .name
6197- c = torch .load (handle )
6198- else :
6199- with filecontext_lambda () as f :
6200- handle = f if not use_name else f .name
6201- torch .save (b , handle )
6202- f .seek (0 )
6203- c = torch .load (handle )
6204- self .assertEqual (b , c , 0 )
6205- self .assertTrue (isinstance (c [0 ], torch .FloatTensor ))
6206- self .assertTrue (isinstance (c [1 ], torch .FloatTensor ))
6207- self .assertTrue (isinstance (c [2 ], torch .FloatTensor ))
6208- self .assertTrue (isinstance (c [3 ], torch .FloatTensor ))
6209- self .assertTrue (isinstance (c [4 ], torch .FloatStorage ))
6210- c [0 ].fill_ (10 )
6211- self .assertEqual (c [0 ], c [2 ], 0 )
6212- self .assertEqual (c [4 ], torch .FloatStorage (25 ).fill_ (10 ), 0 )
6213- c [1 ].fill_ (20 )
6214- self .assertEqual (c [1 ], c [3 ], 0 )
6215- self .assertEqual (c [4 ][1 :4 ], c [5 ], 0 )
6216-
6217- # check that serializing the same storage view object unpickles
6218- # it as one object not two (and vice versa)
6219- views = c [7 ]
6220- self .assertEqual (views [0 ]._cdata , views [1 ]._cdata )
6221- self .assertEqual (views [0 ], views [2 ])
6222- self .assertNotEqual (views [0 ]._cdata , views [2 ]._cdata )
6223-
6224- rootview = c [8 ]
6225- self .assertEqual (rootview .data_ptr (), c [0 ].data_ptr ())
6226-
6227- def test_serialization (self ):
6228- # Test serialization with a real file
6229- self ._test_serialization (tempfile .NamedTemporaryFile )
6217+ with tempfile .NamedTemporaryFile () as f :
6218+ handle = f if not use_name else f .name
6219+ torch .save (b , handle )
6220+ f .seek (0 )
6221+ c = torch .load (handle )
6222+ self ._test_serialization_assert (b , c )
62306223
62316224 def test_serialization_filelike (self ):
62326225 # Test serialization (load and save) with a filelike object
6233- self ._test_serialization (BytesIOContext , test_use_filename = False )
6226+ b = self ._test_serialization_data ()
6227+ with BytesIOContext () as f :
6228+ torch .save (b , f )
6229+ f .seek (0 )
6230+ c = torch .load (f )
6231+ self ._test_serialization_assert (b , c )
62346232
62356233 def test_serialization_gzip (self ):
6236- f = tempfile .NamedTemporaryFile (delete = False )
6237- self ._test_serialization (lambda : gzip .open (f .name , mode = 'w+b' ),
6238- filecontext_read_lambda = lambda : gzip .open (f .name , mode = 'r+b' ),
6239- test_use_filename = False )
6234+ # Test serialization with gzip file
6235+ b = self ._test_serialization_data ()
6236+ f1 = tempfile .NamedTemporaryFile (delete = False )
6237+ f2 = tempfile .NamedTemporaryFile (delete = False )
6238+ torch .save (b , f1 )
6239+ with open (f1 .name , 'rb' ) as f_in , gzip .open (f2 .name , 'wb' ) as f_out :
6240+ shutil .copyfileobj (f_in , f_out )
6241+
6242+ with gzip .open (f2 .name , 'rb' ) as f :
6243+ c = torch .load (f )
6244+ self ._test_serialization_assert (b , c )
62406245
6241- def _test_serialization_offset (self , filecontext_lambda , filecontext_read_lambda = None ):
6246+ def test_serialization_offset (self ):
62426247 a = torch .randn (5 , 5 )
62436248 i = 41
6244- if filecontext_read_lambda :
6245- with filecontext_lambda () as f :
6246- pickle .dump (i , f )
6247- torch .save (a , f )
6248- with filecontext_read_lambda () as f :
6249- j = pickle .load (f )
6250- b = torch .load (f )
6251- else :
6252- with filecontext_lambda () as f :
6249+ for use_name in (False , True ):
6250+ # Passing filename to torch.save(...) will cause the file to be opened twice,
6251+ # which is not supported on Windows
6252+ if sys .platform == "win32" and use_name :
6253+ continue
6254+ with tempfile .NamedTemporaryFile () as f :
6255+ handle = f if not use_name else f .name
62536256 pickle .dump (i , f )
62546257 torch .save (a , f )
62556258 f .seek (0 )
62566259 j = pickle .load (f )
62576260 b = torch .load (f )
6258- self .assertTrue (torch .equal (a , b ))
6259- self .assertEqual (i , j )
6260-
6261- def test_serialization_offset (self ):
6262- self ._test_serialization_offset (tempfile .TemporaryFile )
6261+ self .assertTrue (torch .equal (a , b ))
6262+ self .assertEqual (i , j )
62636263
62646264 def test_serialization_offset_filelike (self ):
6265- self ._test_serialization_offset (BytesIOContext )
6265+ a = torch .randn (5 , 5 )
6266+ i = 41
6267+ with BytesIOContext () as f :
6268+ pickle .dump (i , f )
6269+ torch .save (a , f )
6270+ f .seek (0 )
6271+ j = pickle .load (f )
6272+ b = torch .load (f )
6273+ self .assertTrue (torch .equal (a , b ))
6274+ self .assertEqual (i , j )
62666275
62676276 def test_serialization_offset_gzip (self ):
6268- f = tempfile .NamedTemporaryFile (delete = False )
6269- self ._test_serialization_offset (lambda : gzip .open (f .name , mode = 'w+b' ),
6270- filecontext_read_lambda = lambda : gzip .open (f .name , mode = 'r+b' ))
6277+ a = torch .randn (5 , 5 )
6278+ i = 41
6279+ f1 = tempfile .NamedTemporaryFile (delete = False )
6280+ f2 = tempfile .NamedTemporaryFile (delete = False )
6281+ with open (f1 .name , 'wb' ) as f :
6282+ pickle .dump (i , f )
6283+ torch .save (a , f )
6284+ with open (f1 .name , 'rb' ) as f_in , gzip .open (f2 .name , 'wb' ) as f_out :
6285+ shutil .copyfileobj (f_in , f_out )
6286+
6287+ with gzip .open (f2 .name , 'rb' ) as f :
6288+ j = pickle .load (f )
6289+ b = torch .load (f )
6290+ self .assertTrue (torch .equal (a , b ))
6291+ self .assertEqual (i , j )
62716292
62726293 def test_half_tensor (self ):
62736294 x = torch .randn (5 , 5 ).float ()
0 commit comments