Commit 08cf7b5
committed
Join-based API to support DDP uneven inputs
Pull Request resolved: #42577
Closes #38174. Implements a join-based API to support training with the DDP module in the scenario where different processes have different no. of inputs. The implementation follows the description in #38174. Details are available in the RFC, but as a summary, we make the following changes:
#### Approach
1) Add a context manager `torch.nn.parallel.distributed.join`
2) In the forward pass, we schedule a "present" allreduce where non-joined process contribute 1 and joined processes contribute 0. This lets us keep track of joined processes and know when all procs are joined.
3) When a process depletes its input and exits the context manager, it enters "joining" mode and attempts to "shadow" the collective comm. calls made in the model's forward and backward pass. For example we schedule the same allreduces in the same order as the backward pass, but with zeros
4) We adjust the allreduce division logic to divide by the effective world size (no. of non-joined procs) rather than the absolute world size to maintain correctness.
5) At the end of training, the last joined process is selected to be the "authoritative" model copy
We also make some misc. changes such as adding a `rank` argument to `_distributed_broadcast_coalesced` and exposing some getters/setters on `Reducer` to support the above changes.
#### How is it tested?
We have tests covering the following models/scenarios:
- [x] Simple linear model
- [x] Large convolutional model
- [x] Large model with module buffers that are broadcast in the forward pass (resnet). We verify this with a helper function `will_sync_module_buffers` and ensure this is true for ResNet (due to batchnorm)
- [x] Scenario where a rank calls join() without iterating at all, so without rebuilding buckets (which requires collective comm)
- [x] Model with unused params (with find unused parameters=True)
- [x] Scenarios where different processes iterate for a varying number of different iterations.
- [x] Test consistency in tie-breaking when multiple ranks are the last ones to join
- [x] Test that we divide by the effective world_size (no. of unjoined processes)
#### Limitations
1) This is only implemented for MPSD, not SPMD. Per a discussion with @mrshenli we want to encourage the use of MPSD over SPMD for DDP.
2) This does not currently work with SyncBN or custom collective calls made in the model's forward pass. This is because the `join` class only shadows the `broadcast` for buffers in the forward pass, the gradient allreduces in the bwd pass, unused parameters reduction, and (optionally) the rebuild buckets broadcasting in the backwards pass. Supporting this will require additional design thought.
3) Has not been tested with the [DDP comm. hook](#39272) as this feature is still being finalized/in progress. We will add support for this in follow up PRs.
ghstack-source-id: 109227369
Differential Revision: [D22893859](https://our.internmc.facebook.com/intern/diff/D22893859/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22893859/)!1 parent 8850fd1 commit 08cf7b5
File tree
9 files changed
+705
-51
lines changed- test/distributed
- torch
- csrc/distributed/c10d
- nn/parallel
- testing/_internal
9 files changed
+705
-51
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3544 | 3544 | | |
3545 | 3545 | | |
3546 | 3546 | | |
3547 | | - | |
| 3547 | + | |
3548 | 3548 | | |
3549 | 3549 | | |
3550 | 3550 | | |
| |||
3560 | 3560 | | |
3561 | 3561 | | |
3562 | 3562 | | |
3563 | | - | |
| 3563 | + | |
3564 | 3564 | | |
3565 | 3565 | | |
3566 | | - | |
| 3566 | + | |
| 3567 | + | |
| 3568 | + | |
| 3569 | + | |
3567 | 3570 | | |
3568 | 3571 | | |
3569 | 3572 | | |
3570 | 3573 | | |
3571 | | - | |
| 3574 | + | |
| 3575 | + | |
3572 | 3576 | | |
3573 | | - | |
| 3577 | + | |
| 3578 | + | |
3574 | 3579 | | |
3575 | 3580 | | |
3576 | 3581 | | |
3577 | 3582 | | |
3578 | 3583 | | |
3579 | 3584 | | |
3580 | 3585 | | |
3581 | | - | |
| 3586 | + | |
| 3587 | + | |
| 3588 | + | |
3582 | 3589 | | |
3583 | 3590 | | |
3584 | 3591 | | |
| |||
3588 | 3595 | | |
3589 | 3596 | | |
3590 | 3597 | | |
3591 | | - | |
| 3598 | + | |
| 3599 | + | |
| 3600 | + | |
3592 | 3601 | | |
3593 | 3602 | | |
3594 | 3603 | | |
| |||
3597 | 3606 | | |
3598 | 3607 | | |
3599 | 3608 | | |
3600 | | - | |
| 3609 | + | |
| 3610 | + | |
| 3611 | + | |
3601 | 3612 | | |
3602 | 3613 | | |
3603 | 3614 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
| 3 | + | |
3 | 4 | | |
4 | 5 | | |
| 6 | + | |
5 | 7 | | |
6 | 8 | | |
7 | 9 | | |
| |||
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
| 18 | + | |
16 | 19 | | |
17 | 20 | | |
18 | 21 | | |
| |||
85 | 88 | | |
86 | 89 | | |
87 | 90 | | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
88 | 99 | | |
89 | 100 | | |
90 | 101 | | |
| |||
2437 | 2448 | | |
2438 | 2449 | | |
2439 | 2450 | | |
| 2451 | + | |
| 2452 | + | |
| 2453 | + | |
| 2454 | + | |
| 2455 | + | |
| 2456 | + | |
| 2457 | + | |
| 2458 | + | |
| 2459 | + | |
| 2460 | + | |
| 2461 | + | |
| 2462 | + | |
| 2463 | + | |
| 2464 | + | |
| 2465 | + | |
| 2466 | + | |
| 2467 | + | |
| 2468 | + | |
| 2469 | + | |
| 2470 | + | |
| 2471 | + | |
| 2472 | + | |
| 2473 | + | |
| 2474 | + | |
| 2475 | + | |
| 2476 | + | |
| 2477 | + | |
| 2478 | + | |
| 2479 | + | |
| 2480 | + | |
| 2481 | + | |
| 2482 | + | |
| 2483 | + | |
| 2484 | + | |
| 2485 | + | |
| 2486 | + | |
| 2487 | + | |
| 2488 | + | |
| 2489 | + | |
| 2490 | + | |
| 2491 | + | |
| 2492 | + | |
| 2493 | + | |
| 2494 | + | |
| 2495 | + | |
| 2496 | + | |
| 2497 | + | |
| 2498 | + | |
| 2499 | + | |
| 2500 | + | |
| 2501 | + | |
| 2502 | + | |
| 2503 | + | |
| 2504 | + | |
| 2505 | + | |
| 2506 | + | |
| 2507 | + | |
| 2508 | + | |
| 2509 | + | |
| 2510 | + | |
| 2511 | + | |
| 2512 | + | |
| 2513 | + | |
| 2514 | + | |
| 2515 | + | |
| 2516 | + | |
| 2517 | + | |
| 2518 | + | |
| 2519 | + | |
| 2520 | + | |
| 2521 | + | |
| 2522 | + | |
| 2523 | + | |
| 2524 | + | |
| 2525 | + | |
| 2526 | + | |
| 2527 | + | |
| 2528 | + | |
| 2529 | + | |
| 2530 | + | |
| 2531 | + | |
| 2532 | + | |
| 2533 | + | |
| 2534 | + | |
| 2535 | + | |
| 2536 | + | |
| 2537 | + | |
| 2538 | + | |
| 2539 | + | |
| 2540 | + | |
| 2541 | + | |
| 2542 | + | |
| 2543 | + | |
| 2544 | + | |
| 2545 | + | |
| 2546 | + | |
| 2547 | + | |
| 2548 | + | |
| 2549 | + | |
| 2550 | + | |
| 2551 | + | |
| 2552 | + | |
| 2553 | + | |
| 2554 | + | |
| 2555 | + | |
| 2556 | + | |
| 2557 | + | |
| 2558 | + | |
| 2559 | + | |
| 2560 | + | |
| 2561 | + | |
| 2562 | + | |
| 2563 | + | |
| 2564 | + | |
| 2565 | + | |
| 2566 | + | |
| 2567 | + | |
| 2568 | + | |
| 2569 | + | |
| 2570 | + | |
| 2571 | + | |
| 2572 | + | |
| 2573 | + | |
| 2574 | + | |
| 2575 | + | |
| 2576 | + | |
| 2577 | + | |
| 2578 | + | |
| 2579 | + | |
| 2580 | + | |
| 2581 | + | |
| 2582 | + | |
| 2583 | + | |
| 2584 | + | |
| 2585 | + | |
| 2586 | + | |
| 2587 | + | |
| 2588 | + | |
| 2589 | + | |
| 2590 | + | |
| 2591 | + | |
| 2592 | + | |
| 2593 | + | |
| 2594 | + | |
| 2595 | + | |
| 2596 | + | |
| 2597 | + | |
| 2598 | + | |
| 2599 | + | |
| 2600 | + | |
| 2601 | + | |
| 2602 | + | |
| 2603 | + | |
| 2604 | + | |
| 2605 | + | |
| 2606 | + | |
| 2607 | + | |
| 2608 | + | |
| 2609 | + | |
| 2610 | + | |
| 2611 | + | |
| 2612 | + | |
| 2613 | + | |
| 2614 | + | |
| 2615 | + | |
| 2616 | + | |
| 2617 | + | |
| 2618 | + | |
| 2619 | + | |
| 2620 | + | |
| 2621 | + | |
| 2622 | + | |
| 2623 | + | |
| 2624 | + | |
| 2625 | + | |
| 2626 | + | |
| 2627 | + | |
| 2628 | + | |
| 2629 | + | |
| 2630 | + | |
| 2631 | + | |
| 2632 | + | |
| 2633 | + | |
| 2634 | + | |
| 2635 | + | |
| 2636 | + | |
| 2637 | + | |
| 2638 | + | |
| 2639 | + | |
| 2640 | + | |
| 2641 | + | |
| 2642 | + | |
| 2643 | + | |
| 2644 | + | |
| 2645 | + | |
| 2646 | + | |
| 2647 | + | |
| 2648 | + | |
| 2649 | + | |
| 2650 | + | |
| 2651 | + | |
| 2652 | + | |
| 2653 | + | |
| 2654 | + | |
| 2655 | + | |
| 2656 | + | |
| 2657 | + | |
| 2658 | + | |
| 2659 | + | |
| 2660 | + | |
| 2661 | + | |
| 2662 | + | |
| 2663 | + | |
| 2664 | + | |
| 2665 | + | |
| 2666 | + | |
| 2667 | + | |
| 2668 | + | |
| 2669 | + | |
| 2670 | + | |
| 2671 | + | |
| 2672 | + | |
| 2673 | + | |
| 2674 | + | |
| 2675 | + | |
| 2676 | + | |
| 2677 | + | |
| 2678 | + | |
| 2679 | + | |
| 2680 | + | |
| 2681 | + | |
| 2682 | + | |
| 2683 | + | |
| 2684 | + | |
| 2685 | + | |
| 2686 | + | |
| 2687 | + | |
| 2688 | + | |
| 2689 | + | |
2440 | 2690 | | |
2441 | 2691 | | |
2442 | 2692 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
| 17 | + | |
| 18 | + | |
18 | 19 | | |
19 | | - | |
20 | | - | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
21 | 25 | | |
22 | 26 | | |
23 | 27 | | |
| |||
51 | 55 | | |
52 | 56 | | |
53 | 57 | | |
54 | | - | |
| 58 | + | |
| 59 | + | |
55 | 60 | | |
56 | 61 | | |
57 | 62 | | |
| |||
71 | 76 | | |
72 | 77 | | |
73 | 78 | | |
74 | | - | |
| 79 | + | |
75 | 80 | | |
76 | 81 | | |
77 | 82 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
| 15 | + | |
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| |||
0 commit comments