Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def accept_output(update_type):
self.assertEqual(s, expected)

if sys.version_info < (3, 2):
# assertRegexpMatches renamed assertRegex in 3.2
assertRegex = unittest.TestCase.assertRegexpMatches
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp

Expand Down
65 changes: 56 additions & 9 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import math
import sys
import errno
import os
import ctypes
import signal
import torch
import time
import traceback
import unittest
from torch import multiprocessing
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import default_collate
from torch.utils.data.dataloader import default_collate, ExceptionWrapper
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
from common_nn import TEST_CUDA

JOIN_TIMEOUT = 14.0 if IS_WINDOWS else 1.5

JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5


class TestDatasetRandomSplit(TestCase):
Expand Down Expand Up @@ -103,6 +106,46 @@ def test_add_dataset(self):
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())


# Stores the first encountered exception in .exception.
# Inspired by https://stackoverflow.com/a/33599967
class ErrorTrackingProcess(multiprocessing.Process):

def __init__(self, *args, **kwargs):
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
self._pconn, self._cconn = multiprocessing.Pipe()
self._exception = None

def run(self):
# Disable stderr printing from os level, and make workers not printing
# to stderr.
# Can't use sys.stderr.close, otherwise Python `raise` will error with
# ValueError: I/O operation on closed file.
os.close(sys.stderr.fileno())
try:
super(ErrorTrackingProcess, self).run()
self._cconn.send(None)
except Exception as e:
self._cconn.send(ExceptionWrapper(sys.exc_info()))
raise

@property
def exception(self):
if self._pconn.poll():
self._exception = self._pconn.recv()
if self._exception is None:
return None
else:
return self._exception.exc_type(self._exception.exc_msg)

# ESRCH means that os.kill can't finds alive proc
def send_signal(self, signum, ignore_ESRCH=False):
try:
os.kill(self.pid, signum)
except OSError as e:
if not ignore_ESRCH or e.errno != errno.ESRCH:
raise


class ErrorDataset(Dataset):

def __init__(self, size):
Expand Down Expand Up @@ -175,16 +218,12 @@ def __len__(self):


def _test_timeout():
os.close(sys.stderr.fileno())
sys.stderr.close()
dataset = SleepDataset(10, 10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
_ = next(iter(dataloader))


def _test_segfault():
os.close(sys.stderr.fileno())
sys.stderr.close()
dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
_ = next(iter(dataloader))
Expand Down Expand Up @@ -272,22 +311,30 @@ def test_multiple_dataloaders(self):
next(loader2_it)

def test_segfault(self):
p = multiprocessing.Process(target=_test_segfault)
p = ErrorTrackingProcess(target=_test_segfault)
p.start()
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
if IS_WINDOWS:
self.assertIsInstance(p.exception, OSError)
self.assertRegex(str(p.exception), r'access violation reading ')
else:
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
finally:
p.terminate()

def test_timeout(self):
p = multiprocessing.Process(target=_test_timeout)
p = ErrorTrackingProcess(target=_test_timeout)
p.start()
p.join(3.0 + JOIN_TIMEOUT)
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
finally:
p.terminate()

Expand Down
42 changes: 33 additions & 9 deletions torch/csrc/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \

// signal(2) is really not portable. So use sigaction.
// http://man7.org/linux/man-pages/man2/signal.2.html
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
{
struct sigaction sa;
sa.sa_sigaction = handler;
Expand All @@ -59,10 +59,34 @@ SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered
"This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");

// When an error happend in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children
// processes first before deleting the loader object. Then the cleaning up
// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
// again, and then it defeats the whole purpose.
static void handler_SIGTERM(int sig, siginfo_t *info, void *ctx)
{
if (info->si_pid == getppid()) {
_exit(EXIT_SUCCESS);
}
struct sigaction sa;
sa.sa_handler = SIG_DFL;
sa.sa_flags = 0;
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, NULL) != 0) {
_exit(EXIT_FAILURE);
} else {
raise(SIGTERM);
}
}

PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
HANDLE_TH_ERRORS
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
setSignalHandler(SIGTERM, &handler_SIGTERM, NULL);
Py_RETURN_TRUE;
END_HANDLE_TH_ERRORS
}
Expand All @@ -73,33 +97,33 @@ PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
HANDLE_TH_ERRORS
int error;
std::set<pid_t> *pid_set;
pid_t pid;
pid_t worker_pid;
siginfo_t infop;

// Only check the pids we care about
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
pid_set = &(it->second);
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
pid = *pid_it;
worker_pid = *pid_it;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop.si_pid = 0;
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
error = waitid(P_PID, worker_pid, &infop, WEXITED|WNOHANG|WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0)
continue;
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error
std::ostringstream oss;
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
<< "with exit code " << infop.si_status << ".";
oss << "DataLoader worker (pid " << worker_pid << ") exited "
<< "unexpectedly with exit code " << infop.si_status << ".";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
throw std::runtime_error(oss.str());
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss;
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
<< strsignal(infop.si_status) << ".";
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
<< "by signal: " << strsignal(infop.si_status) << ".";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
Expand Down
57 changes: 34 additions & 23 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,28 @@
import collections
import re
import sys
import traceback
import threading
import traceback
from torch._six import string_classes, int_classes


if sys.version_info[0] == 2:
import Queue as queue
else:
import queue


_use_shared_memory = False
"""Whether to use shared memory in default_collate"""


class ExceptionWrapper(object):
"Wraps an exception plus traceback to communicate across threads"
r"Wraps an exception plus traceback to communicate across threads"

def __init__(self, exc_info):
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))


_use_shared_memory = False
"""Whether to use shared memory in default_collate"""


def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
Expand Down Expand Up @@ -157,7 +156,11 @@ def pin_memory_batch(batch):


def _set_SIGCHLD_handler():
if sys.platform == 'win32': # Windows doesn't support SIGCHLD handler
# Windows doesn't support SIGCHLD handler
if sys.platform == 'win32':
return
# can't set signal in child threads
if not isinstance(threading.current_thread(), threading._MainThread):
return
global _SIGCHLD_handler_set
if _SIGCHLD_handler_set:
Expand Down Expand Up @@ -212,10 +215,15 @@ def __init__(self, loader):

if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
torch.cuda.current_device()))
maybe_device_id))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
Expand All @@ -239,7 +247,7 @@ def __len__(self):
def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(True, self.timeout)
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
Expand Down Expand Up @@ -302,17 +310,20 @@ def __getstate__(self):
raise NotImplementedError("DataLoaderIterator cannot be pickled")

def _shutdown_workers(self):
if not self.shutdown:
self.shutdown = True
self.done_event.set()
# if worker_manager_thread is waiting to put
while not self.data_queue.empty():
self.data_queue.get()
for _ in self.workers:
self.index_queue.put(None)
# done_event should be sufficient to exit worker_manager_thread, but
# be safe here and put another None
self.worker_result_queue.put(None)
try:
if not self.shutdown:
self.shutdown = True
self.done_event.set()
# if worker_manager_thread is waiting to put
while not self.data_queue.empty():
self.data_queue.get()
for _ in self.workers:
self.index_queue.put(None)
# done_event should be sufficient to exit worker_manager_thread,
# but be safe here and put another None
self.worker_result_queue.put(None)
finally:
# removes pids no matter what
if self.worker_pids_set:
_remove_worker_pids(id(self))
self.worker_pids_set = False
Expand Down Expand Up @@ -351,8 +362,8 @@ class DataLoader(object):
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id as input, after seeding and before data
loading. (default: None)
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: None)

.. note:: By default, each worker will have its PyTorch seed set to
``base_seed + worker_id``, where ``base_seed`` is a long generated
Expand Down