-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add torch.multiprocessing.spawn helper #13518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| 'indexing', | ||
| 'jit', | ||
| 'multiprocessing', | ||
| 'multiprocessing_spawn', | ||
| 'nccl', | ||
| 'nn', | ||
| 'numba_integration', | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
|
||
| import multiprocessing | ||
| import os | ||
| import random | ||
| import signal | ||
| import sys | ||
| import time | ||
| import unittest | ||
|
|
||
| from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN) | ||
| import torch.multiprocessing as mp | ||
|
|
||
|
|
||
| def test_success_func(i): | ||
| pass | ||
|
|
||
|
|
||
| def test_success_single_arg_func(i, arg): | ||
| if arg: | ||
| arg.put(i) | ||
|
|
||
|
|
||
| def test_exception_single_func(i, arg): | ||
| if i == arg: | ||
| raise ValueError("legitimate exception from process %d" % i) | ||
| time.sleep(1.0) | ||
|
|
||
|
|
||
| def test_exception_all_func(i): | ||
| time.sleep(random.random() / 10) | ||
| raise ValueError("legitimate exception from process %d" % i) | ||
|
|
||
|
|
||
| def test_terminate_signal_func(i): | ||
| if i == 0: | ||
| os.kill(os.getpid(), signal.SIGABRT) | ||
| time.sleep(1.0) | ||
|
|
||
|
|
||
| def test_terminate_exit_func(i, arg): | ||
| if i == 0: | ||
| sys.exit(arg) | ||
| time.sleep(1.0) | ||
|
|
||
|
|
||
| def test_success_first_then_exception_func(i, arg): | ||
| if i == 0: | ||
| return | ||
| time.sleep(0.1) | ||
| raise ValueError("legitimate exception") | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| NO_MULTIPROCESSING_SPAWN, | ||
| "Disabled for environments that don't support the spawn start method") | ||
| class SpawnTest(TestCase): | ||
| def test_success(self): | ||
| mp.spawn(test_success_func, nprocs=2) | ||
|
|
||
| def test_success_non_blocking(self): | ||
| spawn_context = mp.spawn(test_success_func, nprocs=2, join=False) | ||
|
|
||
| # After all processes (nproc=2) have joined it must return True | ||
| spawn_context.join(timeout=None) | ||
| spawn_context.join(timeout=None) | ||
| self.assertTrue(spawn_context.join(timeout=None)) | ||
|
|
||
| def test_first_argument_index(self): | ||
| context = mp.get_context("spawn") | ||
| queue = context.SimpleQueue() | ||
| mp.spawn(test_success_single_arg_func, args=(queue,), nprocs=2) | ||
| self.assertEqual([0, 1], sorted([queue.get(), queue.get()])) | ||
|
|
||
| def test_exception_single(self): | ||
| nprocs = 2 | ||
| for i in range(nprocs): | ||
| with self.assertRaisesRegex( | ||
| Exception, | ||
| "\nValueError: legitimate exception from process %d$" % i, | ||
| ): | ||
| mp.spawn(test_exception_single_func, args=(i,), nprocs=nprocs) | ||
|
|
||
| def test_exception_all(self): | ||
| with self.assertRaisesRegex( | ||
| Exception, | ||
| "\nValueError: legitimate exception from process (0|1)$", | ||
| ): | ||
| mp.spawn(test_exception_all_func, nprocs=2) | ||
|
|
||
| def test_terminate_signal(self): | ||
| # SIGABRT is aliased with SIGIOT | ||
| message = "process 0 terminated with signal (SIGABRT|SIGIOT)" | ||
|
|
||
| # Termination through with signal is expressed as a negative exit code | ||
| # in multiprocessing, so we know it was a signal that caused the exit. | ||
| # This doesn't appear to exist on Windows, where the exit code is always | ||
| # positive, and therefore results in a different exception message. | ||
| # Exit code 22 means "ERROR_BAD_COMMAND". | ||
| if IS_WINDOWS: | ||
| message = "process 0 terminated with exit code 22" | ||
|
|
||
| with self.assertRaisesRegex(Exception, message): | ||
| mp.spawn(test_terminate_signal_func, nprocs=2) | ||
|
|
||
| def test_terminate_exit(self): | ||
| exitcode = 123 | ||
| with self.assertRaisesRegex( | ||
| Exception, | ||
| "process 0 terminated with exit code %d" % exitcode, | ||
| ): | ||
| mp.spawn(test_terminate_exit_func, args=(exitcode,), nprocs=2) | ||
|
|
||
| def test_success_first_then_exception(self): | ||
| exitcode = 123 | ||
| with self.assertRaisesRegex( | ||
| Exception, | ||
| "ValueError: legitimate exception", | ||
| ): | ||
| mp.spawn(test_success_first_then_exception_func, args=(exitcode,), nprocs=2) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| run_tests() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
|
||
| import multiprocessing | ||
| import multiprocessing.connection | ||
| import signal | ||
| import sys | ||
|
|
||
|
|
||
| def _wrap(fn, i, args, error_queue): | ||
| try: | ||
| fn(i, *args) | ||
|
||
| except KeyboardInterrupt: | ||
| pass # SIGINT; Killed by parent, do nothing | ||
| except Exception: | ||
| # Propagate exception to parent process, keeping original traceback | ||
| import traceback | ||
| error_queue.put(traceback.format_exc()) | ||
| sys.exit(1) | ||
|
|
||
|
|
||
| class SpawnContext: | ||
| def __init__(self, processes, error_queues): | ||
| self.error_queues = error_queues | ||
| self.processes = processes | ||
| self.sentinels = { | ||
| process.sentinel: index | ||
| for index, process in enumerate(processes) | ||
| } | ||
|
|
||
| def join(self, timeout=None): | ||
| r""" | ||
| Tries to join one or more processes in this spawn context. | ||
| If one of them exited with a non-zero exit status, this function | ||
| kills the remaining processes and raises an exception with the cause | ||
| of the first process exiting. | ||
|
|
||
| Returns ``True`` if all processes have been joined successfully, | ||
| ``False`` if there are more processes that need to be joined. | ||
|
|
||
| Arguments: | ||
| timeout (float): Wait this long before giving up on waiting. | ||
| """ | ||
| # Ensure this function can be called even when we're done. | ||
| if len(self.sentinels) == 0: | ||
| return True | ||
|
|
||
| # Wait for any process to fail or all of them to succeed. | ||
| ready = multiprocessing.connection.wait( | ||
| self.sentinels.keys(), | ||
| timeout=timeout, | ||
| ) | ||
|
|
||
| error_index = None | ||
| for sentinel in ready: | ||
| index = self.sentinels.pop(sentinel) | ||
| process = self.processes[index] | ||
| process.join() | ||
| if process.exitcode != 0: | ||
| error_index = index | ||
| break | ||
|
|
||
| # Return if there was no error. | ||
| if error_index is None: | ||
| # Return whether or not all processes have been joined. | ||
| return len(self.sentinels) == 0 | ||
|
|
||
| # Assume failure. Terminate processes that are still alive. | ||
| for process in self.processes: | ||
| if process.is_alive(): | ||
| process.terminate() | ||
| process.join() | ||
|
|
||
| # There won't be an error on the queue if the process crashed. | ||
| if self.error_queues[error_index].empty(): | ||
| exitcode = self.processes[error_index].exitcode | ||
| if exitcode < 0: | ||
| name = signal.Signals(-exitcode).name | ||
| raise Exception( | ||
| "process %d terminated with signal %s" % | ||
| (error_index, name) | ||
| ) | ||
| else: | ||
| raise Exception( | ||
| "process %d terminated with exit code %d" % | ||
| (error_index, exitcode) | ||
| ) | ||
|
|
||
| original_trace = self.error_queues[error_index].get() | ||
| msg = "\n\n-- Process %d terminated with the following error:\n" % error_index | ||
| msg += original_trace | ||
| raise Exception(msg) | ||
|
|
||
|
|
||
| def spawn(fn, args=(), nprocs=1, join=True): | ||
| r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. | ||
|
||
|
|
||
| If one of the processes exits with a non-zero exit status, the | ||
| remaining processes are killed and an exception is raised with the | ||
| cause of termination. In the case an exception was caught in the | ||
| child process, it is forwarded and its traceback is included in | ||
| the exception raised in the parent process. | ||
|
|
||
| Arguments: | ||
| fn (function): Function is called as the entrypoint of the | ||
| spawned process. This function must be defined at the top | ||
| level of a module so it can be pickled and spawned. This | ||
| is a requirement imposed by multiprocessing. | ||
|
|
||
| The function is called as ``fn(i, *args)``, where ``i`` is | ||
| the process index and ``args`` is the passed through tuple | ||
| of arguments. | ||
|
|
||
| args (tuple): Arguments passed to ``fn``. | ||
| nprocs (int): Number of processes to spawn. | ||
| join (bool): Perform a blocking join on all processes. | ||
|
|
||
| """ | ||
| mp = multiprocessing.get_context('spawn') | ||
| error_queues = [] | ||
| processes = [] | ||
| for i in range(nprocs): | ||
| error_queue = mp.SimpleQueue() | ||
| process = mp.Process( | ||
| target=_wrap, | ||
| args=(fn, i, args, error_queue), | ||
| daemon=True, | ||
| ) | ||
| process.start() | ||
| error_queues.append(error_queue) | ||
| processes.append(process) | ||
|
|
||
| spawn_context = SpawnContext(processes, error_queues) | ||
| if not join: | ||
| return spawn_context | ||
|
|
||
| # Loop on join until it returns True or raises an exception. | ||
| while not spawn_context.join(): | ||
| pass | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.