55import random
66import operator
77import copy
8+ import shutil
89import torch
910import torch .cuda
1011import tempfile
1112import unittest
1213import warnings
1314import pickle
15+ import gzip
1416from torch .utils .dlpack import from_dlpack , to_dlpack
1517from torch ._utils import _rebuild_tensor
1618from itertools import product , combinations
@@ -6169,7 +6171,7 @@ def test_parsing_intlist(self):
61696171 self .assertRaises (TypeError , lambda : torch .ones (np .array (3 , 3 )))
61706172 self .assertRaises (TypeError , lambda : torch .ones ((np .array (3 , 3 ))))
61716173
6172- def _test_serialization (self , filecontext_lambda , test_use_filename = True ):
6174+ def _test_serialization_data (self ):
61736175 a = [torch .randn (5 , 5 ).float () for i in range (2 )]
61746176 b = [a [i % 2 ] for i in range (4 )]
61756177 b += [a [0 ].storage ()]
@@ -6179,68 +6181,115 @@ def _test_serialization(self, filecontext_lambda, test_use_filename=True):
61796181 t2 = torch .FloatTensor ().set_ (a [0 ].storage ()[1 :4 ], 0 , (3 ,), (1 ,))
61806182 b += [(t1 .storage (), t1 .storage (), t2 .storage ())]
61816183 b += [a [0 ].storage ()[0 :2 ]]
6182- if test_use_filename :
6183- use_name_options = (False , True )
6184- else :
6185- use_name_options = (False ,)
6186- for use_name in use_name_options :
6184+ return b
6185+
6186+ def _test_serialization_assert (self , b , c ):
6187+ self .assertEqual (b , c , 0 )
6188+ self .assertTrue (isinstance (c [0 ], torch .FloatTensor ))
6189+ self .assertTrue (isinstance (c [1 ], torch .FloatTensor ))
6190+ self .assertTrue (isinstance (c [2 ], torch .FloatTensor ))
6191+ self .assertTrue (isinstance (c [3 ], torch .FloatTensor ))
6192+ self .assertTrue (isinstance (c [4 ], torch .FloatStorage ))
6193+ c [0 ].fill_ (10 )
6194+ self .assertEqual (c [0 ], c [2 ], 0 )
6195+ self .assertEqual (c [4 ], torch .FloatStorage (25 ).fill_ (10 ), 0 )
6196+ c [1 ].fill_ (20 )
6197+ self .assertEqual (c [1 ], c [3 ], 0 )
6198+ self .assertEqual (c [4 ][1 :4 ], c [5 ], 0 )
6199+
6200+ # check that serializing the same storage view object unpickles
6201+ # it as one object not two (and vice versa)
6202+ views = c [7 ]
6203+ self .assertEqual (views [0 ]._cdata , views [1 ]._cdata )
6204+ self .assertEqual (views [0 ], views [2 ])
6205+ self .assertNotEqual (views [0 ]._cdata , views [2 ]._cdata )
6206+
6207+ rootview = c [8 ]
6208+ self .assertEqual (rootview .data_ptr (), c [0 ].data_ptr ())
6209+
6210+ def test_serialization (self ):
6211+ # Test serialization with a real file
6212+ b = self ._test_serialization_data ()
6213+ for use_name in (False , True ):
61876214 # Passing filename to torch.save(...) will cause the file to be opened twice,
61886215 # which is not supported on Windows
61896216 if sys .platform == "win32" and use_name :
61906217 continue
6191- with filecontext_lambda () as f :
6218+ with tempfile . NamedTemporaryFile () as f :
61926219 handle = f if not use_name else f .name
61936220 torch .save (b , handle )
61946221 f .seek (0 )
61956222 c = torch .load (handle )
6196- self .assertEqual (b , c , 0 )
6197- self .assertTrue (isinstance (c [0 ], torch .FloatTensor ))
6198- self .assertTrue (isinstance (c [1 ], torch .FloatTensor ))
6199- self .assertTrue (isinstance (c [2 ], torch .FloatTensor ))
6200- self .assertTrue (isinstance (c [3 ], torch .FloatTensor ))
6201- self .assertTrue (isinstance (c [4 ], torch .FloatStorage ))
6202- c [0 ].fill_ (10 )
6203- self .assertEqual (c [0 ], c [2 ], 0 )
6204- self .assertEqual (c [4 ], torch .FloatStorage (25 ).fill_ (10 ), 0 )
6205- c [1 ].fill_ (20 )
6206- self .assertEqual (c [1 ], c [3 ], 0 )
6207- self .assertEqual (c [4 ][1 :4 ], c [5 ], 0 )
6208-
6209- # check that serializing the same storage view object unpickles
6210- # it as one object not two (and vice versa)
6211- views = c [7 ]
6212- self .assertEqual (views [0 ]._cdata , views [1 ]._cdata )
6213- self .assertEqual (views [0 ], views [2 ])
6214- self .assertNotEqual (views [0 ]._cdata , views [2 ]._cdata )
6215-
6216- rootview = c [8 ]
6217- self .assertEqual (rootview .data_ptr (), c [0 ].data_ptr ())
6218-
6219- def test_serialization (self ):
6220- # Test serialization with a real file
6221- self ._test_serialization (tempfile .NamedTemporaryFile )
6223+ self ._test_serialization_assert (b , c )
62226224
62236225 def test_serialization_filelike (self ):
62246226 # Test serialization (load and save) with a filelike object
6225- self ._test_serialization (BytesIOContext , test_use_filename = False )
6227+ b = self ._test_serialization_data ()
6228+ with BytesIOContext () as f :
6229+ torch .save (b , f )
6230+ f .seek (0 )
6231+ c = torch .load (f )
6232+ self ._test_serialization_assert (b , c )
6233+
6234+ def test_serialization_gzip (self ):
6235+ # Test serialization with gzip file
6236+ b = self ._test_serialization_data ()
6237+ f1 = tempfile .NamedTemporaryFile (delete = False )
6238+ f2 = tempfile .NamedTemporaryFile (delete = False )
6239+ torch .save (b , f1 )
6240+ with open (f1 .name , 'rb' ) as f_in , gzip .open (f2 .name , 'wb' ) as f_out :
6241+ shutil .copyfileobj (f_in , f_out )
6242+
6243+ with gzip .open (f2 .name , 'rb' ) as f :
6244+ c = torch .load (f )
6245+ self ._test_serialization_assert (b , c )
6246+
6247+ def test_serialization_offset (self ):
6248+ a = torch .randn (5 , 5 )
6249+ i = 41
6250+ for use_name in (False , True ):
6251+ # Passing filename to torch.save(...) will cause the file to be opened twice,
6252+ # which is not supported on Windows
6253+ if sys .platform == "win32" and use_name :
6254+ continue
6255+ with tempfile .NamedTemporaryFile () as f :
6256+ handle = f if not use_name else f .name
6257+ pickle .dump (i , f )
6258+ torch .save (a , f )
6259+ f .seek (0 )
6260+ j = pickle .load (f )
6261+ b = torch .load (f )
6262+ self .assertTrue (torch .equal (a , b ))
6263+ self .assertEqual (i , j )
62266264
6227- def _test_serialization_offset (self , filecontext_lambda ):
6265+ def test_serialization_offset_filelike (self ):
62286266 a = torch .randn (5 , 5 )
62296267 i = 41
6230- with tempfile . TemporaryFile () as f :
6268+ with BytesIOContext () as f :
62316269 pickle .dump (i , f )
62326270 torch .save (a , f )
62336271 f .seek (0 )
62346272 j = pickle .load (f )
62356273 b = torch .load (f )
6236- self .assertTrue (torch .equal (a , b ))
6237- self .assertEqual (i , j )
6274+ self .assertTrue (torch .equal (a , b ))
6275+ self .assertEqual (i , j )
62386276
6239- def test_serialization_offset (self ):
6240- self ._test_serialization_offset (tempfile .TemporaryFile )
6277+ def test_serialization_offset_gzip (self ):
6278+ a = torch .randn (5 , 5 )
6279+ i = 41
6280+ f1 = tempfile .NamedTemporaryFile (delete = False )
6281+ f2 = tempfile .NamedTemporaryFile (delete = False )
6282+ with open (f1 .name , 'wb' ) as f :
6283+ pickle .dump (i , f )
6284+ torch .save (a , f )
6285+ with open (f1 .name , 'rb' ) as f_in , gzip .open (f2 .name , 'wb' ) as f_out :
6286+ shutil .copyfileobj (f_in , f_out )
62416287
6242- def test_serialization_offset_filelike (self ):
6243- self ._test_serialization_offset (BytesIOContext )
6288+ with gzip .open (f2 .name , 'rb' ) as f :
6289+ j = pickle .load (f )
6290+ b = torch .load (f )
6291+ self .assertTrue (torch .equal (a , b ))
6292+ self .assertEqual (i , j )
62446293
62456294 def test_half_tensor (self ):
62466295 x = torch .randn (5 , 5 ).float ()
0 commit comments