Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Sep 23, 2018

Test plan:

Trained 16000 LeNets (total) on MNIST on 160 processes. Previously some of them will either hang during training or during program exiting. Now every process finishes successfully.

    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
    #
    # Preliminary:
    #
    # Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.
    #
    #
    # Terminating multiprocessing logic requires very careful design. In
    # particular, we need to make sure that
    #
    #   1. The iterator gracefully exits the workers when its last reference is
    #      gone.
    #
    #      In this case, the workers should be gracefully exited because the
    #      main process may still need to continue to run, and we want cleaning
    #      up code in the workers to be executed (e.g., releasing GPU memory).
    #      Naturally, we implement the shutdown logic in `__del__` of
    #      DataLoaderIterator.
    #
    #      We delay the discussion on the logic in this case until later.
    #
    #   2. The iterator exits the workers when the problem ends
    #
    #      We set all workers and `pin_memory_thread` to have `daemon=True`.
    #      Doing so means that when the program ends, it shuts the workers down
    #      with a SIGTERM. `pin_memory_thread` will exit too, but by a different
    #      mechanism.
    #
    #      You may ask, why don't we just not set the workers as daemonic, and
    #      gracefully exit using the same logic as we have in `__del__` when the
    #      iterator gets deleted (see 1 above)? The answer requires a bit
    #      understanding of Python multiprocessing design. As of Python 3.7, for
    #      reasons I have yet to understand, in a subprocess, Python runs the
    #      given function (e.g., the `target` argument passed to a `mp.Process`)
    #      using this pattern (unrelated code removed for clarity):
    #
    #          # These are run the sub-process
    #          try:
    #              user_provided_function()
    #          finally:
    #              multiprocessing.util._exit_function()
    #
    #      In `_exit_function`, Python joins all non-daemonic subprocesses of
    #      this process (which is a subprocess of a Python process itself), and
    #      sends SIGTERM to the daemonic ones. Therefore, if a DataLoader is
    #      used in a subprocess (i.e., used in `user_provided_function` above),
    #      and an error is raised containing frames that references the
    #      DataLoaderIter (Python exception traces keeps local objects in
    #      relevant frames alive), workers will be joined in `_exit_function`
    #      before the `__del__` is called (which starts the shutdown logic). And
    #      unfortunately the DataLoaderIter process will hang. E.g., such errors
    #      can be timeout, or arbitrary error if users hold a reference to an
    #      iterator.
    #
    #      For context, `_exit_function` is also registered as an `atexit` call.
    #      So I really don't understand the need to do this in a finally block
    #      The code dates back to 2008 and there is no comment on the original
    #      PEP 371 or patch https://bugs.python.org/issue3050 (containing both
    #      the finally block and the `atexit` registration) that explains this.
    #
    #      Another choice is to just shutdown workers with logic in 1 above
    #      whenever we see an error in `next`. This isn't ideal because
    #        a. It prevents users from using try-catch to resume data loading.
    #        b. It doesn't prevent hanging if users have references to the
    #           iterator.
    #
    #   3. All processes exit if any of them die unexpectedly (e.g., error,
    #      SIGKILL).
    #
    #      As shown above, the workers are set as daemonic children of the main
    #      process. However, automatic cleaning-up of such child processes only
    #      happen if the parent process exits gracefully (e.g., SIGTERM). So we
    #      must ensure that each process will exit even the process that should
    #      send/receive data to/from it were killed, i.e.,
    #
    #        a. A process won't hang when getting from a queue.
    #
    #           Even with carefully designed data dependencies (i.e., a `put()`
    #           always corresponding to a `get()`), hanging on `get()` can still
    #           happen when data in queue is corrupted (e.g., due to
    #           `cancel_join_thread` or unexpected exit).
    #
    #           For child exit, we register SIGCHLD handler on main process,
    #           which checks if any of the workers fail in the (Python) handler.
    #           See DataLoader.cpp.
    #
    #           For `.get()` calls where the sender(s) is not the workers, we
    #           guard them with timeouts, and check the status of the sender
    #           when timeout happens:
    #             + in the workers, the `ManagerWatchdog` class checks the main
    #               process status.
    #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
    #               check `pin_memory_thread` status periodically until `.get()`
    #               returns or see that `pin_memory_thread` died.
    #
    #        b. A process won't hang when putting into a queue;
    #
    #           We use `mp.Queue` which has a separate background thread to put
    #           objects. The background thread is usually automatically joined
    #           when the process exits.
    #
    #           However, in case that the receiver has ended abruptly while
    #           reading from the pipe, the join will hang forever. Therefore,
    #           for both `worker_result_queue` (worker -> main process/pin_memory_thread)
    #           and each `index_queue` (main process -> worker), we use
    #           `q.cancel_join_thread()` in sender process before any `q.put` to
    #           prevent this automatic join.
    #
    #           Moreover, having all queues called `cancel_join_thread` makes
    #           implementing graceful shutdown logic in `__del__` much easier.
    #           It won't need to get from any queue, which would also need to be
    #           guarded by periodic status checks.
    #
    #           Note that this may leave corrupted data in the queue, but we
    #           don't care about the data anyways once we are shutting down.
    #
    #
    # Now let's get back to 1:
    #   how we gracefully exit the workers when the last reference to the
    #   iteartor is gone.
    #
    # To achieve this, we implement the following logic along with the design
    # choices mentioned above:
    #
    # [pin_memory_thread] and [worker processes]
    #   When getting from queues,
    #     if get a `None`, exit.
    #     if get anything else or time out, check `done_event`,
    #        if set, keep getting until see the `None`, then exit.
    #        otherwise, process the data.
    #
    # [main process]
    #   In the DataLoader Iter's `__del__`
    #     a. Set `done_event` (shared with `pin_memory_thread` and workers).
    #
    #        Note: from here on, the workers & `pin_memory_thread` may exit at
    #              any time after they receive `None`.
    #
    #     b. Exit `pin_memory_thread`
    #          i.   Put `None` in `worker_result_queue`.
    #          ii.  Join the `pin_memory_thread`.
    #
    #     c. Exit the workers.
    #          i.   Put `None` in each worker's `index_queue`.
    #          ii.  Join the workers.
    #
    #        Note: This has to be after (b) because it may leave corrupted data
    #              in `worker_result_queue`, which `pin_memory_thread` reads
    #              from.
    #
    #   Note: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
    #         can be omitted
    #
    # NB: `done_event`s isn't strictly needed. E.g., we can just check for
    #     `None` from `index_queue`, but it allows us to skip wasting resources
    #     processing indices already in `index_queue` if we are already shutting
    #     down.

Original desc:

In DataLoaderIter __del__, ensure that None is sent to pin_memory_thread before joining workers.

Trace when interrupted at such a hang:

 Exception ignored in: <function _DataLoaderIter.__del__ at 0x7facf66760d0>
 Traceback (most recent call last):
   File "/private/home/ssnl/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 412, in __del__
     self._shutdown_workers()
   File "/private/home/ssnl/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 408, in _shutdown_worke

     self.pin_memory_thread.join()
   File "/private/home/ssnl/miniconda3/lib/python3.7/threading.py", line 1032, in join
     self._wait_for_tstate_lock()
   File "/private/home/ssnl/miniconda3/lib/python3.7/threading.py", line 1048, in _wait_for_tstate_lock
     elif lock.acquire(block, timeout):
 KeyboardInterrupt

The 1st commit solves majority of the hang, but uncovers another problem:

36: Exception ignored in: <function _DataLoaderIter.__del__ at 0x7f214fa412f0>
36: Traceback (most recent call last):
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 416, in __
del__
36:     self._shutdown_workers()
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 401, in _s
hutdown_workers
36:     self.worker_result_queue.join_thread()
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/multiprocessing/queues.py", line 145, in join_thread
36:     self._jointhread()
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/multiprocessing/util.py", line 189, in __call__
36:     res = self._callback(*self._args, **self._kwargs)
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/multiprocessing/queues.py", line 192, in _finalize_join
36:     thread.join()
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/threading.py", line 1032, in join
36:     self._wait_for_tstate_lock()
36:   File "/private/home/ssnl/miniconda3/lib/python3.7/threading.py", line 1048, in _wait_for_tstate_lock
36:     elif lock.acquire(block, timeout):
36: KeyboardInterrupt

@ssnl
Copy link
Collaborator Author

ssnl commented Sep 23, 2018

cc @csarofeen I remember that you said there is still occasional hang. This might fix it.

@csarofeen
Copy link
Contributor

csarofeen commented Sep 24, 2018

Thanks for keeping me posted, I'm checking now to see if I can still get any hangs. I am getting quite a few errors on cleanup (disclaimer this is on a ~month old master + most of your changes, so take with a grain of salt for the moment).

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 134, in _pin_memory_loop
    r = in_queue.get()
  File "/opt/conda/lib/python3.6/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 201, in rebuild_storage_fd
    fd = df.detach()
  File "/opt/conda/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/opt/conda/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 487, in Client
    c = SocketClient(address)
  File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 614, in SocketClient
    s.connect(address)
FileNotFoundError: [Errno 2] No such file or directory

@csarofeen
Copy link
Contributor

Besides the error, I am seeing one of the big hangs I was fighting with is gone. Thanks @ssnl! Will also check on a newer master commit.

@ssnl
Copy link
Collaborator Author

ssnl commented Sep 24, 2018

@csarofeen Thanks for testing out. I am aware of that error. I think it is due to that during python exit, when the main process tries to get from worker in __del__, the worker may already exited and the pipe/fd isn't valid anymore.

There is another hang related to this. I think it is due to that we set worker.daemon = True so when exiting python, they may (and very likely) terminate before sending the Nones that signals the end.

@csarofeen
Copy link
Contributor

csarofeen commented Sep 24, 2018

I'm getting a test failure now...

======================================================================
FAIL: test_timeout (__main__.TestDataLoader)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_dataloader.py", line 379, in test_timeout
    self.assertFalse(p.is_alive())
AssertionError: True is not false

----------------------------------------------------------------------
Ran 44 tests in 23.531s

Sorry, I see this is already showing up in the CI.

@ssnl
Copy link
Collaborator Author

ssnl commented Sep 25, 2018

@csarofeen Yep, I will fix that. There are still some tricky parts to be fixed.

@ssnl ssnl added the blocker label Sep 29, 2018
@ssnl ssnl changed the title Prevent hanging in data loader dtor pin_memory_thread.join() Prevent hanging in data loader altogether Oct 2, 2018
@ssnl ssnl force-pushed the join_queue branch 3 times, most recently from e08b452 to 3771ee4 Compare October 3, 2018 08:00
@ssnl
Copy link
Collaborator Author

ssnl commented Oct 3, 2018

I firmly believe that this should work in all cases. And empirically it has even lowered the data loading shutdown overhead a bit for me when training LeNets on MNIST. I'm training 9000 LeNets on multiple processes to see if any still hangs. (Previously I can't train 200 without seeing hanging on one rank).

@ssnl
Copy link
Collaborator Author

ssnl commented Oct 4, 2018

@pytorchbot retest this please

@yf225
Copy link
Contributor

yf225 commented Oct 4, 2018

@pytorchbot retest this please

1 similar comment
@yf225
Copy link
Contributor

yf225 commented Oct 4, 2018

@pytorchbot retest this please

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't finished a review, but I have a bunch of comments and need to board a plane now.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@bombs-kim
Copy link
Contributor

bombs-kim commented Dec 9, 2018

@ssnl I'm sorry but NOTE [ Data Loader Multiprocessing Shutdown Logic ] that you wrote is somewhat confusing to me... Firstly, did you mean this process.py code in CPython, when you said this?

in a subprocess, Python runs the given function (e.g., the target argument passed to a mp.Process)
# using this pattern (unrelated code removed for clarity):
#
# # These are run the sub-process
# try:
# user_provided_function()
# finally:
# multiprocessing.util._exit_function()

And if that's the case, is it really related to Python 3.7 or later? Because I can see the similar code in previous versions.

Thank you in advance!

@ssnl
Copy link
Collaborator Author

ssnl commented Dec 10, 2018

@bombs-kim Yes I mean that code. It is not specific to python 3.7 or later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants