Skip to content

Commit be424de

Browse files
pieternfacebook-github-bot
authored andcommitted
Add torch.multiprocessing.spawn helper (#13518)
Summary: This helper addresses a common pattern where one spawns N processes to work on some common task (e.g. parallel preprocessing or multiple training loops). A straightforward approach is to use the multiprocessing API directly and then consecutively call join on the resulting processes. This pattern breaks down in the face of errors. If one of the processes terminates with an exception or via some signal, and it is not the first process that was launched, the join call on the first process won't be affected. This helper seeks to solve this by waiting on termination from any of the spawned processes. When any process terminates with a non-zero exit status, it terminates the remaining processes, and raises an exception in the parent process. If the process terminated with an exception, it is propagated to the parent. If the process terminated via a signal (e.g. SIGINT, SIGSEGV), this is mentioned in the exception as well. Requires Python >= 3.4. Pull Request resolved: #13518 Reviewed By: orionr Differential Revision: D12929045 Pulled By: pietern fbshipit-source-id: 00df19fa16a568d1e22f37a2ba65677ab0cce3fd
1 parent 056f2cd commit be424de

File tree

4 files changed

+268
-0
lines changed

4 files changed

+268
-0
lines changed

test/run_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
'indexing',
3131
'jit',
3232
'multiprocessing',
33+
'multiprocessing_spawn',
3334
'nccl',
3435
'nn',
3536
'numba_integration',

test/test_multiprocessing_spawn.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import os
4+
import random
5+
import signal
6+
import sys
7+
import time
8+
import unittest
9+
10+
from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN)
11+
import torch.multiprocessing as mp
12+
13+
14+
def test_success_func(i):
15+
pass
16+
17+
18+
def test_success_single_arg_func(i, arg):
19+
if arg:
20+
arg.put(i)
21+
22+
23+
def test_exception_single_func(i, arg):
24+
if i == arg:
25+
raise ValueError("legitimate exception from process %d" % i)
26+
time.sleep(1.0)
27+
28+
29+
def test_exception_all_func(i):
30+
time.sleep(random.random() / 10)
31+
raise ValueError("legitimate exception from process %d" % i)
32+
33+
34+
def test_terminate_signal_func(i):
35+
if i == 0:
36+
os.kill(os.getpid(), signal.SIGABRT)
37+
time.sleep(1.0)
38+
39+
40+
def test_terminate_exit_func(i, arg):
41+
if i == 0:
42+
sys.exit(arg)
43+
time.sleep(1.0)
44+
45+
46+
def test_success_first_then_exception_func(i, arg):
47+
if i == 0:
48+
return
49+
time.sleep(0.1)
50+
raise ValueError("legitimate exception")
51+
52+
53+
@unittest.skipIf(
54+
NO_MULTIPROCESSING_SPAWN,
55+
"Disabled for environments that don't support the spawn start method")
56+
class SpawnTest(TestCase):
57+
def test_success(self):
58+
mp.spawn(test_success_func, nprocs=2)
59+
60+
def test_success_non_blocking(self):
61+
spawn_context = mp.spawn(test_success_func, nprocs=2, join=False)
62+
63+
# After all processes (nproc=2) have joined it must return True
64+
spawn_context.join(timeout=None)
65+
spawn_context.join(timeout=None)
66+
self.assertTrue(spawn_context.join(timeout=None))
67+
68+
def test_first_argument_index(self):
69+
context = mp.get_context("spawn")
70+
queue = context.SimpleQueue()
71+
mp.spawn(test_success_single_arg_func, args=(queue,), nprocs=2)
72+
self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))
73+
74+
def test_exception_single(self):
75+
nprocs = 2
76+
for i in range(nprocs):
77+
with self.assertRaisesRegex(
78+
Exception,
79+
"\nValueError: legitimate exception from process %d$" % i,
80+
):
81+
mp.spawn(test_exception_single_func, args=(i,), nprocs=nprocs)
82+
83+
def test_exception_all(self):
84+
with self.assertRaisesRegex(
85+
Exception,
86+
"\nValueError: legitimate exception from process (0|1)$",
87+
):
88+
mp.spawn(test_exception_all_func, nprocs=2)
89+
90+
def test_terminate_signal(self):
91+
# SIGABRT is aliased with SIGIOT
92+
message = "process 0 terminated with signal (SIGABRT|SIGIOT)"
93+
94+
# Termination through with signal is expressed as a negative exit code
95+
# in multiprocessing, so we know it was a signal that caused the exit.
96+
# This doesn't appear to exist on Windows, where the exit code is always
97+
# positive, and therefore results in a different exception message.
98+
# Exit code 22 means "ERROR_BAD_COMMAND".
99+
if IS_WINDOWS:
100+
message = "process 0 terminated with exit code 22"
101+
102+
with self.assertRaisesRegex(Exception, message):
103+
mp.spawn(test_terminate_signal_func, nprocs=2)
104+
105+
def test_terminate_exit(self):
106+
exitcode = 123
107+
with self.assertRaisesRegex(
108+
Exception,
109+
"process 0 terminated with exit code %d" % exitcode,
110+
):
111+
mp.spawn(test_terminate_exit_func, args=(exitcode,), nprocs=2)
112+
113+
def test_success_first_then_exception(self):
114+
exitcode = 123
115+
with self.assertRaisesRegex(
116+
Exception,
117+
"ValueError: legitimate exception",
118+
):
119+
mp.spawn(test_success_first_then_exception_func, args=(exitcode,), nprocs=2)
120+
121+
122+
if __name__ == '__main__':
123+
run_tests()

torch/multiprocessing/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
from .pool import Pool
3535

3636

37+
if sys.version_info >= (3, 4):
38+
"""Add helper function to spawn N processes and wait for completion of any of
39+
them. This depends `mp.get_context` which was added in Python 3.4."""
40+
from .spawn import spawn
41+
42+
3743
if sys.platform == 'darwin' or sys.platform == 'win32':
3844
_sharing_strategy = 'file_system'
3945
_all_sharing_strategies = {'file_system'}

torch/multiprocessing/spawn.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import multiprocessing
4+
import multiprocessing.connection
5+
import signal
6+
import sys
7+
8+
9+
def _wrap(fn, i, args, error_queue):
10+
try:
11+
fn(i, *args)
12+
except KeyboardInterrupt:
13+
pass # SIGINT; Killed by parent, do nothing
14+
except Exception:
15+
# Propagate exception to parent process, keeping original traceback
16+
import traceback
17+
error_queue.put(traceback.format_exc())
18+
sys.exit(1)
19+
20+
21+
class SpawnContext:
22+
def __init__(self, processes, error_queues):
23+
self.error_queues = error_queues
24+
self.processes = processes
25+
self.sentinels = {
26+
process.sentinel: index
27+
for index, process in enumerate(processes)
28+
}
29+
30+
def join(self, timeout=None):
31+
r"""
32+
Tries to join one or more processes in this spawn context.
33+
If one of them exited with a non-zero exit status, this function
34+
kills the remaining processes and raises an exception with the cause
35+
of the first process exiting.
36+
37+
Returns ``True`` if all processes have been joined successfully,
38+
``False`` if there are more processes that need to be joined.
39+
40+
Arguments:
41+
timeout (float): Wait this long before giving up on waiting.
42+
"""
43+
# Ensure this function can be called even when we're done.
44+
if len(self.sentinels) == 0:
45+
return True
46+
47+
# Wait for any process to fail or all of them to succeed.
48+
ready = multiprocessing.connection.wait(
49+
self.sentinels.keys(),
50+
timeout=timeout,
51+
)
52+
53+
error_index = None
54+
for sentinel in ready:
55+
index = self.sentinels.pop(sentinel)
56+
process = self.processes[index]
57+
process.join()
58+
if process.exitcode != 0:
59+
error_index = index
60+
break
61+
62+
# Return if there was no error.
63+
if error_index is None:
64+
# Return whether or not all processes have been joined.
65+
return len(self.sentinels) == 0
66+
67+
# Assume failure. Terminate processes that are still alive.
68+
for process in self.processes:
69+
if process.is_alive():
70+
process.terminate()
71+
process.join()
72+
73+
# There won't be an error on the queue if the process crashed.
74+
if self.error_queues[error_index].empty():
75+
exitcode = self.processes[error_index].exitcode
76+
if exitcode < 0:
77+
name = signal.Signals(-exitcode).name
78+
raise Exception(
79+
"process %d terminated with signal %s" %
80+
(error_index, name)
81+
)
82+
else:
83+
raise Exception(
84+
"process %d terminated with exit code %d" %
85+
(error_index, exitcode)
86+
)
87+
88+
original_trace = self.error_queues[error_index].get()
89+
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
90+
msg += original_trace
91+
raise Exception(msg)
92+
93+
94+
def spawn(fn, args=(), nprocs=1, join=True):
95+
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
96+
97+
If one of the processes exits with a non-zero exit status, the
98+
remaining processes are killed and an exception is raised with the
99+
cause of termination. In the case an exception was caught in the
100+
child process, it is forwarded and its traceback is included in
101+
the exception raised in the parent process.
102+
103+
Arguments:
104+
fn (function): Function is called as the entrypoint of the
105+
spawned process. This function must be defined at the top
106+
level of a module so it can be pickled and spawned. This
107+
is a requirement imposed by multiprocessing.
108+
109+
The function is called as ``fn(i, *args)``, where ``i`` is
110+
the process index and ``args`` is the passed through tuple
111+
of arguments.
112+
113+
args (tuple): Arguments passed to ``fn``.
114+
nprocs (int): Number of processes to spawn.
115+
join (bool): Perform a blocking join on all processes.
116+
117+
"""
118+
mp = multiprocessing.get_context('spawn')
119+
error_queues = []
120+
processes = []
121+
for i in range(nprocs):
122+
error_queue = mp.SimpleQueue()
123+
process = mp.Process(
124+
target=_wrap,
125+
args=(fn, i, args, error_queue),
126+
daemon=True,
127+
)
128+
process.start()
129+
error_queues.append(error_queue)
130+
processes.append(process)
131+
132+
spawn_context = SpawnContext(processes, error_queues)
133+
if not join:
134+
return spawn_context
135+
136+
# Loop on join until it returns True or raises an exception.
137+
while not spawn_context.join():
138+
pass

0 commit comments

Comments
 (0)