Skip to content

Commit f23fb66

Browse files
philippslangfacebook-github-bot
authored andcommitted
Fix in file position logic: file descriptor and Python-side handle (#20270)
Summary: This addresses #18436 The logic replicates the essence of closing file descriptors in numpy: https://github.com/numpy/numpy/blob/bf20e3034085716c4559ec4bf31b23b6016f266c/numpy/core/include/numpy/npy_3kcompat.h#L278 This stores the position of the file descriptor before resetting it to the Python handle offset, then resets to the original position before exit. The Python-side handle is then updated to reflect the new position. Also added somewhat more demanding tests to cover this. Pull Request resolved: #20270 Differential Revision: D15275902 Pulled By: soumith fbshipit-source-id: 5ca8a52b61c7718d2e69571f72f80b1350b0acdb
1 parent c406bf2 commit f23fb66

File tree

3 files changed

+48
-23
lines changed

3 files changed

+48
-23
lines changed

test/test_torch.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9287,33 +9287,45 @@ def test_serialization_gzip(self):
92879287

92889288
def test_serialization_offset(self):
92899289
a = torch.randn(5, 5)
9290-
i = 41
9291-
for use_name in (False, True):
9292-
# Passing filename to torch.save(...) will cause the file to be opened twice,
9293-
# which is not supported on Windows
9294-
if sys.platform == "win32" and use_name:
9295-
continue
9296-
with tempfile.NamedTemporaryFile() as f:
9297-
handle = f if not use_name else f.name
9298-
pickle.dump(i, f)
9299-
torch.save(a, f)
9300-
f.seek(0)
9301-
j = pickle.load(f)
9302-
b = torch.load(f)
9303-
self.assertTrue(torch.equal(a, b))
9304-
self.assertEqual(i, j)
9290+
b = torch.randn(2, 2)
9291+
m = torch.nn.Conv2d(1, 1, (1, 3))
9292+
i, j = 41, 43
9293+
with tempfile.NamedTemporaryFile() as f:
9294+
pickle.dump(i, f)
9295+
torch.save(a, f)
9296+
pickle.dump(j, f)
9297+
torch.save(b, f)
9298+
torch.save(m, f)
9299+
f.seek(0)
9300+
i_loaded = pickle.load(f)
9301+
a_loaded = torch.load(f)
9302+
j_loaded = pickle.load(f)
9303+
b_loaded = torch.load(f)
9304+
m_loaded = torch.load(f)
9305+
self.assertTrue(torch.equal(a, a_loaded))
9306+
self.assertTrue(torch.equal(b, b_loaded))
9307+
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
9308+
self.assertEqual(i, i_loaded)
9309+
self.assertEqual(j, j_loaded)
93059310

93069311
def test_serialization_offset_filelike(self):
93079312
a = torch.randn(5, 5)
9308-
i = 41
9313+
b = torch.randn(2, 3)
9314+
i, j = 41, 43
93099315
with BytesIOContext() as f:
93109316
pickle.dump(i, f)
9311-
torch.save(a, f)
9317+
torch.save(a, f)
9318+
pickle.dump(j, f)
9319+
torch.save(b, f)
93129320
f.seek(0)
9313-
j = pickle.load(f)
9314-
b = torch.load(f)
9315-
self.assertTrue(torch.equal(a, b))
9316-
self.assertEqual(i, j)
9321+
i_loaded = pickle.load(f)
9322+
a_loaded = torch.load(f)
9323+
j_loaded = pickle.load(f)
9324+
b_loaded = torch.load(f)
9325+
self.assertTrue(torch.equal(a, a_loaded))
9326+
self.assertTrue(torch.equal(b, b_loaded))
9327+
self.assertEqual(i, i_loaded)
9328+
self.assertEqual(j, j_loaded)
93179329

93189330
def test_serialization_offset_gzip(self):
93199331
a = torch.randn(5, 5)

torch/csrc/generic/StorageMethods.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)
261261
}
262262

263263
// file is backed by a fd
264-
int fd = PyObject_AsFileDescriptor(file);
264+
const int fd = PyObject_AsFileDescriptor(file);
265+
const auto fd_original_pos = lseek(fd, 0, SEEK_CUR);
265266
if (offset != Py_None) {
266267
lseek(fd, THPUtils_unpackLong(offset), SEEK_SET);
267268
}
@@ -272,6 +273,17 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args)
272273
return nullptr;
273274
Py_INCREF(self);
274275

276+
// the file descriptor is returned to original position and
277+
// the file handle at python call-site needs updating to the
278+
// advanced postion
279+
const auto fd_current_pos = lseek(fd, 0, SEEK_CUR);
280+
lseek(fd, fd_original_pos, SEEK_SET);
281+
const auto seek_return = PyObject_CallMethod(file, "seek", "li", fd_current_pos, 0);
282+
if (seek_return == nullptr) {
283+
return nullptr;
284+
}
285+
Py_DECREF(seek_return);
286+
275287
return (PyObject *) self;
276288
END_HANDLE_TH_ERRORS
277289
}

torch/serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def persistent_load(saved_id):
579579
for key in deserialized_storage_keys:
580580
assert key in deserialized_objects
581581
deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
582-
offset = None
582+
if offset is not None:
583+
offset = f.tell()
583584

584585
return result

0 commit comments

Comments
 (0)