Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_test.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'indexing',
'jit',
'multiprocessing',
'multiprocessing_spawn',
'nccl',
'nn',
'numba_integration',
Expand Down
124 changes: 124 additions & 0 deletions test/test_multiprocessing_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import multiprocessing
import os
import random
import signal
import sys
import time
import unittest

from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
import torch.multiprocessing as mp


def test_success_func(i):
pass


def test_success_single_arg_func(i, arg):
if arg:
arg.put(i)


def test_exception_single_func(i, arg):
if i == arg:
raise ValueError("legitimate exception from process %d" % i)
time.sleep(1.0)


def test_exception_all_func(i):
time.sleep(random.random() / 10)
raise ValueError("legitimate exception from process %d" % i)


def test_terminate_signal_func(i):
if i == 0:
os.kill(os.getpid(), signal.SIGABRT)
time.sleep(1.0)


def test_terminate_exit_func(i, arg):
if i == 0:
sys.exit(arg)
time.sleep(1.0)


def test_success_first_then_exception_func(i, arg):
if i == 0:
return
time.sleep(0.1)
raise ValueError("legitimate exception")


@unittest.skipIf(
NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that don't support the spawn start method")
class SpawnTest(TestCase):
def test_success(self):
mp.spawn(test_success_func, nprocs=2)

def test_success_non_blocking(self):
spawn_context = mp.spawn(test_success_func, nprocs=2, join=False)

# After all processes (nproc=2) have joined it must return True
spawn_context.join(timeout=None)
spawn_context.join(timeout=None)
self.assertTrue(spawn_context.join(timeout=None))

def test_first_argument_index(self):
context = mp.get_context("spawn")
queue = context.SimpleQueue()
mp.spawn(test_success_single_arg_func, args=(queue,), nprocs=2)
self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))

def test_exception_single(self):
nprocs = 2
for i in range(nprocs):
with self.assertRaisesRegex(
Exception,
"\nValueError: legitimate exception from process %d$" % i,
):
mp.spawn(test_exception_single_func, args=(i,), nprocs=nprocs)

def test_exception_all(self):
with self.assertRaisesRegex(
Exception,
"\nValueError: legitimate exception from process (0|1)$",
):
mp.spawn(test_exception_all_func, nprocs=2)

def test_terminate_signal(self):
# SIGABRT is aliased with SIGIOT
message = "process 0 terminated with signal (SIGABRT|SIGIOT)"

# Termination through with signal is expressed as a negative exit code
# in multiprocessing, so we know it was a signal that caused the exit.
# This doesn't appear to exist on Windows, where the exit code is always
# positive, and therefore results in a different exception message.
# Exit code 22 means "ERROR_BAD_COMMAND".
if IS_WINDOWS:
message = "process 0 terminated with exit code 22"

with self.assertRaisesRegex(Exception, message):
mp.spawn(test_terminate_signal_func, nprocs=2)

def test_terminate_exit(self):
exitcode = 123
with self.assertRaisesRegex(
Exception,
"process 0 terminated with exit code %d" % exitcode,
):
mp.spawn(test_terminate_exit_func, args=(exitcode,), nprocs=2)

def test_success_first_then_exception(self):
exitcode = 123
with self.assertRaisesRegex(
Exception,
"ValueError: legitimate exception",
):
mp.spawn(test_success_first_then_exception_func, args=(exitcode,), nprocs=2)


if __name__ == '__main__':
run_tests()
6 changes: 6 additions & 0 deletions torch/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
from .pool import Pool


if sys.version_info >= (3, 4):
"""Add helper function to spawn N processes and wait for completion of any of
them. This depends `mp.get_context` which was added in Python 3.4."""
from .spawn import spawn


if sys.platform == 'darwin' or sys.platform == 'win32':
_sharing_strategy = 'file_system'
_all_sharing_strategies = {'file_system'}
Expand Down
138 changes: 138 additions & 0 deletions torch/multiprocessing/spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import multiprocessing
import multiprocessing.connection
import signal
import sys


def _wrap(fn, i, args, error_queue):
try:
fn(i, *args)

This comment was marked as off-topic.

except KeyboardInterrupt:
pass # SIGINT; Killed by parent, do nothing
except Exception:
# Propagate exception to parent process, keeping original traceback
import traceback
error_queue.put(traceback.format_exc())
sys.exit(1)


class SpawnContext:
def __init__(self, processes, error_queues):
self.error_queues = error_queues
self.processes = processes
self.sentinels = {
process.sentinel: index
for index, process in enumerate(processes)
}

def join(self, timeout=None):
r"""
Tries to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.

Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.

Arguments:
timeout (float): Wait this long before giving up on waiting.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True

# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)

error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break

# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0

# Assume failure. Terminate processes that are still alive.
for process in self.processes:
if process.is_alive():
process.terminate()
process.join()

# There won't be an error on the queue if the process crashed.
if self.error_queues[error_index].empty():
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
name = signal.Signals(-exitcode).name
raise Exception(
"process %d terminated with signal %s" %
(error_index, name)
)
else:
raise Exception(
"process %d terminated with exit code %d" %
(error_index, exitcode)
)

original_trace = self.error_queues[error_index].get()
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
msg += original_trace
raise Exception(msg)


def spawn(fn, args=(), nprocs=1, join=True):
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.

This comment was marked as off-topic.


If one of the processes exits with a non-zero exit status, the
remaining processes are killed and an exception is raised with the
cause of termination. In the case an exception was caught in the
child process, it is forwarded and its traceback is included in
the exception raised in the parent process.

Arguments:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.

The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.

args (tuple): Arguments passed to ``fn``.
nprocs (int): Number of processes to spawn.
join (bool): Perform a blocking join on all processes.

"""
mp = multiprocessing.get_context('spawn')
error_queues = []
processes = []
for i in range(nprocs):
error_queue = mp.SimpleQueue()
process = mp.Process(
target=_wrap,
args=(fn, i, args, error_queue),
daemon=True,
)
process.start()
error_queues.append(error_queue)
processes.append(process)

spawn_context = SpawnContext(processes, error_queues)
if not join:
return spawn_context

# Loop on join until it returns True or raises an exception.
while not spawn_context.join():
pass