Skip to content

Commit 902d57b

Browse files
ssnlsoumith
authored andcommitted
Cherry pick dataloader issue fix to 0.3.1 (#5140)
* cherry pick Fix multiprocessing and dataloader tests on Windows (#4453) * cherry pick Dataloader issues #4643 * fix common IS_WINDOWS
1 parent db9a700 commit 902d57b

File tree

5 files changed

+156
-61
lines changed

5 files changed

+156
-61
lines changed

test/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def run_tests():
3232
unittest.main(argv=UNITTEST_ARGS)
3333

34+
IS_WINDOWS = sys.platform == "win32"
3435

3536
TEST_NUMPY = True
3637
try:
@@ -332,6 +333,8 @@ def accept_output(update_type):
332333
self.assertEqual(s, expected)
333334

334335
if sys.version_info < (3, 2):
336+
# assertRegexpMatches renamed assertRegex in 3.2
337+
assertRegex = unittest.TestCase.assertRegexpMatches
335338
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
336339
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
337340

test/test_dataloader.py

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import math
22
import sys
3+
import errno
4+
import os
35
import ctypes
6+
import signal
47
import torch
58
import time
69
import traceback
710
import unittest
811
from torch import multiprocessing
912
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
1013
from torch.utils.data.dataset import random_split
11-
from torch.utils.data.dataloader import default_collate
12-
from common import TestCase, run_tests, TEST_NUMPY
14+
from torch.utils.data.dataloader import default_collate, ExceptionWrapper
15+
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
1316
from common_nn import TEST_CUDA
1417

1518

19+
JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5
20+
21+
1622
class TestDatasetRandomSplit(TestCase):
1723
def test_lengths_must_equal_datset_size(self):
1824
with self.assertRaises(ValueError):
@@ -100,6 +106,46 @@ def test_add_dataset(self):
100106
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
101107

102108

109+
# Stores the first encountered exception in .exception.
110+
# Inspired by https://stackoverflow.com/a/33599967
111+
class ErrorTrackingProcess(multiprocessing.Process):
112+
113+
def __init__(self, *args, **kwargs):
114+
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
115+
self._pconn, self._cconn = multiprocessing.Pipe()
116+
self._exception = None
117+
118+
def run(self):
119+
# Disable stderr printing from os level, and make workers not printing
120+
# to stderr.
121+
# Can't use sys.stderr.close, otherwise Python `raise` will error with
122+
# ValueError: I/O operation on closed file.
123+
os.close(sys.stderr.fileno())
124+
try:
125+
super(ErrorTrackingProcess, self).run()
126+
self._cconn.send(None)
127+
except Exception as e:
128+
self._cconn.send(ExceptionWrapper(sys.exc_info()))
129+
raise
130+
131+
@property
132+
def exception(self):
133+
if self._pconn.poll():
134+
self._exception = self._pconn.recv()
135+
if self._exception is None:
136+
return None
137+
else:
138+
return self._exception.exc_type(self._exception.exc_msg)
139+
140+
# ESRCH means that os.kill can't finds alive proc
141+
def send_signal(self, signum, ignore_ESRCH=False):
142+
try:
143+
os.kill(self.pid, signum)
144+
except OSError as e:
145+
if not ignore_ESRCH or e.errno != errno.ESRCH:
146+
raise
147+
148+
103149
class ErrorDataset(Dataset):
104150

105151
def __init__(self, size):
@@ -170,6 +216,23 @@ def __len__(self):
170216
return self.size
171217

172218

219+
def _test_timeout():
220+
dataset = SleepDataset(10, 10)
221+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
222+
_ = next(iter(dataloader))
223+
224+
225+
def _test_segfault():
226+
dataset = SegfaultDataset(10)
227+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
228+
_ = next(iter(dataloader))
229+
230+
231+
# test custom init function
232+
def init_fn(worker_id):
233+
torch.manual_seed(12345)
234+
235+
173236
class TestDataLoader(TestCase):
174237

175238
def setUp(self):
@@ -248,34 +311,30 @@ def test_multiple_dataloaders(self):
248311

249312
@unittest.skipIf(True, "flaky test")
250313
def test_segfault(self):
251-
def _test_segfault():
252-
sys.stderr.close()
253-
dataset = SegfaultDataset(10)
254-
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
255-
_ = next(iter(dataloader))
256-
257-
p = multiprocessing.Process(target=_test_segfault)
314+
p = ErrorTrackingProcess(target=_test_segfault)
258315
p.start()
259-
p.join(1.0)
316+
p.join(JOIN_TIMEOUT)
260317
try:
261318
self.assertFalse(p.is_alive())
262319
self.assertNotEqual(p.exitcode, 0)
320+
if IS_WINDOWS:
321+
self.assertIsInstance(p.exception, OSError)
322+
self.assertRegex(str(p.exception), r'access violation reading ')
323+
else:
324+
self.assertIsInstance(p.exception, RuntimeError)
325+
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
263326
finally:
264327
p.terminate()
265328

266329
def test_timeout(self):
267-
def _test_timeout():
268-
sys.stderr.close()
269-
dataset = SleepDataset(10, 10)
270-
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
271-
_ = next(iter(dataloader))
272-
273-
p = multiprocessing.Process(target=_test_timeout)
330+
p = ErrorTrackingProcess(target=_test_timeout)
274331
p.start()
275-
p.join(3.0)
332+
p.join(JOIN_TIMEOUT)
276333
try:
277334
self.assertFalse(p.is_alive())
278335
self.assertNotEqual(p.exitcode, 0)
336+
self.assertIsInstance(p.exception, RuntimeError)
337+
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
279338
finally:
280339
p.terminate()
281340

@@ -289,10 +348,6 @@ def test_worker_seed(self):
289348
self.assertEqual(len(seeds), num_workers)
290349

291350
def test_worker_init_fn(self):
292-
# test custom init function
293-
def init_fn(worker_id):
294-
torch.manual_seed(12345)
295-
296351
dataset = SeedDataset(4)
297352
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
298353
worker_init_fn=init_fn)
@@ -381,10 +436,10 @@ def test_partial_workers(self):
381436
break
382437
del loader
383438
for w in workers:
384-
w.join(1.0) # timeout of one second
439+
w.join(JOIN_TIMEOUT)
385440
self.assertFalse(w.is_alive(), 'subprocess not terminated')
386441
self.assertEqual(w.exitcode, 0)
387-
worker_manager_thread.join(1.0)
442+
worker_manager_thread.join(JOIN_TIMEOUT)
388443
self.assertFalse(worker_manager_thread.is_alive())
389444

390445
def test_len(self):

test/test_multiprocessing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
import torch.multiprocessing as mp
1212
from torch.autograd import Variable
1313
from torch.nn import Parameter
14-
from common import TestCase, run_tests
14+
from common import TestCase, run_tests, IS_WINDOWS
1515

1616

1717
TEST_REPEATS = 30
1818
HAS_SHM_FILES = os.path.isdir('/dev/shm')
1919
TEST_CUDA_IPC = torch.cuda.is_available() and \
2020
sys.version_info[0] == 3 and \
21-
sys.platform != 'darwin'
21+
sys.platform != 'darwin' and \
22+
sys.platform != 'win32'
2223
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
2324

2425

@@ -318,6 +319,7 @@ def test_cuda_small_tensors(self):
318319
self.assertEqual(tensor_size, 5)
319320
self.assertEqual(storage_size, 5)
320321

322+
@unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)')
321323
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
322324
def test_cuda_bad_call(self):
323325
# Initialize CUDA

torch/csrc/DataLoader.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \
4040

4141
// signal(2) is really not portable. So use sigaction.
4242
// http://man7.org/linux/man-pages/man2/signal.2.html
43-
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
43+
static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
4444
{
4545
struct sigaction sa;
4646
sa.sa_sigaction = handler;
@@ -56,10 +56,34 @@ SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered
5656
"This might be caused by insufficient shared memory (shm).\n");
5757
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
5858

59+
// When an error happend in DataLoader methods and Python starts to exit, the
60+
// error trace will keep the loader alive, and Python may kill the children
61+
// processes first before deleting the loader object. Then the cleaning up
62+
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
63+
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
64+
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
65+
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
66+
// again, and then it defeats the whole purpose.
67+
static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx)
68+
{
69+
if (info->si_pid == getppid()) {
70+
_exit(EXIT_SUCCESS);
71+
}
72+
struct sigaction sa;
73+
sa.sa_handler = SIG_DFL;
74+
sa.sa_flags = 0;
75+
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, NULL) != 0) {
76+
_exit(EXIT_FAILURE);
77+
} else {
78+
raise(SIGTERM);
79+
}
80+
}
81+
5982
PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
6083
HANDLE_TH_ERRORS
6184
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
6285
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
86+
setSignalHandler(SIGTERM, &handler_SIGTERM, NULL);
6387
Py_RETURN_TRUE;
6488
END_HANDLE_TH_ERRORS
6589
}
@@ -70,33 +94,33 @@ PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
7094
HANDLE_TH_ERRORS
7195
int error;
7296
std::set<pid_t> *pid_set;
73-
pid_t pid;
97+
pid_t worker_pid;
7498
siginfo_t infop;
7599

76100
// Only check the pids we care about
77101
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
78102
pid_set = &(it->second);
79103
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
80-
pid = *pid_it;
104+
worker_pid = *pid_it;
81105
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
82106
// and other handlers can get whatever info they want about the child.
83107
infop.si_pid = 0;
84-
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
108+
error = waitid(P_PID, worker_pid, &infop, WEXITED|WNOHANG|WNOWAIT);
85109
// ignore errors and case with no waitable child
86110
if (error < 0 || infop.si_pid == 0)
87111
continue;
88-
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
112+
if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error
89113
std::ostringstream oss;
90-
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
91-
<< "with exit code " << infop.si_status << ".";
114+
oss << "DataLoader worker (pid " << worker_pid << ") exited "
115+
<< "unexpectedly with exit code " << infop.si_status << ".";
92116
// This is necessary. Otherwise, the runtime error will kill the other
93117
// workers, and trigger this again.
94118
pid_set->clear();
95119
throw std::runtime_error(oss.str());
96120
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
97121
std::ostringstream oss;
98-
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
99-
<< strsignal(infop.si_status) << ".";
122+
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
123+
<< "by signal: " << strsignal(infop.si_status) << ".";
100124
// This is necessary. Otherwise, the runtime error will kill the other
101125
// workers, and trigger this again.
102126
pid_set->clear();

0 commit comments

Comments
 (0)