Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import math
import sys
import ctypes
import torch
import time
import traceback
import unittest
from torch import multiprocessing
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataloader import default_collate
from common import TestCase, run_tests, TEST_NUMPY
Expand Down Expand Up @@ -83,6 +86,32 @@ def __len__(self):
return self.size


class SegfaultDataset(Dataset):

def __init__(self, size):
self.size = size

def __getitem__(self, idx):
return ctypes.string_at(0)

def __len__(self):
return self.size


class SleepDataset(Dataset):

def __init__(self, size, sleep_sec):
self.size = size
self.sleep_sec = sleep_sec

def __getitem__(self, idx):
time.sleep(self.sleep_sec)
return idx

def __len__(self):
return self.size


class TestDataLoader(TestCase):

def setUp(self):
Expand Down Expand Up @@ -149,6 +178,48 @@ def test_sequential_pin_memory(self):
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())

def test_multiple_dataloaders(self):
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
loader2_it = iter(DataLoader(self.dataset, num_workers=2))
next(loader1_it)
next(loader1_it)
next(loader2_it)
next(loader2_it)
next(loader1_it)
next(loader2_it)

def test_segfault(self):
def _test_segfault():
sys.stderr.close()
dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
_ = next(iter(dataloader))

p = multiprocessing.Process(target=_test_segfault)
p.start()
p.join(1.0)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
finally:
p.terminate()

def test_timeout(self):
def _test_timeout():
sys.stderr.close()
dataset = SleepDataset(10, 10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
_ = next(iter(dataloader))

p = multiprocessing.Process(target=_test_timeout)
p.start()
p.join(3.0)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
finally:
p.terminate()

def test_shuffle(self):
self._test_shuffle(DataLoader(self.dataset, shuffle=True))

Expand Down Expand Up @@ -224,7 +295,7 @@ def test_partial_workers(self):
"check that workers exit even if the iterator is not exhausted"
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True))
workers = loader.workers
pin_thread = loader.pin_thread
worker_manager_thread = loader.worker_manager_thread
for i, sample in enumerate(loader):
if i == 3:
break
Expand All @@ -233,8 +304,8 @@ def test_partial_workers(self):
w.join(1.0) # timeout of one second
self.assertFalse(w.is_alive(), 'subprocess not terminated')
self.assertEqual(w.exitcode, 0)
pin_thread.join(1.0)
self.assertFalse(pin_thread.is_alive())
worker_manager_thread.join(1.0)
self.assertFalse(worker_manager_thread.is_alive())

def test_len(self):
def check_len(dl, expected):
Expand Down
179 changes: 179 additions & 0 deletions torch/csrc/DataLoader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#include <sys/wait.h>
#include <map>
#include <set>
#include <atomic>
#include <signal.h>
#include "THP.h"

// In cases like DataLoader, if a worker process die due to bus error/segfault
// or just hang, the main process, if implemented with
// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is
// difficult to avoid on PyTorch side as it can be caused by limited shm, or
// other libraries users call in the workers. The following methods is an effort
// to do our best provide some error message to users when such unfortunate
// events happen.

// TODO: The following don't work on Windows. Specifically, sigaction, waitid
// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
// for Windows.

#ifndef _WIN32

// Critical signal handlers should be registered on worker processes before
// doing work.
// The handler will raise default handler so that the kill information will be
// retrieved from main process.
// Python handle is _set_worker_signal_handlers().
#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
static void HANDLER_NAME(int sig, siginfo_t *info, void *ctx) \
{ \
write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
struct sigaction sa; \
sa.sa_handler = SIG_DFL; \
sa.sa_flags = 0; \
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGNAL, &sa, NULL) != 0) { \
_exit(EXIT_FAILURE); \
} else { \
raise(SIGNAL); \

This comment was marked as off-topic.

This comment was marked as off-topic.

} \
}

// signal(2) is really not portable. So use sigaction.
// http://man7.org/linux/man-pages/man2/signal.2.html
static void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, void *), struct sigaction *old_sa_ptr)
{
struct sigaction sa;
sa.sa_sigaction = handler;
sa.sa_flags = SA_RESTART|SA_SIGINFO|SA_NOCLDSTOP|SA_NODEFER;
if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) {
std::ostringstream oss;
oss << "An error occurred while setting handler for " << strsignal(signal) << ".";
throw std::runtime_error(oss.str());
}
}

SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
"This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");

PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *arg) {
HANDLE_TH_ERRORS
setSignalHandler(SIGBUS, &handler_SIGBUS, NULL);
setSignalHandler(SIGSEGV, &handler_SIGSEGV, NULL);
Py_RETURN_TRUE;
END_HANDLE_TH_ERRORS
}

static std::map<int64_t, std::set<pid_t>> worker_pids = {};

PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
HANDLE_TH_ERRORS
int error;
std::set<pid_t> *pid_set;
pid_t pid;
siginfo_t infop;

// Only check the pids we care about
for (auto it = worker_pids.begin(); it != worker_pids.end(); ++it) {
pid_set = &(it->second);
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
pid = *pid_it;
// Use waitid rather than waitpid so that we can set NOWAIT, and that Python
// and other handlers can get whatever info they want about the child.
infop.si_pid = 0;
error = waitid(P_PID, pid, &infop, WEXITED|WNOHANG|WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0)
continue;
if (infop.si_code == CLD_EXITED && infop.si_status != 0) { // exit with error
std::ostringstream oss;
oss << "DataLoader worker (pid " << pid << ") exited unexpectedly "
<< "with exit code " << infop.si_status << ".";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
throw std::runtime_error(oss.str());
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss;
oss << "DataLoader worker (pid " << pid << ") is killed by signal: "
<< strsignal(infop.si_status) << ".";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
throw std::runtime_error(oss.str());
}
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
// of pids we are interested in.
PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) {
HANDLE_TH_ERRORS
Py_ssize_t num_args = args ? (Py_ssize_t) PyTuple_Size(args) : 0;
THPUtils_assert(num_args == 2, "_update_worker_pids expectes exactly 2 arguments.");
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
THPUtils_assert(worker_pids.find(key) == worker_pids.end(), "_update_worker_pids "
"should be called only once for each DataLoader.");
PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
THPUtils_assert(PyTuple_Check(child_pids), "_update_worker_pids "
"expects a tuple for child_pids, but got %s.", THPUtils_typename(child_pids));

std::set<pid_t> pids_set = {};
auto size = PyTuple_GET_SIZE(child_pids);
for (int idx = 0; idx < size; idx++) {
PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
pids_set.insert((pid_t) THPUtils_unpackLong(obj));
}

worker_pids[key] = pids_set;

Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_id) {
HANDLE_TH_ERRORS

int64_t key = THPUtils_unpackLong(loader_id);
THPUtils_assert(worker_pids.find(key) != worker_pids.end(), "Cannot find worker "
"information for DataLoader with id %ld.", key);

worker_pids.erase(key);

Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

#undef SIGNAL_HANDLER

#else
// dummy implementations for windows

PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ignored) {
Py_RETURN_TRUE;
}

PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) {
Py_RETURN_TRUE;
}

PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *_ignored) {
Py_RETURN_NONE;
}

PyObject *THPModule_exitIfAnyWorkerFails(PyObject *module, PyObject *_ignored) {
Py_RETURN_NONE;
}

#endif

PyMethodDef DataLoaderMethods[] = {
{"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, NULL},
{"_update_worker_pids", (PyCFunction)THPModule_updateWorkerPIDs, METH_VARARGS, NULL},
{"_remove_worker_pids", (PyCFunction)THPModule_removeWorkerPIDs, METH_O, NULL},
{"_error_if_any_worker_fails", (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, NULL},
{NULL, NULL, 0, NULL}
};
2 changes: 2 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "THP.h"

#include "ModuleSparse.cpp"
#include "DataLoader.cpp"

PyObject* module;
PyObject* tensor_classes;
Expand Down Expand Up @@ -797,6 +798,7 @@ static PyObject* initModule() {
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL

THPUtils_addPyMethodDefs(methods, TorchMethods);
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
#ifdef WITH_CUDA
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
#endif
Expand Down
Loading