Skip to content

Commit a6d35cc

Browse files
author
Roy Li
committed
fix test for python2
1 parent 20bf41c commit a6d35cc

File tree

1 file changed

+90
-69
lines changed

1 file changed

+90
-69
lines changed

test/test_torch.py

Lines changed: 90 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import random
66
import operator
77
import copy
8+
import shutil
89
import torch
910
import torch.cuda
1011
import tempfile
@@ -6169,7 +6170,7 @@ def test_parsing_intlist(self):
61696170
self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
61706171
self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3))))
61716172

6172-
def _test_serialization(self, filecontext_lambda, filecontext_read_lambda=None, test_use_filename=True):
6173+
def _test_serialization_data(self):
61736174
a = [torch.randn(5, 5).float() for i in range(2)]
61746175
b = [a[i % 2] for i in range(4)]
61756176
b += [a[0].storage()]
@@ -6179,95 +6180,115 @@ def _test_serialization(self, filecontext_lambda, filecontext_read_lambda=None,
61796180
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
61806181
b += [(t1.storage(), t1.storage(), t2.storage())]
61816182
b += [a[0].storage()[0:2]]
6182-
if test_use_filename:
6183-
use_name_options = (False, True)
6184-
else:
6185-
use_name_options = (False,)
6186-
for use_name in use_name_options:
6183+
return b
6184+
6185+
def _test_serialization_assert(self, b, c):
6186+
self.assertEqual(b, c, 0)
6187+
self.assertTrue(isinstance(c[0], torch.FloatTensor))
6188+
self.assertTrue(isinstance(c[1], torch.FloatTensor))
6189+
self.assertTrue(isinstance(c[2], torch.FloatTensor))
6190+
self.assertTrue(isinstance(c[3], torch.FloatTensor))
6191+
self.assertTrue(isinstance(c[4], torch.FloatStorage))
6192+
c[0].fill_(10)
6193+
self.assertEqual(c[0], c[2], 0)
6194+
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
6195+
c[1].fill_(20)
6196+
self.assertEqual(c[1], c[3], 0)
6197+
self.assertEqual(c[4][1:4], c[5], 0)
6198+
6199+
# check that serializing the same storage view object unpickles
6200+
# it as one object not two (and vice versa)
6201+
views = c[7]
6202+
self.assertEqual(views[0]._cdata, views[1]._cdata)
6203+
self.assertEqual(views[0], views[2])
6204+
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
6205+
6206+
rootview = c[8]
6207+
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
6208+
6209+
def test_serialization(self):
6210+
# Test serialization with a real file
6211+
b = self._test_serialization_data()
6212+
for use_name in (False, True):
61876213
# Passing filename to torch.save(...) will cause the file to be opened twice,
61886214
# which is not supported on Windows
61896215
if sys.platform == "win32" and use_name:
61906216
continue
6191-
if filecontext_read_lambda:
6192-
with filecontext_lambda() as f:
6193-
handle = f if not use_name else f.name
6194-
torch.save(b, handle)
6195-
with filecontext_read_lambda() as f:
6196-
handle = f if not use_name else f.name
6197-
c = torch.load(handle)
6198-
else:
6199-
with filecontext_lambda() as f:
6200-
handle = f if not use_name else f.name
6201-
torch.save(b, handle)
6202-
f.seek(0)
6203-
c = torch.load(handle)
6204-
self.assertEqual(b, c, 0)
6205-
self.assertTrue(isinstance(c[0], torch.FloatTensor))
6206-
self.assertTrue(isinstance(c[1], torch.FloatTensor))
6207-
self.assertTrue(isinstance(c[2], torch.FloatTensor))
6208-
self.assertTrue(isinstance(c[3], torch.FloatTensor))
6209-
self.assertTrue(isinstance(c[4], torch.FloatStorage))
6210-
c[0].fill_(10)
6211-
self.assertEqual(c[0], c[2], 0)
6212-
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
6213-
c[1].fill_(20)
6214-
self.assertEqual(c[1], c[3], 0)
6215-
self.assertEqual(c[4][1:4], c[5], 0)
6216-
6217-
# check that serializing the same storage view object unpickles
6218-
# it as one object not two (and vice versa)
6219-
views = c[7]
6220-
self.assertEqual(views[0]._cdata, views[1]._cdata)
6221-
self.assertEqual(views[0], views[2])
6222-
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
6223-
6224-
rootview = c[8]
6225-
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
6226-
6227-
def test_serialization(self):
6228-
# Test serialization with a real file
6229-
self._test_serialization(tempfile.NamedTemporaryFile)
6217+
with tempfile.NamedTemporaryFile() as f:
6218+
handle = f if not use_name else f.name
6219+
torch.save(b, handle)
6220+
f.seek(0)
6221+
c = torch.load(handle)
6222+
self._test_serialization_assert(b, c)
62306223

62316224
def test_serialization_filelike(self):
62326225
# Test serialization (load and save) with a filelike object
6233-
self._test_serialization(BytesIOContext, test_use_filename=False)
6226+
b = self._test_serialization_data()
6227+
with BytesIOContext() as f:
6228+
torch.save(b, f)
6229+
f.seek(0)
6230+
c = torch.load(f)
6231+
self._test_serialization_assert(b, c)
62346232

62356233
def test_serialization_gzip(self):
6236-
f = tempfile.NamedTemporaryFile(delete=False)
6237-
self._test_serialization(lambda: gzip.open(f.name, mode='w+b'),
6238-
filecontext_read_lambda=lambda: gzip.open(f.name, mode='r+b'),
6239-
test_use_filename=False)
6234+
# Test serialization with gzip file
6235+
b = self._test_serialization_data()
6236+
f1 = tempfile.NamedTemporaryFile(delete=False)
6237+
f2 = tempfile.NamedTemporaryFile(delete=False)
6238+
torch.save(b, f1)
6239+
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
6240+
shutil.copyfileobj(f_in, f_out)
6241+
6242+
with gzip.open(f2.name, 'rb') as f:
6243+
c = torch.load(f)
6244+
self._test_serialization_assert(b, c)
62406245

6241-
def _test_serialization_offset(self, filecontext_lambda, filecontext_read_lambda=None):
6246+
def test_serialization_offset(self):
62426247
a = torch.randn(5, 5)
62436248
i = 41
6244-
if filecontext_read_lambda:
6245-
with filecontext_lambda() as f:
6246-
pickle.dump(i, f)
6247-
torch.save(a, f)
6248-
with filecontext_read_lambda() as f:
6249-
j = pickle.load(f)
6250-
b = torch.load(f)
6251-
else:
6252-
with filecontext_lambda() as f:
6249+
for use_name in (False, True):
6250+
# Passing filename to torch.save(...) will cause the file to be opened twice,
6251+
# which is not supported on Windows
6252+
if sys.platform == "win32" and use_name:
6253+
continue
6254+
with tempfile.NamedTemporaryFile() as f:
6255+
handle = f if not use_name else f.name
62536256
pickle.dump(i, f)
62546257
torch.save(a, f)
62556258
f.seek(0)
62566259
j = pickle.load(f)
62576260
b = torch.load(f)
6258-
self.assertTrue(torch.equal(a, b))
6259-
self.assertEqual(i, j)
6260-
6261-
def test_serialization_offset(self):
6262-
self._test_serialization_offset(tempfile.TemporaryFile)
6261+
self.assertTrue(torch.equal(a, b))
6262+
self.assertEqual(i, j)
62636263

62646264
def test_serialization_offset_filelike(self):
6265-
self._test_serialization_offset(BytesIOContext)
6265+
a = torch.randn(5, 5)
6266+
i = 41
6267+
with BytesIOContext() as f:
6268+
pickle.dump(i, f)
6269+
torch.save(a, f)
6270+
f.seek(0)
6271+
j = pickle.load(f)
6272+
b = torch.load(f)
6273+
self.assertTrue(torch.equal(a, b))
6274+
self.assertEqual(i, j)
62666275

62676276
def test_serialization_offset_gzip(self):
6268-
f = tempfile.NamedTemporaryFile(delete=False)
6269-
self._test_serialization_offset(lambda: gzip.open(f.name, mode='w+b'),
6270-
filecontext_read_lambda=lambda: gzip.open(f.name, mode='r+b'))
6277+
a = torch.randn(5, 5)
6278+
i = 41
6279+
f1 = tempfile.NamedTemporaryFile(delete=False)
6280+
f2 = tempfile.NamedTemporaryFile(delete=False)
6281+
with open(f1.name, 'wb') as f:
6282+
pickle.dump(i, f)
6283+
torch.save(a, f)
6284+
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
6285+
shutil.copyfileobj(f_in, f_out)
6286+
6287+
with gzip.open(f2.name, 'rb') as f:
6288+
j = pickle.load(f)
6289+
b = torch.load(f)
6290+
self.assertTrue(torch.equal(a, b))
6291+
self.assertEqual(i, j)
62716292

62726293
def test_half_tensor(self):
62736294
x = torch.randn(5, 5).float()

0 commit comments

Comments
 (0)