Skip to content

Conversation

@mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Sep 21, 2020

Stack from ghstack:

Fixes #45082

Found a few problems while working on #44983

  1. We deliberately swallow RPC timeouts during shutdown, as we haven't
    found a good way to handle those. When we convert _wait_all_workers
    into _all_gather, the same logic was inherited. However, as
    _all_gather meant to be used in more general scenarios, we should
    no longer keep silent about errors. This commit let the error throw
    in _all_gather and also let shutdown() to catch them and log.
  2. After fixing (1), I found that UnpickledPythonCall needs to
    acquire GIL on destruction, and this can lead to deadlock when used
    in conjuction with ProcessGroup. Because ProcessGroup ctor is a
    synchronization point which holds GIL. In init_rpc, followers
    (rank != 0) can exit before the leader (rank == 0). If the two
    happens together, we could get a) on a follower, it exits init_rpc
    after running _broadcast_to_followers and before the reaching dtor
    of UnpickledPythonCall. Then it runs the ctor of ProcessGroup,
    which holds the GIL and wait for the leader to join. However, the
    leader is waiting for the response from _broadcast_to_followers,
    which is blocked by the dtor of UnpickledPythonCall. And hence
    the deadlock. This commit drops the GIL in ProcessGroup ctor.
  3. After fixing (2), I found that TensorPipe backend
    nondeterministically fails with test_local_shutdown, due to a
    similar reason as (2), but this time it is that shutdown() on a
    follower runs before the leader finishes init_rpc. This commit
    adds a join for TensorPipe backend init_rpc after _all_gather.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during ProcessGroup ctor, I
made that change too.

Differential Revision: D23825592

Fixes #45082

Found a few problems while working on #44983

1. We deliberately swallow RPC timeouts during shutdown, as we haven't
found a good way to handle those. When we convert `_wait_all_workers`
into `_all_gather`, the same logic was inherited. However, as
`_all_gather` meant to be used in more general scenarios, we should
no longer keep silent about errors. This commit let the error throw
in `_all_gather` and also let `shutdown()` to catch them and log.
2. After fixing (1), I found that `UnpickledPythonCall` needs to
acquire GIL on destruction, and this can lead to deadlock when used
in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a
synchronization point which holds GIL. In `init_rpc`, followers
(`rank != 0`) can exit before the leader (`rank == 0`). If the two
happens together, we could get a) on a follower, it exits `init_rpc`
after running `_broadcast_to_followers` and before the reaching dtor
of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`,
which holds the GIL and wait for the leader to join. However, the
leader is waiting for the response from `_broadcast_to_followers`,
which is blocked by the dtor of `UnpickledPythonCall`. And hence
the deadlock. This commit drops the GIL in `ProcessGroup` ctor.
3. After fixing (2), I found that `TensorPipe` backend
nondeterministically fails with `test_local_shutdown`, due to a
similar reason as (2), but this time it is that `shutdown()` on a
follower runs before the leader finishes `init_rpc`. This commit
adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during `ProcessGroup` ctor, I
made that change too.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 21, 2020
Fixes #45082

Found a few problems while working on #44983

1. We deliberately swallow RPC timeouts during shutdown, as we haven't
found a good way to handle those. When we convert `_wait_all_workers`
into `_all_gather`, the same logic was inherited. However, as
`_all_gather` meant to be used in more general scenarios, we should
no longer keep silent about errors. This commit let the error throw
in `_all_gather` and also let `shutdown()` to catch them and log.
2. After fixing (1), I found that `UnpickledPythonCall` needs to
acquire GIL on destruction, and this can lead to deadlock when used
in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a
synchronization point which holds GIL. In `init_rpc`, followers
(`rank != 0`) can exit before the leader (`rank == 0`). If the two
happens together, we could get a) on a follower, it exits `init_rpc`
after running `_broadcast_to_followers` and before the reaching dtor
of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`,
which holds the GIL and wait for the leader to join. However, the
leader is waiting for the response from `_broadcast_to_followers`,
which is blocked by the dtor of `UnpickledPythonCall`. And hence
the deadlock. This commit drops the GIL in `ProcessGroup` ctor.
3. After fixing (2), I found that `TensorPipe` backend
nondeterministically fails with `test_local_shutdown`, due to a
similar reason as (2), but this time it is that `shutdown()` on a
follower runs before the leader finishes `init_rpc`. This commit
adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during `ProcessGroup` ctor, I
made that change too.

ghstack-source-id: aab6baa
Pull Request resolved: #45088
@dr-ci
Copy link

dr-ci bot commented Sep 21, 2020

💊 CI failures summary and remediations

As of commit 82615bf (more details on the Dr. CI page):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_macos_10_13_py3_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun) <confirmed not flaky by 2 failures>

Sep 21 20:01:16 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future
Sep 21 20:01:16 At: 
Sep 21 20:01:16   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Sep 21 20:01:16   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Sep 21 20:01:16  
Sep 21 20:01:16 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future 
Sep 21 20:01:16  
Sep 21 20:01:16 At: 
Sep 21 20:01:16   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Sep 21 20:01:16   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Sep 21 20:01:16  
Sep 21 20:01:16 [E request_callback_no_python.cpp:618] Received error while processing request type 2: RuntimeError: Can not pickle torch.futures.Future 
Sep 21 20:01:16  
Sep 21 20:01:16 At: 
Sep 21 20:01:16   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(93): serialize 
Sep 21 20:01:16   /Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/distributed/rpc/internal.py(145): serialize 
Sep 21 20:01:16  
Sep 21 20:01:16 ok (1.569s) 
Sep 21 20:01:18   test_return_future_remote (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.546s) 
Sep 21 20:01:19   test_return_local_rrefs (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.485s) 
Sep 21 20:01:21   test_rpc_profiling_remote_record_function (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.523s) 
Sep 21 20:01:22   test_rpc_return_rref (__main__.ProcessGroupRpcTestWithSpawn) ... ok (1.482s) 

❄️ 2 failures tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test1 (1/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun) ❄️

ModuleNotFoundError: No module named 'torch'
[ERROR:VsDevCmd.bat] Where [value] is: 
[ERROR:VsDevCmd.bat]    1 : basic debug logging 
[ERROR:VsDevCmd.bat]    2 : detailed debug logging 
[ERROR:VsDevCmd.bat]    3 : trace level logging. Redirection of output to a file when using this level is recommended. 
[ERROR:VsDevCmd.bat] Example: set VSCMD_DEBUG=3 
[ERROR:VsDevCmd.bat]          vsdevcmd.bat > vsdevcmd.trace.txt 2>&1 
Run jit_profiling tests 
Traceback (most recent call last): 
  File "run_test.py", line 13, in <module> 
    import torch 
ModuleNotFoundError: No module named 'torch' 
+ cleanup
+ retcode=1
+ set +x

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (2/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun) ❄️

ModuleNotFoundError: No module named 'torch'
[ERROR:VsDevCmd.bat] vsdevcmd.bat [args] for additional details. 
[ERROR:VsDevCmd.bat] Where [value] is: 
[ERROR:VsDevCmd.bat]    1 : basic debug logging 
[ERROR:VsDevCmd.bat]    2 : detailed debug logging 
[ERROR:VsDevCmd.bat]    3 : trace level logging. Redirection of output to a file when using this level is recommended. 
[ERROR:VsDevCmd.bat] Example: set VSCMD_DEBUG=3 
[ERROR:VsDevCmd.bat]          vsdevcmd.bat > vsdevcmd.trace.txt 2>&1 
Traceback (most recent call last): 
  File "run_test.py", line 13, in <module> 
    import torch 
ModuleNotFoundError: No module named 'torch' 
+ cleanup
+ retcode=1
+ set +x

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 10 times.

Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a unit test that would've caught this issue? Even a unit test that would fail once in 100 without this fix would be useful to catch any regressions we have in the future.

worker_name=follower_name, timeout=timeout
)
)
fut.wait()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we would call wait() on all futures during shutdown, but now we throw on the first one and probably don't wait for the other futures.

It would be nice to keep the same behavior as before. Can we add a try-catch here and save the exception? We can probably only throw the first/last exception we see and catch that in _wait_all_workers.

@rohan-varma
Copy link
Contributor

Is it possible to add a unit test that would've caught this issue? Even a unit test that would fail once in 100 without this fix would be useful to catch any regressions we have in the future.

I think the test described here: #45089 should reproduce the RPC/ProcessGroup GIL deadlock issue, so this should probably be able to use as a unittest for (2).

@mrshenli
Copy link
Contributor Author

Is it possible to add a unit test that would've caught this issue? Even a unit test that would fail once in 100 without this fix would be useful to catch any regressions we have in the future.

After fixing (1), there are multiple existing tests failed due to (2) and (3). This problem didn't surface because existing code swallowed the failure, timed out in 5s, and logged the error message.

@mrshenli
Copy link
Contributor Author

I think the test described here: #45089 should reproduce the RPC/ProcessGroup GIL deadlock issue, so this should probably be able to use as a unittest for (2).

Thanks, let me pull that in as a test.

Fixes #45082

Found a few problems while working on #44983

1. We deliberately swallow RPC timeouts during shutdown, as we haven't
found a good way to handle those. When we convert `_wait_all_workers`
into `_all_gather`, the same logic was inherited. However, as
`_all_gather` meant to be used in more general scenarios, we should
no longer keep silent about errors. This commit let the error throw
in `_all_gather` and also let `shutdown()` to catch them and log.
2. After fixing (1), I found that `UnpickledPythonCall` needs to
acquire GIL on destruction, and this can lead to deadlock when used
in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a
synchronization point which holds GIL. In `init_rpc`, followers
(`rank != 0`) can exit before the leader (`rank == 0`). If the two
happens together, we could get a) on a follower, it exits `init_rpc`
after running `_broadcast_to_followers` and before the reaching dtor
of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`,
which holds the GIL and wait for the leader to join. However, the
leader is waiting for the response from `_broadcast_to_followers`,
which is blocked by the dtor of `UnpickledPythonCall`. And hence
the deadlock. This commit drops the GIL in `ProcessGroup` ctor.
3. After fixing (2), I found that `TensorPipe` backend
nondeterministically fails with `test_local_shutdown`, due to a
similar reason as (2), but this time it is that `shutdown()` on a
follower runs before the leader finishes `init_rpc`. This commit
adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during `ProcessGroup` ctor, I
made that change too.

Differential Revision: [D23825592](https://our.internmc.facebook.com/intern/diff/D23825592)

[ghstack-poisoned]
Fixes #45082

Found a few problems while working on #44983

1. We deliberately swallow RPC timeouts during shutdown, as we haven't
found a good way to handle those. When we convert `_wait_all_workers`
into `_all_gather`, the same logic was inherited. However, as
`_all_gather` meant to be used in more general scenarios, we should
no longer keep silent about errors. This commit let the error throw
in `_all_gather` and also let `shutdown()` to catch them and log.
2. After fixing (1), I found that `UnpickledPythonCall` needs to
acquire GIL on destruction, and this can lead to deadlock when used
in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a
synchronization point which holds GIL. In `init_rpc`, followers
(`rank != 0`) can exit before the leader (`rank == 0`). If the two
happens together, we could get a) on a follower, it exits `init_rpc`
after running `_broadcast_to_followers` and before the reaching dtor
of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`,
which holds the GIL and wait for the leader to join. However, the
leader is waiting for the response from `_broadcast_to_followers`,
which is blocked by the dtor of `UnpickledPythonCall`. And hence
the deadlock. This commit drops the GIL in `ProcessGroup` ctor.
3. After fixing (2), I found that `TensorPipe` backend
nondeterministically fails with `test_local_shutdown`, due to a
similar reason as (2), but this time it is that `shutdown()` on a
follower runs before the leader finishes `init_rpc`. This commit
adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during `ProcessGroup` ctor, I
made that change too.

Differential Revision: [D23825592](https://our.internmc.facebook.com/intern/diff/D23825592)

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Sep 22, 2020
Fixes #45082

Found a few problems while working on #44983

1. We deliberately swallow RPC timeouts during shutdown, as we haven't
found a good way to handle those. When we convert `_wait_all_workers`
into `_all_gather`, the same logic was inherited. However, as
`_all_gather` meant to be used in more general scenarios, we should
no longer keep silent about errors. This commit let the error throw
in `_all_gather` and also let `shutdown()` to catch them and log.
2. After fixing (1), I found that `UnpickledPythonCall` needs to
acquire GIL on destruction, and this can lead to deadlock when used
in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a
synchronization point which holds GIL. In `init_rpc`, followers
(`rank != 0`) can exit before the leader (`rank == 0`). If the two
happens together, we could get a) on a follower, it exits `init_rpc`
after running `_broadcast_to_followers` and before the reaching dtor
of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`,
which holds the GIL and wait for the leader to join. However, the
leader is waiting for the response from `_broadcast_to_followers`,
which is blocked by the dtor of `UnpickledPythonCall`. And hence
the deadlock. This commit drops the GIL in `ProcessGroup` ctor.
3. After fixing (2), I found that `TensorPipe` backend
nondeterministically fails with `test_local_shutdown`, due to a
similar reason as (2), but this time it is that `shutdown()` on a
follower runs before the leader finishes `init_rpc`. This commit
adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during `ProcessGroup` ctor, I
made that change too.

ghstack-source-id: 3184900
Pull Request resolved: #45088
@mrshenli
Copy link
Contributor Author

Failures in MacOS tests are already there on Master, see:

#45105 #40434 #40378

@facebook-github-bot
Copy link
Contributor

@mrshenli merged this pull request in 09e7f62.

@lw
Copy link
Contributor

lw commented Sep 22, 2020

The three known failures you linked to were only affecting the ProcessGroup agent. After this diff, those same tests have also started failing for the TensorPipe agent, with the same logs. I opened #45116, #45117 and #45118 to disable them.

@mrshenli
Copy link
Contributor Author

The three known failures you linked to were only affecting the ProcessGroup agent. After this diff, those same tests have also started failing for the TensorPipe agent, with the same logs. I opened #45116, #45117 and #45118 to disable them.

It's actualy failing for both ProcessGroup and TensorPipe, but we don't have an issue to track. See the test below. Thanks for opening those issues.

https://app.circleci.com/pipelines/github/pytorch/pytorch/215941/workflows/f4bc8989-0016-438b-b029-fa872c487791/jobs/7586991/steps

@mrshenli
Copy link
Contributor Author

Trying a fix for those 6 issues: #45126


try:
_tensorpipe_check_device_maps(agent, rpc_backend_options.device_maps)
agent.join()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me a bit sad, as I thought in principle we wanted to eventually get rid of the sync and join methods of the agent, and adding more usages will just make it harder. What is the reason for doing this actually? By this time there shouldn't be any RPCs in flight so joining seems to be unnecessary. The only "side effect" I see is that it adds a barrier between all agents. If that's why we use it, can we just make an explicit barrier instead of using join?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me a bit sad, as I thought in principle we wanted to eventually get rid of the sync and join methods of the agent, and adding more usages will just make it harder.

Agree, I am sad too.

What is the reason for doing this actually? By this time there shouldn't be any RPCs in flight so joining seems to be unnecessary.

There can be RPCs at this point. Assume there are two processes X and Y. Y could finish init_rpc earlier than X and then started sending RPCs to X. At this point, without this join, it is possible that X has not configured the reverse device map, and hence hit failures.

The only "side effect" I see is that it adds a barrier between all agents. If that's why we use it, can we just make an explicit barrier instead of using join?

Which barrier are you referring to? The dist.barrier()? Will that add another dependency on c10d? I cannot tell which one is worse. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for clarifying that there could be RPCs already in flight, I hadn't realized.

So the problem, at a high level, seems to me that we run in troubles if one rank exits init_rpc early and starts doing stuff while other ranks are potentially still inside init_rpc, right? If that's the case, can't we fix it by doing a barrier at the very end of init_rpc?

On one hand, that would nicely match what @pritamdamania is doing for DDP in #45181 (adding a barrier at the end of init_process_group and new_group). Also, since a barrier is a "weaker" operation than joining the agent (because the latter performs multiple barriers), I'd consider a barrier to be more readable as it makes it clear what the core issue is and what the essential solution is, without going "overkill".

And, if we're doing a barrier, we could use _all_gather and thus avoid using both the (deprecated?) agent.join() and without adding a dependency on c10d.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I would expect _all_gather to work too. I actually have one PR for that #44990. Let me land that one too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using _all_gather makes test_local_shutdown flaky. Will come back to debug this after branch cut.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants