Skip to content

Commit c4b0db5

Browse files
philippslangsoumith
authored andcommitted
Remove hard file offset reset in load() (#3695)
* improved file offset logic * load offset test * whitespace * needless exception handling * test integer in binary
1 parent 2453bc2 commit c4b0db5

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

test/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99
import unittest
1010
import warnings
11+
import pickle
1112
from torch.utils.dlpack import from_dlpack, to_dlpack
1213
from itertools import product, combinations
1314
from common import TestCase, iter_indices, TEST_NUMPY, run_tests, download_file, skipIfNoLapack, \
@@ -4134,6 +4135,18 @@ def test_serialization(self):
41344135
rootview = c[8]
41354136
self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
41364137

4138+
def test_serialization_offset(self):
4139+
a = torch.randn(5, 5)
4140+
i = 41
4141+
with tempfile.TemporaryFile() as f:
4142+
pickle.dump(i, f)
4143+
torch.save(a, f)
4144+
f.seek(0)
4145+
j = pickle.load(f)
4146+
b = torch.load(f)
4147+
self.assertTrue(torch.equal(a, b))
4148+
self.assertEqual(i, j)
4149+
41374150
def test_half_tensor(self):
41384151
x = torch.randn(5, 5).float()
41394152
y = torch.randn(5, 5).float()

torch/serialization.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,15 @@ def persistent_load(saved_id):
389389
else:
390390
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
391391

392-
# try the legacy loader first, which only works if f is a tarfile
393-
try:
394-
return legacy_load(f)
395-
except tarfile.TarError:
396-
pass
392+
foffset = f.tell()
393+
if foffset == 0:
394+
# only if offset is zero we can attempt the legacy tar file loader
395+
try:
396+
return legacy_load(f)
397+
except tarfile.TarError:
398+
# if not a tarfile, reset file offset and proceed
399+
f.seek(foffset)
397400

398-
f.seek(0)
399401
magic_number = pickle_module.load(f)
400402
if magic_number != MAGIC_NUMBER:
401403
raise RuntimeError("Invalid magic number; corrupt file?")

0 commit comments

Comments
 (0)