Skip to content

Commit bafec16

Browse files
li-roysoumith
authored andcommitted
support loading gzip (#6490)
* support loading gzip * address comments * address comments * fix lint * fix test for python2
1 parent 3481c6c commit bafec16

File tree

2 files changed

+112
-49
lines changed

2 files changed

+112
-49
lines changed

test/test_torch.py

Lines changed: 91 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import random
66
import operator
77
import copy
8+
import shutil
89
import torch
910
import torch.cuda
1011
import tempfile
1112
import unittest
1213
import warnings
1314
import pickle
15+
import gzip
1416
from torch.utils.dlpack import from_dlpack, to_dlpack
1517
from torch._utils import _rebuild_tensor
1618
from itertools import product, combinations
@@ -6169,7 +6171,7 @@ def test_parsing_intlist(self):
61696171
self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3)))
61706172
self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3))))
61716173

6172-
def _test_serialization(self, filecontext_lambda, test_use_filename=True):
6174+
def _test_serialization_data(self):
61736175
a = [torch.randn(5, 5).float() for i in range(2)]
61746176
b = [a[i % 2] for i in range(4)]
61756177
b += [a[0].storage()]
@@ -6179,68 +6181,115 @@ def _test_serialization(self, filecontext_lambda, test_use_filename=True):
61796181
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
61806182
b += [(t1.storage(), t1.storage(), t2.storage())]
61816183
b += [a[0].storage()[0:2]]
6182-
if test_use_filename:
6183-
use_name_options = (False, True)
6184-
else:
6185-
use_name_options = (False,)
6186-
for use_name in use_name_options:
6184+
return b
6185+
6186+
def _test_serialization_assert(self, b, c):
6187+
self.assertEqual(b, c, 0)
6188+
self.assertTrue(isinstance(c[0], torch.FloatTensor))
6189+
self.assertTrue(isinstance(c[1], torch.FloatTensor))
6190+
self.assertTrue(isinstance(c[2], torch.FloatTensor))
6191+
self.assertTrue(isinstance(c[3], torch.FloatTensor))
6192+
self.assertTrue(isinstance(c[4], torch.FloatStorage))
6193+
c[0].fill_(10)
6194+
self.assertEqual(c[0], c[2], 0)
6195+
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
6196+
c[1].fill_(20)
6197+
self.assertEqual(c[1], c[3], 0)
6198+
self.assertEqual(c[4][1:4], c[5], 0)
6199+
6200+
# check that serializing the same storage view object unpickles
6201+
# it as one object not two (and vice versa)
6202+
views = c[7]
6203+
self.assertEqual(views[0]._cdata, views[1]._cdata)
6204+
self.assertEqual(views[0], views[2])
6205+
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
6206+
6207+
rootview = c[8]
6208+
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
6209+
6210+
def test_serialization(self):
6211+
# Test serialization with a real file
6212+
b = self._test_serialization_data()
6213+
for use_name in (False, True):
61876214
# Passing filename to torch.save(...) will cause the file to be opened twice,
61886215
# which is not supported on Windows
61896216
if sys.platform == "win32" and use_name:
61906217
continue
6191-
with filecontext_lambda() as f:
6218+
with tempfile.NamedTemporaryFile() as f:
61926219
handle = f if not use_name else f.name
61936220
torch.save(b, handle)
61946221
f.seek(0)
61956222
c = torch.load(handle)
6196-
self.assertEqual(b, c, 0)
6197-
self.assertTrue(isinstance(c[0], torch.FloatTensor))
6198-
self.assertTrue(isinstance(c[1], torch.FloatTensor))
6199-
self.assertTrue(isinstance(c[2], torch.FloatTensor))
6200-
self.assertTrue(isinstance(c[3], torch.FloatTensor))
6201-
self.assertTrue(isinstance(c[4], torch.FloatStorage))
6202-
c[0].fill_(10)
6203-
self.assertEqual(c[0], c[2], 0)
6204-
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
6205-
c[1].fill_(20)
6206-
self.assertEqual(c[1], c[3], 0)
6207-
self.assertEqual(c[4][1:4], c[5], 0)
6208-
6209-
# check that serializing the same storage view object unpickles
6210-
# it as one object not two (and vice versa)
6211-
views = c[7]
6212-
self.assertEqual(views[0]._cdata, views[1]._cdata)
6213-
self.assertEqual(views[0], views[2])
6214-
self.assertNotEqual(views[0]._cdata, views[2]._cdata)
6215-
6216-
rootview = c[8]
6217-
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
6218-
6219-
def test_serialization(self):
6220-
# Test serialization with a real file
6221-
self._test_serialization(tempfile.NamedTemporaryFile)
6223+
self._test_serialization_assert(b, c)
62226224

62236225
def test_serialization_filelike(self):
62246226
# Test serialization (load and save) with a filelike object
6225-
self._test_serialization(BytesIOContext, test_use_filename=False)
6227+
b = self._test_serialization_data()
6228+
with BytesIOContext() as f:
6229+
torch.save(b, f)
6230+
f.seek(0)
6231+
c = torch.load(f)
6232+
self._test_serialization_assert(b, c)
6233+
6234+
def test_serialization_gzip(self):
6235+
# Test serialization with gzip file
6236+
b = self._test_serialization_data()
6237+
f1 = tempfile.NamedTemporaryFile(delete=False)
6238+
f2 = tempfile.NamedTemporaryFile(delete=False)
6239+
torch.save(b, f1)
6240+
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
6241+
shutil.copyfileobj(f_in, f_out)
6242+
6243+
with gzip.open(f2.name, 'rb') as f:
6244+
c = torch.load(f)
6245+
self._test_serialization_assert(b, c)
6246+
6247+
def test_serialization_offset(self):
6248+
a = torch.randn(5, 5)
6249+
i = 41
6250+
for use_name in (False, True):
6251+
# Passing filename to torch.save(...) will cause the file to be opened twice,
6252+
# which is not supported on Windows
6253+
if sys.platform == "win32" and use_name:
6254+
continue
6255+
with tempfile.NamedTemporaryFile() as f:
6256+
handle = f if not use_name else f.name
6257+
pickle.dump(i, f)
6258+
torch.save(a, f)
6259+
f.seek(0)
6260+
j = pickle.load(f)
6261+
b = torch.load(f)
6262+
self.assertTrue(torch.equal(a, b))
6263+
self.assertEqual(i, j)
62266264

6227-
def _test_serialization_offset(self, filecontext_lambda):
6265+
def test_serialization_offset_filelike(self):
62286266
a = torch.randn(5, 5)
62296267
i = 41
6230-
with tempfile.TemporaryFile() as f:
6268+
with BytesIOContext() as f:
62316269
pickle.dump(i, f)
62326270
torch.save(a, f)
62336271
f.seek(0)
62346272
j = pickle.load(f)
62356273
b = torch.load(f)
6236-
self.assertTrue(torch.equal(a, b))
6237-
self.assertEqual(i, j)
6274+
self.assertTrue(torch.equal(a, b))
6275+
self.assertEqual(i, j)
62386276

6239-
def test_serialization_offset(self):
6240-
self._test_serialization_offset(tempfile.TemporaryFile)
6277+
def test_serialization_offset_gzip(self):
6278+
a = torch.randn(5, 5)
6279+
i = 41
6280+
f1 = tempfile.NamedTemporaryFile(delete=False)
6281+
f2 = tempfile.NamedTemporaryFile(delete=False)
6282+
with open(f1.name, 'wb') as f:
6283+
pickle.dump(i, f)
6284+
torch.save(a, f)
6285+
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
6286+
shutil.copyfileobj(f_in, f_out)
62416287

6242-
def test_serialization_offset_filelike(self):
6243-
self._test_serialization_offset(BytesIOContext)
6288+
with gzip.open(f2.name, 'rb') as f:
6289+
j = pickle.load(f)
6290+
b = torch.load(f)
6291+
self.assertTrue(torch.equal(a, b))
6292+
self.assertEqual(i, j)
62446293

62456294
def test_half_tensor(self):
62466295
x = torch.randn(5, 5).float()

torch/serialization.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,22 @@ def _with_file_like(f, mode, body):
137137
f.close()
138138

139139

140-
def _is_real_file(f):
141-
"""Checks if f is backed by a real file (has a fileno)"""
140+
def _is_compressed_file(f):
141+
compress_modules = ['gzip']
142+
try:
143+
return f.__module__ in compress_modules
144+
except AttributeError:
145+
return False
146+
147+
148+
def _should_read_directly(f):
149+
"""
150+
Checks if f is a file that should be read directly. It should be read
151+
directly if it is backed by a real file (has a fileno) and is not a
152+
a compressed file (e.g. gzip)
153+
"""
154+
if _is_compressed_file(f):
155+
return False
142156
try:
143157
return f.fileno() >= 0
144158
except io.UnsupportedOperation:
@@ -251,7 +265,7 @@ def persistent_id(obj):
251265
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
252266
f.flush()
253267
for key in serialized_storage_keys:
254-
serialized_storages[key]._write_file(f, _is_real_file(f))
268+
serialized_storages[key]._write_file(f, _should_read_directly(f))
255269

256270

257271
def load(f, map_location=None, pickle_module=pickle):
@@ -465,8 +479,8 @@ def persistent_load(saved_id):
465479
else:
466480
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
467481

468-
f_is_real_file = _is_real_file(f)
469-
if f_is_real_file and f.tell() == 0:
482+
f_should_read_directly = _should_read_directly(f)
483+
if f_should_read_directly and f.tell() == 0:
470484
# legacy_load requires that f has fileno()
471485
# only if offset is zero we can attempt the legacy tar file loader
472486
try:
@@ -489,10 +503,10 @@ def persistent_load(saved_id):
489503

490504
deserialized_storage_keys = pickle_module.load(f)
491505

492-
offset = f.tell() if f_is_real_file else None
506+
offset = f.tell() if f_should_read_directly else None
493507
for key in deserialized_storage_keys:
494508
assert key in deserialized_objects
495-
deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
509+
deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
496510
offset = None
497511

498512
return result

0 commit comments

Comments
 (0)