Skip to content

Conversation

@rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Mar 21, 2022

Stack from ghstack (oldest at bottom):

Enables mixed_precision training for PT FSDP.

High level overview

  • We add a MixedPrecision argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any torch.dtype and the torch.dtype does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms.

Mixed precision for inputs

  • The root module simply casts inputs to the reduced precision at the beginning of the forward pass.

Mixed precision for parameters

  • In _rebuild_full_params, if we need to cast parameters to reduced precision, we call _cast_param_shards_to_dtype. This allocates a p._mp_shard of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with.
  • After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed.
  • Full precision for parameters is restored when taking checkpoints and for summon_full_params.

Mixed precision for gradients

  • Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in _post_backward_hook, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization.
  • After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards.

Communication mixed precision

  • If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them.

Buffers mixed precision

  • Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision is restored for checkpoint with full_state_dict, and not restored in summon_full_params.
  • See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision.

Changes to _rebuild_full_param

  • Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting summon_full_params which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set force_full_precision based on whether we are in summon_full_params.
  • To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact)
  • One subtlety is in the case of calling update_p_data, we need to update above tuple before and not after, because after update_p_data the full param has padding trimmed, and this will cause issues with writeback.
  • Finally, we don't necessarily call all_gather on full_param_padded anymore, i.e. particularly in the case of summon_full_param. This is because full_param_padded would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming p._full_param_padded is the full parameter.

Changes to summon_full_params

  • summon_full_params mostly consumes the above return value from _rebuild_full_param and the way writeback is done and full parameters are freed is refactored.
  • For writeback, we can no longer assume that p._full_param_padded is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding p._full_param_padded to writeback.
  • For freeing full params, similar to above we cannot assume that p._full_param_padded is the full parameter as _collect_local_params did. Instead, we consume the return value from _rebuild_full_params which explicitly tells us whether we can free the parameter or not.

How checkpoint works

  • For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above.
  • For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls _lazy_init which casts buffers to their reduced dtype.
  • After checkpointing, buffers are restored back to their reduced precision
  • Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work.

Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
  • Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
  • During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
  • During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.
How PT FSDP implements MP for buffers in this diff:
  • Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
  • During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
  • Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own self.mixed_precision config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision.
  • Similar to FairScale, integer buffers are not cast.
  • In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
  • During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type.
  • The improvement here is that we remember and restore the correct dtype of buffer the model originally had.

Test coverage:

  • nested model with same param / reduce dtypes
  • nested model with distinct param / buffer / reduce dtypes
  • model where buffer is a different type than parameter
  • nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
  • After taking checkpoint, verified that buffers are back in the reduced precision
  • test that summon_full_params summons params in full precision
  • tests that gradient was appropriate type in backwards pass. This is done by patching _reduce_scatter_base to run the mixed precision checks.
  • Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config.
  • Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen).
  • Tests that after forward, the reduced precision param shard is freed
  • Tests that after backward, the reduced precision param shard is freed
  • Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is not restored to the full type.
  • all of the above tests, but with reshard_after_forward=False i.e. zero-2
  • test that summon_full_params respects reshard_after_forward in the case of mixed precision as well
  • parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision.
  • tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1

Follow up work (#74515):
[- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting.

  • Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR.
  • Test summon_full_params with reshard_after_forward (with and without mixed precision)](FSDP: Mixed Precision follow up work #74515)

Differential Revision: D35000703

Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 21, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit bf797d1 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See GitHub Actions build pull / linux-xenial-py3.7-clang7-asan / test (default, 2, 3, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun) ❄️

2022-03-31T03:13:27.4002551Z unknown file: Failure
2022-03-31T03:13:26.6544143Z �[0;32m[       OK ] �[mKernel.CatInputTypesPromotion (71 ms)
2022-03-31T03:13:26.6544520Z �[0;32m[ RUN      ] �[mKernel.CatAndInlineWithAConstantDim
2022-03-31T03:13:27.0318180Z �[0;32m[       OK ] �[mKernel.CatAndInlineWithAConstantDim (377 ms)
2022-03-31T03:13:27.0319569Z �[0;32m[ RUN      ] �[mKernel.CatWithEmptyInputs
2022-03-31T03:13:27.2631815Z �[0;32m[       OK ] �[mKernel.CatWithEmptyInputs (231 ms)
2022-03-31T03:13:27.2632173Z �[0;32m[ RUN      ] �[mKernel.CatWoConditionals
2022-03-31T03:13:27.3310139Z �[0;32m[       OK ] �[mKernel.CatWoConditionals (67 ms)
2022-03-31T03:13:27.3310500Z �[0;32m[ RUN      ] �[mKernel.OptimizeConditionals
2022-03-31T03:13:27.3887916Z �[0;32m[       OK ] �[mKernel.OptimizeConditionals (57 ms)
2022-03-31T03:13:27.3888234Z �[0;32m[ RUN      ] �[mKernel.Stack
2022-03-31T03:13:27.4002551Z unknown file: Failure
2022-03-31T03:13:27.4002948Z C++ exception with description "Expected to not find "\n" but found it
2022-03-31T03:13:27.4003432Z       for (int64_t k = 0ll; k < 2ll; k++) {
2022-03-31T03:13:27.4003769Z         for (int64_t l = 0ll; l < 3ll; l++) {
2022-03-31T03:13:27.4003986Z           for (int64_t m = 0ll; m < 6ll; m++) {
2022-03-31T03:13:27.4004288Z             aten_stack[(((108ll * i + 18ll * k) + m) + 36ll * j) + 6ll * l] = k==1ll ? (ty_1[((18ll * j + m) + 54ll * i) + 6ll * l]) : (tx_1[((18ll * j + m) + 54ll * i) + 6ll * l]);
2022-03-31T03:13:27.4004539Z           }
2022-03-31T03:13:27.4006659Z         }
2022-03-31T03:13:27.4007015Z From CHECK-NEXT: aten_stack
2022-03-31T03:13:27.4007236Z " thrown in the test body.
2022-03-31T03:13:27.4007503Z �[0;31m[  FAILED  ] �[mKernel.Stack (11 ms)

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 21, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/fa9262d9f540e933ee7fa0384b9ca0c6d47f4930/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
deploy-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-manywheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
windows-binary-libtorch-debug ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-libtorch-release ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-wheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-bionic-rocm4.5-py3.7-distributed ciflow/all, ciflow/linux, ciflow/rocm, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Mar 21, 2022
rohan-varma added a commit that referenced this pull request Mar 21, 2022
Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

ghstack-source-id: 151743971
Pull Request resolved: #74452
Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
- [ ] nested model with same param / buffer / reduce dtypes
- [ ] nested model with distinct param / buffer / reduce dtypes
- [ ] model where buffer is a different type than parameter
- [ ] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
- [ ] test that summon_full_params summons params in full precision
- [ ] tests that gradient was appropriate type in backwards pass. This is done by calling register_hook on tensor outputs of the model, open to better ways on testing this.
- [ ] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). 
- [ ] Tests that after forward, the reduced precision param shard is freed, when reshard_after_forward=True. Also tests that it is on the right device.
- [ ] Tests that after backward, the reduced precision param shard is freed, for both reshard_after_forward=True or False. Also tests that it is on the right device with respect to CPU offload

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 21, 2022
Pull Request resolved: #74452



Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1
ghstack-source-id: 151746538

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)
Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
- [ ] nested model with same param / buffer / reduce dtypes
- [ ] nested model with distinct param / buffer / reduce dtypes
- [ ] model where buffer is a different type than parameter
- [ ] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
- [ ] test that summon_full_params summons params in full precision
- [ ] tests that gradient was appropriate type in backwards pass. This is done by calling register_hook on tensor outputs of the model, open to better ways on testing this.
- [ ] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). 
- [ ] Tests that after forward, the reduced precision param shard is freed, when reshard_after_forward=True. Also tests that it is on the right device.
- [ ] Tests that after backward, the reduced precision param shard is freed, for both reshard_after_forward=True or False. Also tests that it is on the right device with respect to CPU offload

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
- [ ] nested model with same param / buffer / reduce dtypes
- [ ] nested model with distinct param / buffer / reduce dtypes
- [ ] model where buffer is a different type than parameter
- [ ] Above cases with world size as 1
- [ ] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
- [ ] Above case with world size as 1
- [ ] test that summon_full_params summons params in full precision
- [ ] parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently
- [ ] tests that gradient was appropriate type in backwards pass. This is done by calling register_hook on tensor outputs of the model, open to better ways on testing this.
- [ ] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). 
- [ ] Tests that after forward, the reduced precision param shard is freed, when reshard_after_forward=True. Also tests that it is on the right device.
- [ ] Tests that after backward, the reduced precision param shard is freed, for both reshard_after_forward=True or False. Also tests that it is on the right device with respect to CPU offload

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 21, 2022
Pull Request resolved: #74452



Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1
ghstack-source-id: 151778777

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)
@rohan-varma rohan-varma changed the title [WIP][FSDP] Mixed precision enablement [FSDP] Mixed precision enablement Mar 21, 2022
@rohan-varma rohan-varma marked this pull request as ready for review March 21, 2022 16:57
Enables mixed_precision training for PT FSDP.

### High level overview

- We add a `MixedPrecision` argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any `torch.dtype` and the `torch.dtype` does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms.

### Mixed precision for inputs
- The root module simply casts inputs to the reduced precision at the beginning of the forward pass.

### Mixed precision for parameters
- In _rebuild_full_params, if we need to cast parameters to reduced precision, we call `_cast_param_shards_to_dtype`. This allocates a `p._mp_shard` of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with.
- After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed.
- Full precision for parameters is restored when taking checkpoints and for summon_full_params.

### Mixed precision for gradients
- Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in `_post_backward_hook`, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization.
- After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards.

### Communication mixed precision
- If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them.

### Buffers mixed precision
- Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision _is_ restored for checkpoint with full_state_dict, and _not_ restored in summon_full_params.
- See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision.

### Changes to _rebuild_full_param
- Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting `summon_full_params` which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set `force_full_precision` based on whether we are in summon_full_params.
- To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact)
- One subtlety is in the case of calling `update_p_data`, we need to update above tuple _before_ and not after, because after `update_p_data` the full param has padding trimmed, and this will cause issues with `writeback`.
- Finally, we don't necessarily call `all_gather` on `full_param_padded` anymore, i.e. particularly in the case of summon_full_param. This is because `full_param_padded` would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming `p._full_param_padded` is the full parameter.

### Changes to summon_full_params
- ``summon_full_params`` mostly consumes the above return value from `_rebuild_full_param` and the way `writeback` is done and full parameters are freed is refactored.
- For `writeback`, we can no longer assume that `p._full_param_padded` is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding `p._full_param_padded` to writeback. 
- For freeing full params, similar to above we cannot assume that `p._full_param_padded` is the full parameter as `_collect_local_params` did. Instead, we consume the return value from `_rebuild_full_params` which explicitly tells us whether we can free the parameter or not.

### How checkpoint works
- For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above.
- For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls `_lazy_init` which casts buffers to their reduced dtype.
- After checkpointing, buffers are restored back to their reduced precision 
- Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work.
 
### Useful clarifications while reviewing the diff:

##### How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

#####  How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own `self.mixed_precision` config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision.
- Similar to FairScale, integer buffers are not cast.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type.
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. 


### Test coverage:
- [x] nested model with same param / reduce dtypes
- [x] nested model with distinct param / buffer / reduce dtypes
- [x] model where buffer is a different type than parameter
- [x] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
- [x] After taking checkpoint, verified that buffers are back in the reduced precision
- [x] test that summon_full_params summons params in full precision
- [x] tests that gradient was appropriate type in backwards pass. This is done by patching `_reduce_scatter_base` to run the mixed precision checks.
- [x] Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config.
- [x] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). 
- [x] Tests that after forward, the reduced precision param shard is freed
- [x] Tests that after backward, the reduced precision param shard is freed
- [x] Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is _not_ restored to the full type.
- [x]  all of the above tests, but with reshard_after_forward=False i.e. zero-2 
- [x]  test that summon_full_params respects reshard_after_forward in the case of mixed precision as well
- [x]  parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision.
- [x] tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1


Follow up work (#74515):
[- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting.
- [ ] Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR.
- [ ] Test summon_full_params with reshard_after_forward (with and without mixed precision)](#74515)


Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 29, 2022
Pull Request resolved: #74452



Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1
ghstack-source-id: 152471974

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35000703/)!
Enables mixed_precision training for PT FSDP.

### High level overview

- We add a `MixedPrecision` argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any `torch.dtype` and the `torch.dtype` does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms.

### Mixed precision for inputs
- The root module simply casts inputs to the reduced precision at the beginning of the forward pass.

### Mixed precision for parameters
- In _rebuild_full_params, if we need to cast parameters to reduced precision, we call `_cast_param_shards_to_dtype`. This allocates a `p._mp_shard` of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with.
- After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed.
- Full precision for parameters is restored when taking checkpoints and for summon_full_params.

### Mixed precision for gradients
- Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in `_post_backward_hook`, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization.
- After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards.

### Communication mixed precision
- If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them.

### Buffers mixed precision
- Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision _is_ restored for checkpoint with full_state_dict, and _not_ restored in summon_full_params.
- See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision.

### Changes to _rebuild_full_param
- Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting `summon_full_params` which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set `force_full_precision` based on whether we are in summon_full_params.
- To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact)
- One subtlety is in the case of calling `update_p_data`, we need to update above tuple _before_ and not after, because after `update_p_data` the full param has padding trimmed, and this will cause issues with `writeback`.
- Finally, we don't necessarily call `all_gather` on `full_param_padded` anymore, i.e. particularly in the case of summon_full_param. This is because `full_param_padded` would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming `p._full_param_padded` is the full parameter.

### Changes to summon_full_params
- ``summon_full_params`` mostly consumes the above return value from `_rebuild_full_param` and the way `writeback` is done and full parameters are freed is refactored.
- For `writeback`, we can no longer assume that `p._full_param_padded` is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding `p._full_param_padded` to writeback. 
- For freeing full params, similar to above we cannot assume that `p._full_param_padded` is the full parameter as `_collect_local_params` did. Instead, we consume the return value from `_rebuild_full_params` which explicitly tells us whether we can free the parameter or not.

### How checkpoint works
- For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above.
- For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls `_lazy_init` which casts buffers to their reduced dtype.
- After checkpointing, buffers are restored back to their reduced precision 
- Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work.
 
### Useful clarifications while reviewing the diff:

##### How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

#####  How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own `self.mixed_precision` config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision.
- Similar to FairScale, integer buffers are not cast.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type.
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. 


### Test coverage:
- [x] nested model with same param / reduce dtypes
- [x] nested model with distinct param / buffer / reduce dtypes
- [x] model where buffer is a different type than parameter
- [x] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
- [x] After taking checkpoint, verified that buffers are back in the reduced precision
- [x] test that summon_full_params summons params in full precision
- [x] tests that gradient was appropriate type in backwards pass. This is done by patching `_reduce_scatter_base` to run the mixed precision checks.
- [x] Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config.
- [x] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). 
- [x] Tests that after forward, the reduced precision param shard is freed
- [x] Tests that after backward, the reduced precision param shard is freed
- [x] Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is _not_ restored to the full type.
- [x]  all of the above tests, but with reshard_after_forward=False i.e. zero-2 
- [x]  test that summon_full_params respects reshard_after_forward in the case of mixed precision as well
- [x]  parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision.
- [x] tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1


Follow up work (#74515):
[- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting.
- [ ] Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR.
- [ ] Test summon_full_params with reshard_after_forward (with and without mixed precision)](#74515)


Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 30, 2022
Pull Request resolved: #74452



Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1
ghstack-source-id: 152544162

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35000703/)!
Enables mixed_precision training for PT FSDP.

### High level overview

- We add a `MixedPrecision` argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any `torch.dtype` and the `torch.dtype` does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms.

### Mixed precision for inputs
- The root module simply casts inputs to the reduced precision at the beginning of the forward pass.

### Mixed precision for parameters
- In _rebuild_full_params, if we need to cast parameters to reduced precision, we call `_cast_param_shards_to_dtype`. This allocates a `p._mp_shard` of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with.
- After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed.
- Full precision for parameters is restored when taking checkpoints and for summon_full_params.

### Mixed precision for gradients
- Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in `_post_backward_hook`, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization.
- After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards.

### Communication mixed precision
- If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them.

### Buffers mixed precision
- Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision _is_ restored for checkpoint with full_state_dict, and _not_ restored in summon_full_params.
- See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision.

### Changes to _rebuild_full_param
- Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting `summon_full_params` which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set `force_full_precision` based on whether we are in summon_full_params.
- To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact)
- One subtlety is in the case of calling `update_p_data`, we need to update above tuple _before_ and not after, because after `update_p_data` the full param has padding trimmed, and this will cause issues with `writeback`.
- Finally, we don't necessarily call `all_gather` on `full_param_padded` anymore, i.e. particularly in the case of summon_full_param. This is because `full_param_padded` would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming `p._full_param_padded` is the full parameter.

### Changes to summon_full_params
- ``summon_full_params`` mostly consumes the above return value from `_rebuild_full_param` and the way `writeback` is done and full parameters are freed is refactored.
- For `writeback`, we can no longer assume that `p._full_param_padded` is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding `p._full_param_padded` to writeback. 
- For freeing full params, similar to above we cannot assume that `p._full_param_padded` is the full parameter as `_collect_local_params` did. Instead, we consume the return value from `_rebuild_full_params` which explicitly tells us whether we can free the parameter or not.

### How checkpoint works
- For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above.
- For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls `_lazy_init` which casts buffers to their reduced dtype.
- After checkpointing, buffers are restored back to their reduced precision 
- Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work.
 
### Useful clarifications while reviewing the diff:

##### How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

#####  How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own `self.mixed_precision` config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision.
- Similar to FairScale, integer buffers are not cast.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type.
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. 


### Test coverage:
- [x] nested model with same param / reduce dtypes
- [x] nested model with distinct param / buffer / reduce dtypes
- [x] model where buffer is a different type than parameter
- [x] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected)
- [x] After taking checkpoint, verified that buffers are back in the reduced precision
- [x] test that summon_full_params summons params in full precision
- [x] tests that gradient was appropriate type in backwards pass. This is done by patching `_reduce_scatter_base` to run the mixed precision checks.
- [x] Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config.
- [x] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). 
- [x] Tests that after forward, the reduced precision param shard is freed
- [x] Tests that after backward, the reduced precision param shard is freed
- [x] Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is _not_ restored to the full type.
- [x]  all of the above tests, but with reshard_after_forward=False i.e. zero-2 
- [x]  test that summon_full_params respects reshard_after_forward in the case of mixed precision as well
- [x]  parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision.
- [x] tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1


Follow up work (#74515):
[- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting.
- [ ] Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR.
- [ ] Test summon_full_params with reshard_after_forward (with and without mixed precision)](#74515)


Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Mar 31, 2022
Pull Request resolved: #74452



Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.


Test coverage:
[ ] Test1
ghstack-source-id: 152654758

Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35000703/)!
@rohan-varma
Copy link
Contributor Author

facebook-github-bot pushed a commit that referenced this pull request Mar 31, 2022
Summary:
Pull Request resolved: #74452

Useful clarifications while reviewing the diff:

How fairscale implements MP for buffers:
- Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype.
- During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype.
- During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that.

How PT FSDP implements MP for buffers in this diff:
- Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers.
- During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained.
- In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user.
- During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency).
- The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases.

Why rebuild_full_params checks for summon_full_params training state:
- summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale.
- Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded.

Test coverage:
[ ] Test1
ghstack-source-id: 152654758

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D35000703

fbshipit-source-id: 4bd7937ff36bdb3afd60eda981afc9d8731b823a
@github-actions
Copy link
Contributor

Hey @rohan-varma.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@malfet
Copy link
Contributor

malfet commented Mar 31, 2022

Looks like it's caused a failure in number of mixed precision tests: (see full logs ):

======================================================================
ERROR [2.832s]: test_mixed_precision_e2e_full_shard_mp_diff_reduce_offload_true_prefetch_post_fp64 (__main__.TestFSDPMixedPrecisionSharded)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 484, in wrapper
    self._join_processes(fn)
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 703, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 748, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 601, in run_test
    getattr(self, test_name)()
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 486, in wrapper
    fn()
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 208, in instantiated_test
    test(self, **param_kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 131, in wrapper
    return func(*args, **kwargs)
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_fsdp_mixed_precision.py", line 349, in test_mixed_precision_e2e_full_shard
    self._run_test_mixed_precision_e2e(
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_fsdp_mixed_precision.py", line 271, in _run_test_mixed_precision_e2e
    loss.backward()
  File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 395, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1869, in _post_backward_hook
    dist._reduce_scatter_base(
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_fsdp_mixed_precision.py", line 211, in _reduce_scatter_base_validate_mp
    return orig_reduce_scatter(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2486, in _reduce_scatter_base
    work = group._reduce_scatter_base(output, input, opts)
RuntimeError: Input tensor data type is not supported for NCCL process group: BFloat16
Exception raised from getNcclDataType at /var/lib/jenkins/workspace/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:91 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7fefc6aba07b in /opt/conda/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xce (0x7fefc6ab5a4e in /opt/conda/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0xe122fb (0x7fefc7d5a2fb in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::_reduce_scatter_base(at::Tensor&, at::Tensor&, c10d::ReduceScatterOptions const&) + 0x889 (0x7fefc7d76029 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x87ab5c (0x7fefdc585b5c in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x20167a (0x7fefdbf0c67a in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x181f64 (0x55689188cf64 in /opt/conda/bin/python)
frame #7: _PyObject_MakeTpCall + 0x2df (0x5568918471bf in /opt/conda/bin/python)
frame #8: <unknown function> + 0xb62f8 (0x5568917c12f8 in /opt/conda/bin/python)
frame #9: <unknown function> + 0xfd9bf (0x5568918089bf in /opt/conda/bin/python)
frame #10: <unknown function> + 0x196593 (0x5568918a1593 in /opt/conda/bin/python)
frame #11: _PyFunction_Vectorcall + 0x244 (0x5568918a22d4 in /opt/conda/bin/python)
frame #12: _PyObject_Call + 0xba (0x556891850cfa in /opt/conda/bin/python)
frame #13: _PyEval_EvalFrameDefault + 0x25ff (0x5568918e278f in /opt/conda/bin/python)
frame #14: <unknown function> + 0x196593 (0x5568918a1593 in /opt/conda/bin/python)
frame #15: _PyFunction_Vectorcall + 0x244 (0x5568918a22d4 in /opt/conda/bin/python)
frame #16: <unknown function> + 0x197d47 (0x5568918a2d47 in /opt/conda/bin/python)
frame #17: <unknown function> + 0x7b5ee (0x5568917865ee in /opt/conda/bin/python)
frame #18: <unknown function> + 0x216ca3 (0x556891921ca3 in /opt/conda/bin/python)
frame #19: <unknown function> + 0xfdb46 (0x556891808b46 in /opt/conda/bin/python)
frame #20: <unknown function> + 0x196593 (0x5568918a1593 in /opt/conda/bin/python)
frame #21: _PyFunction_Vectorcall + 0x1d4 (0x5568918a2264 in /opt/conda/bin/python)
frame #22: _PyObject_Call + 0x1da (0x556891850e1a in /opt/conda/bin/python)
frame #23: _PyEval_EvalFrameDefault + 0x25ff (0x5568918e278f in /opt/conda/bin/python)
frame #24: <unknown function> + 0x196593 (0x5568918a1593 in /opt/conda/bin/python)
frame #25: _PyFunction_Vectorcall + 0x1d4 (0x5568918a2264 in /opt/conda/bin/python)
frame #26: <unknown function> + 0x197d47 (0x5568918a2d47 in /opt/conda/bin/python)
frame #27: <unknown function> + 0x7b5ee (0x5568917865ee in /opt/conda/bin/python)
frame #28: <unknown function> + 0x216ca3 (0x556891921ca3 in /opt/conda/bin/python)
frame #29: <unknown function> + 0xd7796 (0x5568917e2796 in /opt/conda/bin/python)
frame #30: PyObject_CallFunctionObjArgs + 0xa8 (0x556891963748 in /opt/conda/bin/python)
frame #31: torch::autograd::PyFunctionPostHook::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&) + 0xe2 (0x7fefdc1fd822 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #32: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x6ea (0x7fefd3fd62aa in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #33: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x5c4 (0x7fefd3fd7c94 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #34: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x99 (0x7fefd3fcf199 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #35: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x6c (0x7fefdc1f01cc in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #36: <unknown function> + 0xc9039 (0x7fefdf5f0039 in /opt/conda/bin/../lib/libstdc++.so.6)
frame #37: <unknown function> + 0x76db (0x7ff014b746db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #38: clone + 0x3f (0x7ff01489d61f in /lib/x86_64-linux-gnu/libc.so.6)


Process 1 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 601, in run_test
    getattr(self, test_name)()
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 486, in wrapper
    fn()
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 208, in instantiated_test
    test(self, **param_kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/testing/_internal/common_distributed.py", line 131, in wrapper
    return func(*args, **kwargs)
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_fsdp_mixed_precision.py", line 349, in test_mixed_precision_e2e_full_shard
    self._run_test_mixed_precision_e2e(
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_fsdp_mixed_precision.py", line 271, in _run_test_mixed_precision_e2e
    loss.backward()
  File "/opt/conda/lib/python3.9/site-packages/torch/_tensor.py", line 395, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1869, in _post_backward_hook
    dist._reduce_scatter_base(
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_fsdp_mixed_precision.py", line 211, in _reduce_scatter_base_validate_mp
    return orig_reduce_scatter(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2486, in _reduce_scatter_base
    work = group._reduce_scatter_base(output, input, opts)
RuntimeError: Input tensor data type is not supported for NCCL process group: BFloat16
Exception raised from getNcclDataType at /var/lib/jenkins/workspace/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:91 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7f440130a07b in /opt/conda/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xce (0x7f4401305a4e in /opt/conda/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0xe122fb (0x7f44025aa2fb in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::_reduce_scatter_base(at::Tensor&, at::Tensor&, c10d::ReduceScatterOptions const&) + 0x889 (0x7f44025c6029 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x87ab5c (0x7f4416dd5b5c in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x20167a (0x7f441675c67a in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x181f64 (0x55a4ec40ff64 in /opt/conda/bin/python)
frame #7: _PyObject_MakeTpCall + 0x2df (0x55a4ec3ca1bf in /opt/conda/bin/python)
frame #8: <unknown function> + 0xb62f8 (0x55a4ec3442f8 in /opt/conda/bin/python)
frame #9: <unknown function> + 0xfd9bf (0x55a4ec38b9bf in /opt/conda/bin/python)
frame #10: <unknown function> + 0x196593 (0x55a4ec424593 in /opt/conda/bin/python)
frame #11: _PyFunction_Vectorcall + 0x244 (0x55a4ec4252d4 in /opt/conda/bin/python)
frame #12: _PyObject_Call + 0xba (0x55a4ec3d3cfa in /opt/conda/bin/python)
frame #13: _PyEval_EvalFrameDefault + 0x25ff (0x55a4ec46578f in /opt/conda/bin/python)
frame #14: <unknown function> + 0x196593 (0x55a4ec424593 in /opt/conda/bin/python)
frame #15: _PyFunction_Vectorcall + 0x244 (0x55a4ec4252d4 in /opt/conda/bin/python)
frame #16: <unknown function> + 0x197d47 (0x55a4ec425d47 in /opt/conda/bin/python)
frame #17: <unknown function> + 0x7b5ee (0x55a4ec3095ee in /opt/conda/bin/python)
frame #18: <unknown function> + 0x216ca3 (0x55a4ec4a4ca3 in /opt/conda/bin/python)
frame #19: <unknown function> + 0xfdb46 (0x55a4ec38bb46 in /opt/conda/bin/python)
frame #20: <unknown function> + 0x196593 (0x55a4ec424593 in /opt/conda/bin/python)
frame #21: _PyFunction_Vectorcall + 0x1d4 (0x55a4ec425264 in /opt/conda/bin/python)
frame #22: _PyObject_Call + 0x1da (0x55a4ec3d3e1a in /opt/conda/bin/python)
frame #23: _PyEval_EvalFrameDefault + 0x25ff (0x55a4ec46578f in /opt/conda/bin/python)
frame #24: <unknown function> + 0x196593 (0x55a4ec424593 in /opt/conda/bin/python)
frame #25: _PyFunction_Vectorcall + 0x1d4 (0x55a4ec425264 in /opt/conda/bin/python)
frame #26: <unknown function> + 0x197d47 (0x55a4ec425d47 in /opt/conda/bin/python)
frame #27: <unknown function> + 0x7b5ee (0x55a4ec3095ee in /opt/conda/bin/python)
frame #28: <unknown function> + 0x216ca3 (0x55a4ec4a4ca3 in /opt/conda/bin/python)
frame #29: <unknown function> + 0xd7796 (0x55a4ec365796 in /opt/conda/bin/python)
frame #30: PyObject_CallFunctionObjArgs + 0xa8 (0x55a4ec4e6748 in /opt/conda/bin/python)
frame #31: torch::autograd::PyFunctionPostHook::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&) + 0xe2 (0x7f4416a4d822 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #32: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x6ea (0x7f440e8262aa in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #33: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x5c4 (0x7f440e827c94 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #34: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x99 (0x7f440e81f199 in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #35: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x6c (0x7f4416a401cc in /opt/conda/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #36: <unknown function> + 0xc9039 (0x7f4419e40039 in /opt/conda/bin/../lib/libstdc++.so.6)
frame #37: <unknown function> + 0x76db (0x7f444f3c46db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #38: clone + 0x3f (0x7f444f0ed61f in /lib/x86_64-linux-gnu/libc.so.6)

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 61e308974ec6c91df2ea6ebe894285635959b393. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by a98d1a5. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/525/head branch April 4, 2022 14:17
rohan-varma added a commit that referenced this pull request Apr 4, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Apr 4, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Apr 4, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Apr 4, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Apr 4, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Apr 4, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by a98d1a5. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

rohan-varma added a commit that referenced this pull request Apr 5, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Apr 5, 2022
Reland #74452

Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions.

Original commit changeset: 99295ea4ff02

Original Phabricator Diff: D35000703

Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/)

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants