Skip to content

Commit cab6656

Browse files
committed
cherry pick Dataloader issues pytorch#4643
1 parent 87e82d0 commit cab6656

File tree

4 files changed

+130
-45
lines changed

4 files changed

+130
-45
lines changed

test/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ def accept_output(update_type):
332332
self.assertEqual(s, expected)
333333

334334
if sys.version_info < (3, 2):
335+
# assertRegexpMatches renamed assertRegex in 3.2
336+
assertRegex = unittest.TestCase.assertRegexpMatches
335337
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
336338
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
337339

test/test_dataloader.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import math
22
import sys
3+
import errno
4+
import os
35
import ctypes
6+
import signal
47
import torch
58
import time
69
import traceback
710
import unittest
811
from torch import multiprocessing
912
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
1013
from torch.utils.data.dataset import random_split
11-
from torch.utils.data.dataloader import default_collate
12-
from common import TestCase, run_tests, TEST_NUMPY
14+
from torch.utils.data.dataloader import default_collate, ExceptionWrapper
15+
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
1316
from common_nn import TEST_CUDA
1417

15-
JOIN_TIMEOUT = 14.0 if IS_WINDOWS else 1.5
18+
19+
JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5
1620

1721

1822
class TestDatasetRandomSplit(TestCase):
@@ -102,6 +106,46 @@ def test_add_dataset(self):
102106
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
103107

104108

109+
# Stores the first encountered exception in .exception.
110+
# Inspired by https://stackoverflow.com/a/33599967
111+
class ErrorTrackingProcess(multiprocessing.Process):
112+
113+
def __init__(self, *args, **kwargs):
114+
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
115+
self._pconn, self._cconn = multiprocessing.Pipe()
116+
self._exception = None
117+
118+
def run(self):
119+
# Disable stderr printing from os level, and make workers not printing
120+
# to stderr.
121+
# Can't use sys.stderr.close, otherwise Python `raise` will error with
122+
# ValueError: I/O operation on closed file.
123+
os.close(sys.stderr.fileno())
124+
try:
125+
super(ErrorTrackingProcess, self).run()
126+
self._cconn.send(None)
127+
except Exception as e:
128+
self._cconn.send(ExceptionWrapper(sys.exc_info()))
129+
raise
130+
131+
@property
132+
def exception(self):
133+
if self._pconn.poll():
134+
self._exception = self._pconn.recv()
135+
if self._exception is None:
136+
return None
137+
else:
138+
return self._exception.exc_type(self._exception.exc_msg)
139+
140+
# ESRCH means that os.kill can't finds alive proc
141+
def send_signal(self, signum, ignore_ESRCH=False):
142+
try:
143+
os.kill(self.pid, signum)
144+
except OSError as e:
145+
if not ignore_ESRCH or e.errno != errno.ESRCH:
146+
raise
147+
148+
105149
class ErrorDataset(Dataset):
106150

107151
def __init__(self, size):
@@ -173,16 +217,12 @@ def __len__(self):
173217

174218

175219
def _test_timeout():
176-
os.close(sys.stderr.fileno())
177-
sys.stderr.close()
178220
dataset = SleepDataset(10, 10)
179221
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
180222
_ = next(iter(dataloader))
181223

182224

183225
def _test_segfault():
184-
os.close(sys.stderr.fileno())
185-
sys.stderr.close()
186226
dataset = SegfaultDataset(10)
187227
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
188228
_ = next(iter(dataloader))
@@ -271,22 +311,30 @@ def test_multiple_dataloaders(self):
271311

272312
@unittest.skipIf(True, "flaky test")
273313
def test_segfault(self):
274-
p = multiprocessing.Process(target=_test_segfault)
314+
p = ErrorTrackingProcess(target=_test_segfault)
275315
p.start()
276316
p.join(JOIN_TIMEOUT)
277317
try:
278318
self.assertFalse(p.is_alive())
279319
self.assertNotEqual(p.exitcode, 0)
320+
if IS_WINDOWS:
321+
self.assertIsInstance(p.exception, OSError)
322+
self.assertRegex(str(p.exception), r'access violation reading ')
323+
else:
324+
self.assertIsInstance(p.exception, RuntimeError)
325+
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
280326
finally:
281327
p.terminate()
282328

283329
def test_timeout(self):
284-
p = multiprocessing.Process(target=_test_timeout)
330+
p = ErrorTrackingProcess(target=_test_timeout)
285331
p.start()
286-
p.join(3.0 + JOIN_TIMEOUT)
332+
p.join(JOIN_TIMEOUT)
287333
try:
288334
self.assertFalse(p.is_alive())
289335
self.assertNotEqual(p.exitcode, 0)
336+
self.assertIsInstance(p.exception, RuntimeError)
337+
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
290338
finally:
291339
p.terminate()
292340

torch/csrc/DataLoader.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \
4040

4141
// signal(2) is really not portable. So use sigaction.
4242
// http://man7.org/linux/man-pages/man2/signal.2.html
43-
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
43+
static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
4444
{
4545
struct sigaction sa;
4646
sa.sa_sigaction = handler;
@@ -56,10 +56,34 @@ SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered
5656
"This might be caused by insufficient shared memory (shm).\n");
5757
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
5858

59+
// When an error happend in DataLoader methods and Python starts to exit, the
60+
// error trace will keep the loader alive, and Python may kill the children
61+
// processes first before deleting the loader object. Then the cleaning up
62+
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
63+
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
64+
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
65+
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
66+
// again, and then it defeats the whole purpose.
67+
static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx)
68+
{
69+
if (info->si_pid == getppid()) {
70+
_exit(EXIT_SUCCESS);
71+
}
72+
struct sigaction sa;
73+
sa.sa_handler = SIG_DFL;
74+
sa.sa_flags = 0;
75+
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, NULL) != 0) {
76+
_exit(EXIT_FAILURE);
77+
} else {
78+
raise(SIGTERM);
79+
}
80+
}
81+
5982
PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
6083
HANDLE_TH_ERRORS
6184
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
6285
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
86+
setSignalHandler(SIGTERM, &handler_SIGTERM, NULL);
6387
Py_RETURN_TRUE;
6488
END_HANDLE_TH_ERRORS
6589
}
@@ -70,33 +94,33 @@ PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
7094
HANDLE_TH_ERRORS
7195
int error;
7296
std::set<pid_t> *pid_set;
73-
pid_t pid;
97+
pid_t worker_pid;
7498
siginfo_t infop;
7599

76100
// Only check the pids we care about
77101
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
78102
pid_set = &(it->second);
79103
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
80-
pid = *pid_it;
104+
worker_pid = *pid_it;
81105
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
82106
// and other handlers can get whatever info they want about the child.
83107
infop.si_pid = 0;
84-
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
108+
error = waitid(P_PID, worker_pid, &infop, WEXITED|WNOHANG|WNOWAIT);
85109
// ignore errors and case with no waitable child
86110
if (error < 0 || infop.si_pid == 0)
87111
continue;
88-
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
112+
if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error
89113
std::ostringstream oss;
90-
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
91-
<< "with exit code " << infop.si_status << ".";
114+
oss << "DataLoader worker (pid " << worker_pid << ") exited "
115+
<< "unexpectedly with exit code " << infop.si_status << ".";
92116
// This is necessary. Otherwise, the runtime error will kill the other
93117
// workers, and trigger this again.
94118
pid_set->clear();
95119
throw std::runtime_error(oss.str());
96120
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
97121
std::ostringstream oss;
98-
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
99-
<< strsignal(infop.si_status) << ".";
122+
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
123+
<< "by signal: " << strsignal(infop.si_status) << ".";
100124
// This is necessary. Otherwise, the runtime error will kill the other
101125
// workers, and trigger this again.
102126
pid_set->clear();

torch/utils/data/dataloader.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,28 @@
88
import collections
99
import re
1010
import sys
11-
import traceback
1211
import threading
12+
import traceback
1313
from torch._six import string_classes, int_classes
1414

15-
1615
if sys.version_info[0] == 2:
1716
import Queue as queue
1817
else:
1918
import queue
2019

2120

22-
_use_shared_memory = False
23-
"""Whether to use shared memory in default_collate"""
24-
25-
2621
class ExceptionWrapper(object):
27-
"Wraps an exception plus traceback to communicate across threads"
22+
r"Wraps an exception plus traceback to communicate across threads"
2823

2924
def __init__(self, exc_info):
3025
self.exc_type = exc_info[0]
3126
self.exc_msg = "".join(traceback.format_exception(*exc_info))
3227

3328

29+
_use_shared_memory = False
30+
"""Whether to use shared memory in default_collate"""
31+
32+
3433
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
3534
global _use_shared_memory
3635
_use_shared_memory = True
@@ -157,7 +156,11 @@ def pin_memory_batch(batch):
157156

158157

159158
def _set_SIGCHLD_handler():
160-
if sys.platform == 'win32': # Windows doesn't support SIGCHLD handler
159+
# Windows doesn't support SIGCHLD handler
160+
if sys.platform == 'win32':
161+
return
162+
# can't set signal in child threads
163+
if not isinstance(threading.current_thread(), threading._MainThread):
161164
return
162165
global _SIGCHLD_handler_set
163166
if _SIGCHLD_handler_set:
@@ -212,10 +215,15 @@ def __init__(self, loader):
212215

213216
if self.pin_memory or self.timeout > 0:
214217
self.data_queue = queue.Queue()
218+
if self.pin_memory:
219+
maybe_device_id = torch.cuda.current_device()
220+
else:
221+
# do not initialize cuda context if not necessary
222+
maybe_device_id = None
215223
self.worker_manager_thread = threading.Thread(
216224
target=_worker_manager_loop,
217225
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
218-
torch.cuda.current_device()))
226+
maybe_device_id))
219227
self.worker_manager_thread.daemon = True
220228
self.worker_manager_thread.start()
221229
else:
@@ -239,7 +247,7 @@ def __len__(self):
239247
def _get_batch(self):
240248
if self.timeout > 0:
241249
try:
242-
return self.data_queue.get(True, self.timeout)
250+
return self.data_queue.get(timeout=self.timeout)
243251
except queue.Empty:
244252
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
245253
else:
@@ -302,20 +310,23 @@ def __getstate__(self):
302310
raise NotImplementedError("DataLoaderIterator cannot be pickled")
303311

304312
def _shutdown_workers(self):
305-
if not self.shutdown:
306-
self.shutdown = True
307-
self.done_event.set()
308-
# if worker_manager_thread is waiting to put
309-
while not self.data_queue.empty():
310-
self.data_queue.get()
311-
for _ in self.workers:
312-
self.index_queue.put(None)
313-
# done_event should be sufficient to exit worker_manager_thread, but
314-
# be safe here and put another None
315-
self.worker_result_queue.put(None)
316-
if self.worker_pids_set:
317-
_remove_worker_pids(id(self))
318-
self.worker_pids_set = False
313+
try:
314+
if not self.shutdown:
315+
self.shutdown = True
316+
self.done_event.set()
317+
# if worker_manager_thread is waiting to put
318+
while not self.data_queue.empty():
319+
self.data_queue.get()
320+
for _ in self.workers:
321+
self.index_queue.put(None)
322+
# done_event should be sufficient to exit worker_manager_thread,
323+
# but be safe here and put another None
324+
self.worker_result_queue.put(None)
325+
finally:
326+
# removes pids no matter what
327+
if self.worker_pids_set:
328+
_remove_worker_pids(id(self))
329+
self.worker_pids_set = False
319330

320331
def __del__(self):
321332
if self.num_workers > 0:
@@ -351,8 +362,8 @@ class DataLoader(object):
351362
timeout (numeric, optional): if positive, the timeout value for collecting a batch
352363
from workers. Should always be non-negative. (default: 0)
353364
worker_init_fn (callable, optional): If not None, this will be called on each
354-
worker subprocess with the worker id as input, after seeding and before data
355-
loading. (default: None)
365+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
366+
input, after seeding and before data loading. (default: None)
356367
357368
.. note:: By default, each worker will have its PyTorch seed set to
358369
``base_seed + worker_id``, where ``base_seed`` is a long generated

0 commit comments

Comments
 (0)