Skip to content

Commit 64a9eca

Browse files
ssnlsoumith
authored andcommitted
Dataloader issues (#4643)
* EINTR and kill by loader fix * addressed @apaszke 's comments * remove EINTR handling and add test if we are in main thread before setting SIGCHLD
1 parent 967bceb commit 64a9eca

File tree

4 files changed

+125
-41
lines changed

4 files changed

+125
-41
lines changed

test/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ def accept_output(update_type):
378378
self.assertEqual(s, expected)
379379

380380
if sys.version_info < (3, 2):
381+
# assertRegexpMatches renamed assertRegex in 3.2
382+
assertRegex = unittest.TestCase.assertRegexpMatches
381383
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
382384
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
383385

test/test_dataloader.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import math
22
import sys
3+
import errno
34
import os
45
import ctypes
6+
import signal
57
import torch
68
import time
79
import traceback
810
import unittest
911
from torch import multiprocessing
1012
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
1113
from torch.utils.data.dataset import random_split
12-
from torch.utils.data.dataloader import default_collate
14+
from torch.utils.data.dataloader import default_collate, ExceptionWrapper
1315
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
1416
from common_nn import TEST_CUDA
1517

16-
JOIN_TIMEOUT = 14.0 if IS_WINDOWS else 1.5
18+
19+
JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5
1720

1821

1922
class TestDatasetRandomSplit(TestCase):
@@ -103,6 +106,46 @@ def test_add_dataset(self):
103106
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
104107

105108

109+
# Stores the first encountered exception in .exception.
110+
# Inspired by https://stackoverflow.com/a/33599967
111+
class ErrorTrackingProcess(multiprocessing.Process):
112+
113+
def __init__(self, *args, **kwargs):
114+
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
115+
self._pconn, self._cconn = multiprocessing.Pipe()
116+
self._exception = None
117+
118+
def run(self):
119+
# Disable stderr printing from os level, and make workers not printing
120+
# to stderr.
121+
# Can't use sys.stderr.close, otherwise Python `raise` will error with
122+
# ValueError: I/O operation on closed file.
123+
os.close(sys.stderr.fileno())
124+
try:
125+
super(ErrorTrackingProcess, self).run()
126+
self._cconn.send(None)
127+
except Exception as e:
128+
self._cconn.send(ExceptionWrapper(sys.exc_info()))
129+
raise
130+
131+
@property
132+
def exception(self):
133+
if self._pconn.poll():
134+
self._exception = self._pconn.recv()
135+
if self._exception is None:
136+
return None
137+
else:
138+
return self._exception.exc_type(self._exception.exc_msg)
139+
140+
# ESRCH means that os.kill can't finds alive proc
141+
def send_signal(self, signum, ignore_ESRCH=False):
142+
try:
143+
os.kill(self.pid, signum)
144+
except OSError as e:
145+
if not ignore_ESRCH or e.errno != errno.ESRCH:
146+
raise
147+
148+
106149
class ErrorDataset(Dataset):
107150

108151
def __init__(self, size):
@@ -175,16 +218,12 @@ def __len__(self):
175218

176219

177220
def _test_timeout():
178-
os.close(sys.stderr.fileno())
179-
sys.stderr.close()
180221
dataset = SleepDataset(10, 10)
181222
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
182223
_ = next(iter(dataloader))
183224

184225

185226
def _test_segfault():
186-
os.close(sys.stderr.fileno())
187-
sys.stderr.close()
188227
dataset = SegfaultDataset(10)
189228
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
190229
_ = next(iter(dataloader))
@@ -272,22 +311,30 @@ def test_multiple_dataloaders(self):
272311
next(loader2_it)
273312

274313
def test_segfault(self):
275-
p = multiprocessing.Process(target=_test_segfault)
314+
p = ErrorTrackingProcess(target=_test_segfault)
276315
p.start()
277316
p.join(JOIN_TIMEOUT)
278317
try:
279318
self.assertFalse(p.is_alive())
280319
self.assertNotEqual(p.exitcode, 0)
320+
if IS_WINDOWS:
321+
self.assertIsInstance(p.exception, OSError)
322+
self.assertRegex(str(p.exception), r'access violation reading ')
323+
else:
324+
self.assertIsInstance(p.exception, RuntimeError)
325+
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
281326
finally:
282327
p.terminate()
283328

284329
def test_timeout(self):
285-
p = multiprocessing.Process(target=_test_timeout)
330+
p = ErrorTrackingProcess(target=_test_timeout)
286331
p.start()
287-
p.join(3.0 + JOIN_TIMEOUT)
332+
p.join(JOIN_TIMEOUT)
288333
try:
289334
self.assertFalse(p.is_alive())
290335
self.assertNotEqual(p.exitcode, 0)
336+
self.assertIsInstance(p.exception, RuntimeError)
337+
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
291338
finally:
292339
p.terminate()
293340

torch/csrc/DataLoader.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \
4343

4444
// signal(2) is really not portable. So use sigaction.
4545
// http://man7.org/linux/man-pages/man2/signal.2.html
46-
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
46+
static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
4747
{
4848
struct sigaction sa;
4949
sa.sa_sigaction = handler;
@@ -59,10 +59,34 @@ SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered
5959
"This might be caused by insufficient shared memory (shm).\n");
6060
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
6161

62+
// When an error happend in DataLoader methods and Python starts to exit, the
63+
// error trace will keep the loader alive, and Python may kill the children
64+
// processes first before deleting the loader object. Then the cleaning up
65+
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
66+
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
67+
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
68+
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
69+
// again, and then it defeats the whole purpose.
70+
static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx)
71+
{
72+
if (info->si_pid == getppid()) {
73+
_exit(EXIT_SUCCESS);
74+
}
75+
struct sigaction sa;
76+
sa.sa_handler = SIG_DFL;
77+
sa.sa_flags = 0;
78+
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, NULL) != 0) {
79+
_exit(EXIT_FAILURE);
80+
} else {
81+
raise(SIGTERM);
82+
}
83+
}
84+
6285
PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
6386
HANDLE_TH_ERRORS
6487
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
6588
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
89+
setSignalHandler(SIGTERM, &handler_SIGTERM, NULL);
6690
Py_RETURN_TRUE;
6791
END_HANDLE_TH_ERRORS
6892
}
@@ -73,33 +97,33 @@ PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
7397
HANDLE_TH_ERRORS
7498
int error;
7599
std::set<pid_t> *pid_set;
76-
pid_t pid;
100+
pid_t worker_pid;
77101
siginfo_t infop;
78102

79103
// Only check the pids we care about
80104
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
81105
pid_set = &(it->second);
82106
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
83-
pid = *pid_it;
107+
worker_pid = *pid_it;
84108
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
85109
// and other handlers can get whatever info they want about the child.
86110
infop.si_pid = 0;
87-
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
111+
error = waitid(P_PID, worker_pid, &infop, WEXITED|WNOHANG|WNOWAIT);
88112
// ignore errors and case with no waitable child
89113
if (error < 0 || infop.si_pid == 0)
90114
continue;
91-
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
115+
if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error
92116
std::ostringstream oss;
93-
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
94-
<< "with exit code " << infop.si_status << ".";
117+
oss << "DataLoader worker (pid " << worker_pid << ") exited "
118+
<< "unexpectedly with exit code " << infop.si_status << ".";
95119
// This is necessary. Otherwise, the runtime error will kill the other
96120
// workers, and trigger this again.
97121
pid_set->clear();
98122
throw std::runtime_error(oss.str());
99123
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
100124
std::ostringstream oss;
101-
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
102-
<< strsignal(infop.si_status) << ".";
125+
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
126+
<< "by signal: " << strsignal(infop.si_status) << ".";
103127
// This is necessary. Otherwise, the runtime error will kill the other
104128
// workers, and trigger this again.
105129
pid_set->clear();

torch/utils/data/dataloader.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,28 @@
88
import collections
99
import re
1010
import sys
11-
import traceback
1211
import threading
12+
import traceback
1313
from torch._six import string_classes, int_classes
1414

15-
1615
if sys.version_info[0] == 2:
1716
import Queue as queue
1817
else:
1918
import queue
2019

2120

22-
_use_shared_memory = False
23-
"""Whether to use shared memory in default_collate"""
24-
25-
2621
class ExceptionWrapper(object):
27-
"Wraps an exception plus traceback to communicate across threads"
22+
r"Wraps an exception plus traceback to communicate across threads"
2823

2924
def __init__(self, exc_info):
3025
self.exc_type = exc_info[0]
3126
self.exc_msg = "".join(traceback.format_exception(*exc_info))
3227

3328

29+
_use_shared_memory = False
30+
"""Whether to use shared memory in default_collate"""
31+
32+
3433
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
3534
global _use_shared_memory
3635
_use_shared_memory = True
@@ -157,7 +156,11 @@ def pin_memory_batch(batch):
157156

158157

159158
def _set_SIGCHLD_handler():
160-
if sys.platform == 'win32': # Windows doesn't support SIGCHLD handler
159+
# Windows doesn't support SIGCHLD handler
160+
if sys.platform == 'win32':
161+
return
162+
# can't set signal in child threads
163+
if not isinstance(threading.current_thread(), threading._MainThread):
161164
return
162165
global _SIGCHLD_handler_set
163166
if _SIGCHLD_handler_set:
@@ -212,10 +215,15 @@ def __init__(self, loader):
212215

213216
if self.pin_memory or self.timeout > 0:
214217
self.data_queue = queue.Queue()
218+
if self.pin_memory:
219+
maybe_device_id = torch.cuda.current_device()
220+
else:
221+
# do not initialize cuda context if not necessary
222+
maybe_device_id = None
215223
self.worker_manager_thread = threading.Thread(
216224
target=_worker_manager_loop,
217225
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
218-
torch.cuda.current_device()))
226+
maybe_device_id))
219227
self.worker_manager_thread.daemon = True
220228
self.worker_manager_thread.start()
221229
else:
@@ -239,7 +247,7 @@ def __len__(self):
239247
def _get_batch(self):
240248
if self.timeout > 0:
241249
try:
242-
return self.data_queue.get(True, self.timeout)
250+
return self.data_queue.get(timeout=self.timeout)
243251
except queue.Empty:
244252
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
245253
else:
@@ -302,17 +310,20 @@ def __getstate__(self):
302310
raise NotImplementedError("DataLoaderIterator cannot be pickled")
303311

304312
def _shutdown_workers(self):
305-
if not self.shutdown:
306-
self.shutdown = True
307-
self.done_event.set()
308-
# if worker_manager_thread is waiting to put
309-
while not self.data_queue.empty():
310-
self.data_queue.get()
311-
for _ in self.workers:
312-
self.index_queue.put(None)
313-
# done_event should be sufficient to exit worker_manager_thread, but
314-
# be safe here and put another None
315-
self.worker_result_queue.put(None)
313+
try:
314+
if not self.shutdown:
315+
self.shutdown = True
316+
self.done_event.set()
317+
# if worker_manager_thread is waiting to put
318+
while not self.data_queue.empty():
319+
self.data_queue.get()
320+
for _ in self.workers:
321+
self.index_queue.put(None)
322+
# done_event should be sufficient to exit worker_manager_thread,
323+
# but be safe here and put another None
324+
self.worker_result_queue.put(None)
325+
finally:
326+
# removes pids no matter what
316327
if self.worker_pids_set:
317328
_remove_worker_pids(id(self))
318329
self.worker_pids_set = False
@@ -351,8 +362,8 @@ class DataLoader(object):
351362
timeout (numeric, optional): if positive, the timeout value for collecting a batch
352363
from workers. Should always be non-negative. (default: 0)
353364
worker_init_fn (callable, optional): If not None, this will be called on each
354-
worker subprocess with the worker id as input, after seeding and before data
355-
loading. (default: None)
365+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
366+
input, after seeding and before data loading. (default: None)
356367
357368
.. note:: By default, each worker will have its PyTorch seed set to
358369
``base_seed + worker_id``, where ``base_seed`` is a long generated

0 commit comments

Comments
 (0)