-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP] Mixed precision enablement #74452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit bf797d1 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
❄️ 1 failure tentatively classified as flakybut reruns have not yet been triggered to confirm:
|
CI Flow Status⚛️ CI FlowRuleset - Version:
|
Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) ghstack-source-id: 151743971 Pull Request resolved: #74452
Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: - [ ] nested model with same param / buffer / reduce dtypes - [ ] nested model with distinct param / buffer / reduce dtypes - [ ] model where buffer is a different type than parameter - [ ] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected) - [ ] test that summon_full_params summons params in full precision - [ ] tests that gradient was appropriate type in backwards pass. This is done by calling register_hook on tensor outputs of the model, open to better ways on testing this. - [ ] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). - [ ] Tests that after forward, the reduced precision param shard is freed, when reshard_after_forward=True. Also tests that it is on the right device. - [ ] Tests that after backward, the reduced precision param shard is freed, for both reshard_after_forward=True or False. Also tests that it is on the right device with respect to CPU offload Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
Pull Request resolved: #74452 Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 ghstack-source-id: 151746538 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)
Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: - [ ] nested model with same param / buffer / reduce dtypes - [ ] nested model with distinct param / buffer / reduce dtypes - [ ] model where buffer is a different type than parameter - [ ] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected) - [ ] test that summon_full_params summons params in full precision - [ ] tests that gradient was appropriate type in backwards pass. This is done by calling register_hook on tensor outputs of the model, open to better ways on testing this. - [ ] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). - [ ] Tests that after forward, the reduced precision param shard is freed, when reshard_after_forward=True. Also tests that it is on the right device. - [ ] Tests that after backward, the reduced precision param shard is freed, for both reshard_after_forward=True or False. Also tests that it is on the right device with respect to CPU offload Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: - [ ] nested model with same param / buffer / reduce dtypes - [ ] nested model with distinct param / buffer / reduce dtypes - [ ] model where buffer is a different type than parameter - [ ] Above cases with world size as 1 - [ ] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected) - [ ] Above case with world size as 1 - [ ] test that summon_full_params summons params in full precision - [ ] parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently - [ ] tests that gradient was appropriate type in backwards pass. This is done by calling register_hook on tensor outputs of the model, open to better ways on testing this. - [ ] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). - [ ] Tests that after forward, the reduced precision param shard is freed, when reshard_after_forward=True. Also tests that it is on the right device. - [ ] Tests that after backward, the reduced precision param shard is freed, for both reshard_after_forward=True or False. Also tests that it is on the right device with respect to CPU offload Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
Pull Request resolved: #74452 Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 ghstack-source-id: 151778777 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/)
Enables mixed_precision training for PT FSDP. ### High level overview - We add a `MixedPrecision` argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any `torch.dtype` and the `torch.dtype` does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms. ### Mixed precision for inputs - The root module simply casts inputs to the reduced precision at the beginning of the forward pass. ### Mixed precision for parameters - In _rebuild_full_params, if we need to cast parameters to reduced precision, we call `_cast_param_shards_to_dtype`. This allocates a `p._mp_shard` of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with. - After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed. - Full precision for parameters is restored when taking checkpoints and for summon_full_params. ### Mixed precision for gradients - Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in `_post_backward_hook`, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization. - After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards. ### Communication mixed precision - If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them. ### Buffers mixed precision - Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision _is_ restored for checkpoint with full_state_dict, and _not_ restored in summon_full_params. - See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision. ### Changes to _rebuild_full_param - Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting `summon_full_params` which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set `force_full_precision` based on whether we are in summon_full_params. - To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact) - One subtlety is in the case of calling `update_p_data`, we need to update above tuple _before_ and not after, because after `update_p_data` the full param has padding trimmed, and this will cause issues with `writeback`. - Finally, we don't necessarily call `all_gather` on `full_param_padded` anymore, i.e. particularly in the case of summon_full_param. This is because `full_param_padded` would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming `p._full_param_padded` is the full parameter. ### Changes to summon_full_params - ``summon_full_params`` mostly consumes the above return value from `_rebuild_full_param` and the way `writeback` is done and full parameters are freed is refactored. - For `writeback`, we can no longer assume that `p._full_param_padded` is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding `p._full_param_padded` to writeback. - For freeing full params, similar to above we cannot assume that `p._full_param_padded` is the full parameter as `_collect_local_params` did. Instead, we consume the return value from `_rebuild_full_params` which explicitly tells us whether we can free the parameter or not. ### How checkpoint works - For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above. - For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls `_lazy_init` which casts buffers to their reduced dtype. - After checkpointing, buffers are restored back to their reduced precision - Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work. ### Useful clarifications while reviewing the diff: ##### How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. ##### How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own `self.mixed_precision` config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision. - Similar to FairScale, integer buffers are not cast. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type. - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. ### Test coverage: - [x] nested model with same param / reduce dtypes - [x] nested model with distinct param / buffer / reduce dtypes - [x] model where buffer is a different type than parameter - [x] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected) - [x] After taking checkpoint, verified that buffers are back in the reduced precision - [x] test that summon_full_params summons params in full precision - [x] tests that gradient was appropriate type in backwards pass. This is done by patching `_reduce_scatter_base` to run the mixed precision checks. - [x] Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config. - [x] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). - [x] Tests that after forward, the reduced precision param shard is freed - [x] Tests that after backward, the reduced precision param shard is freed - [x] Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is _not_ restored to the full type. - [x] all of the above tests, but with reshard_after_forward=False i.e. zero-2 - [x] test that summon_full_params respects reshard_after_forward in the case of mixed precision as well - [x] parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision. - [x] tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1 Follow up work (#74515): [- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting. - [ ] Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR. - [ ] Test summon_full_params with reshard_after_forward (with and without mixed precision)](#74515) Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
Pull Request resolved: #74452 Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 ghstack-source-id: 152471974 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35000703/)!
Enables mixed_precision training for PT FSDP. ### High level overview - We add a `MixedPrecision` argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any `torch.dtype` and the `torch.dtype` does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms. ### Mixed precision for inputs - The root module simply casts inputs to the reduced precision at the beginning of the forward pass. ### Mixed precision for parameters - In _rebuild_full_params, if we need to cast parameters to reduced precision, we call `_cast_param_shards_to_dtype`. This allocates a `p._mp_shard` of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with. - After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed. - Full precision for parameters is restored when taking checkpoints and for summon_full_params. ### Mixed precision for gradients - Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in `_post_backward_hook`, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization. - After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards. ### Communication mixed precision - If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them. ### Buffers mixed precision - Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision _is_ restored for checkpoint with full_state_dict, and _not_ restored in summon_full_params. - See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision. ### Changes to _rebuild_full_param - Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting `summon_full_params` which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set `force_full_precision` based on whether we are in summon_full_params. - To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact) - One subtlety is in the case of calling `update_p_data`, we need to update above tuple _before_ and not after, because after `update_p_data` the full param has padding trimmed, and this will cause issues with `writeback`. - Finally, we don't necessarily call `all_gather` on `full_param_padded` anymore, i.e. particularly in the case of summon_full_param. This is because `full_param_padded` would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming `p._full_param_padded` is the full parameter. ### Changes to summon_full_params - ``summon_full_params`` mostly consumes the above return value from `_rebuild_full_param` and the way `writeback` is done and full parameters are freed is refactored. - For `writeback`, we can no longer assume that `p._full_param_padded` is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding `p._full_param_padded` to writeback. - For freeing full params, similar to above we cannot assume that `p._full_param_padded` is the full parameter as `_collect_local_params` did. Instead, we consume the return value from `_rebuild_full_params` which explicitly tells us whether we can free the parameter or not. ### How checkpoint works - For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above. - For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls `_lazy_init` which casts buffers to their reduced dtype. - After checkpointing, buffers are restored back to their reduced precision - Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work. ### Useful clarifications while reviewing the diff: ##### How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. ##### How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own `self.mixed_precision` config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision. - Similar to FairScale, integer buffers are not cast. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type. - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. ### Test coverage: - [x] nested model with same param / reduce dtypes - [x] nested model with distinct param / buffer / reduce dtypes - [x] model where buffer is a different type than parameter - [x] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected) - [x] After taking checkpoint, verified that buffers are back in the reduced precision - [x] test that summon_full_params summons params in full precision - [x] tests that gradient was appropriate type in backwards pass. This is done by patching `_reduce_scatter_base` to run the mixed precision checks. - [x] Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config. - [x] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). - [x] Tests that after forward, the reduced precision param shard is freed - [x] Tests that after backward, the reduced precision param shard is freed - [x] Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is _not_ restored to the full type. - [x] all of the above tests, but with reshard_after_forward=False i.e. zero-2 - [x] test that summon_full_params respects reshard_after_forward in the case of mixed precision as well - [x] parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision. - [x] tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1 Follow up work (#74515): [- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting. - [ ] Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR. - [ ] Test summon_full_params with reshard_after_forward (with and without mixed precision)](#74515) Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
Pull Request resolved: #74452 Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 ghstack-source-id: 152544162 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35000703/)!
Enables mixed_precision training for PT FSDP. ### High level overview - We add a `MixedPrecision` argument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support any `torch.dtype` and the `torch.dtype` does not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms. ### Mixed precision for inputs - The root module simply casts inputs to the reduced precision at the beginning of the forward pass. ### Mixed precision for parameters - In _rebuild_full_params, if we need to cast parameters to reduced precision, we call `_cast_param_shards_to_dtype`. This allocates a `p._mp_shard` of the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with. - After forwards/backwards passes, we have the full precision parameter shard in memory, and the mixed precision shard has been freed. - Full precision for parameters is restored when taking checkpoints and for summon_full_params. ### Mixed precision for gradients - Backward computation will occur in the reduced precision since activations/params/inputs were in reduced precision. As a result, in `_post_backward_hook`, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization. - After the backwards pass, we have the full precision gradient shard in memory and no reduced precision gradient shards. ### Communication mixed precision - If the mixed_precision config indicates a different reduction type under which to run _reduce_scatter_base, we cast gradients to this type before communicating them. ### Buffers mixed precision - Buffers are unsharded and are cast only once by the root module in forward pass, and remain in their reduced precision throughout the training / in between forward and backward passes. Their full precision _is_ restored for checkpoint with full_state_dict, and _not_ restored in summon_full_params. - See notes below for more details around differences on how PT FSDP vs FairScale implements support for buffer mixed precision. ### Changes to _rebuild_full_param - Changes are made to _rebuild_full_param to cast parameters to their reduced precision. The main complication is supporting `summon_full_params` which must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we set `force_full_precision` based on whether we are in summon_full_params. - To further support summon_full_params which also needs to free full parameters, we refactor _rebuild_full_params similar to FairScale to return a tuple(tensor, bool) which indicates if the tensor can be freed or not. The tensor possibly cannot be freed in the case of world_size == 1 when the parameter is not sharded as the resulting full param points to the original model parameter. Another case is when we're returning the full parameter and reshard_after_forward=False (because we need to ensure p._full_param_padded stays intact) - One subtlety is in the case of calling `update_p_data`, we need to update above tuple _before_ and not after, because after `update_p_data` the full param has padding trimmed, and this will cause issues with `writeback`. - Finally, we don't necessarily call `all_gather` on `full_param_padded` anymore, i.e. particularly in the case of summon_full_param. This is because `full_param_padded` would be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assuming `p._full_param_padded` is the full parameter. ### Changes to summon_full_params - ``summon_full_params`` mostly consumes the above return value from `_rebuild_full_param` and the way `writeback` is done and full parameters are freed is refactored. - For `writeback`, we can no longer assume that `p._full_param_padded` is the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcoding `p._full_param_padded` to writeback. - For freeing full params, similar to above we cannot assume that `p._full_param_padded` is the full parameter as `_collect_local_params` did. Instead, we consume the return value from `_rebuild_full_params` which explicitly tells us whether we can free the parameter or not. ### How checkpoint works - For full_state_dict checkpoint, parameters are checkpointed in full precision which happens automatically due to summoning them in full precision as explained above. - For buffers, in full_state_dict we explicitly cast buffers to their full precision before taking checkpoint. One subtlety is that we need to do this after we've entered summon_full_params context as summon_full_params calls `_lazy_init` which casts buffers to their reduced dtype. - After checkpointing, buffers are restored back to their reduced precision - Note that buffer checkpointing for local_state_dict is not tested at the moment and this is left as follow up work. ### Useful clarifications while reviewing the diff: ##### How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. ##### How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairscale we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - Note that one subtlety is the recursive call. We need to make sure that each submodule uses its own `self.mixed_precision` config instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision. - Similar to FairScale, integer buffers are not cast. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtypes) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). FairScale seems to assume all buffers have original type as fp32, but we maintain a mapping that remembers the actual type. - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. ### Test coverage: - [x] nested model with same param / reduce dtypes - [x] nested model with distinct param / buffer / reduce dtypes - [x] model where buffer is a different type than parameter - [x] nested model checkpoint, with verification that buffers and params are checkpointed in full precision (checks that force_full_precision in summon_full_params is respected) - [x] After taking checkpoint, verified that buffers are back in the reduced precision - [x] test that summon_full_params summons params in full precision - [x] tests that gradient was appropriate type in backwards pass. This is done by patching `_reduce_scatter_base` to run the mixed precision checks. - [x] Above test, but checks that we run reduce_scatter in the higher precision if specified by the mixed precision config. - [x] Tests that after backward, gradient and param shards are in the full precision (and on correct device for optimizer step to happen). - [x] Tests that after forward, the reduced precision param shard is freed - [x] Tests that after backward, the reduced precision param shard is freed - [x] Test that buffers remain in the reduced precision type after forward / backward, and are not affected by summon_full_param. Within summon_full_param the buffer is _not_ restored to the full type. - [x] all of the above tests, but with reshard_after_forward=False i.e. zero-2 - [x] test that summon_full_params respects reshard_after_forward in the case of mixed precision as well - [x] parametrize a few relevant tests in summon_full_params as summon_full_param uses rebuild_full_param a bit differently. In particular we make sure things work as expected if the rebuilt parameter is not p._full_param_padded which is the case in mixed precision. - [x] tests for world_size == 1 i.e. when the parameter is not sharded. Not adding this for initial enablement as all use cases in question have world_size > 1 Follow up work (#74515): [- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting. - [ ] Enhance test_fsdp_state_dict to checkpoint buffers and ensure dtypes are as expected. Although note that this is also already tested in this PR. - [ ] Test summon_full_params with reshard_after_forward (with and without mixed precision)](#74515) Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) [ghstack-poisoned]
Pull Request resolved: #74452 Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 ghstack-source-id: 152654758 Differential Revision: [D35000703](https://our.internmc.facebook.com/intern/diff/D35000703/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D35000703/)!
|
failure is unrelated: https://github.com/pytorch/pytorch/runs/5764193891?check_suite_focus=true |
Summary: Pull Request resolved: #74452 Useful clarifications while reviewing the diff: How fairscale implements MP for buffers: - Accepts buffer_dtype argument in ctor that is the dtype for computation for buffers. By default this is the compute_dtype. - During _lazy_init, for root module, _cast_buffers is called which casts buffers to buffer_dtype. - During state_dict, buffers are cast to torch.float32, then checkpoint is taken. They are restored back to buffer_dtype after that. How PT FSDP implements MP for buffers in this diff: - Rather than buffer_dtype in ctor, we accept MixedPrecision.buffer_dtype which is the compute type for buffers. - During lazy_init, similar to fairsacle we cast the buffers to the type given by the MP config. In the case of no mixed precision the default behavior is maintained. - In _cast_buffers, we remember the original buffer dtype into a member variable. We then may cast them to a new dtype if given by the user. - During state_dict, we use the above remembered type (stored as self._orig_buffer_dtype) and restore this type to the buffers prior to taking checkpoint. After state_dict, we restore it back to the casted type as buffers remain in this mixed precision type even after forward/backwards passes (so this is done for consistency). - The improvement here is that we remember and restore the correct dtype of buffer the model originally had. However we assume all buffers are of the same dtype, which can be relaxed depending on use cases. Why rebuild_full_params checks for summon_full_params training state: - summon_full_params needs to return the full module parameters in the original precision for checkpoint to work as expected (users don't want to checkpoint the fp16 params generally). Thus, _rebuild_full_params will do this check. This is exactly the same reasoning as "force_full_precision" arg in fairscale. - Concretely, if we're in summon_full_params, we 1) Don't cast shards to param_dtype, 2) all_gather with a full precision input rather than _full_param_paded. Test coverage: [ ] Test1 ghstack-source-id: 152654758 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D35000703 fbshipit-source-id: 4bd7937ff36bdb3afd60eda981afc9d8731b823a
|
Hey @rohan-varma. |
|
Looks like it's caused a failure in number of mixed precision tests: (see full logs ): |
|
This pull request has been reverted by 61e308974ec6c91df2ea6ebe894285635959b393. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
|
This pull request has been reverted by a98d1a5. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
|
This pull request has been reverted by a98d1a5. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Reland #74452 Issue was older nccl version does not support bf16. Will take an approach similar to #67843 to ensure test only runs with later nccl versions. Original commit changeset: 99295ea4ff02 Original Phabricator Diff: D35000703 Differential Revision: [D35287501](https://our.internmc.facebook.com/intern/diff/D35287501/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Enables mixed_precision training for PT FSDP.
High level overview
MixedPrecisionargument to FSDP API that allows user to control precision of inputs/parameters, gradient reduction, and buffers. We support anytorch.dtypeand thetorch.dtypedoes not have to be the same between the 3 mixed precision flags we support. The goal of mixed precision training is to provide peak memory reduction and faster training due to the full param and gradient being in a reduced precision during forward/backward pass. Further, we decouple reduction precision from param/grad/input precision to allow for additional experimentation for faster communication algorithms.Mixed precision for inputs
Mixed precision for parameters
_cast_param_shards_to_dtype. This allocates ap._mp_shardof the reduced precision type and copies into this shard in a separate stream, with synchronization taken care of. As a result, all_gather will thus happen in the reduced precision and we'll have a reduced precision full parameter to run the user's forward pass with.Mixed precision for gradients
_post_backward_hook, we are left with a full/unsharded gradient in the reduced precision. Note that at the end of _post_backward_hook, we ensure the gradient is cast back to full precision so that the optimizer step can occur in full precision, and we handle all necessary stream synchronization.Communication mixed precision
Buffers mixed precision
Changes to _rebuild_full_param
summon_full_paramswhich must actually ignore mixed precision and summon in full precision mode. As a result, at the beginning of the function we setforce_full_precisionbased on whether we are in summon_full_params.update_p_data, we need to update above tuple before and not after, because afterupdate_p_datathe full param has padding trimmed, and this will cause issues withwriteback.all_gatheronfull_param_paddedanymore, i.e. particularly in the case of summon_full_param. This is becausefull_param_paddedwould be the reduced precision type but we need to all_gather in full precision. This is also why _rebuild_full_param returns a list of full params to summon_full_params as we can no longer rely on assumingp._full_param_paddedis the full parameter.Changes to summon_full_params
summon_full_paramsmostly consumes the above return value from_rebuild_full_paramand the waywritebackis done and full parameters are freed is refactored.writeback, we can no longer assume thatp._full_param_paddedis the full shard that may have been modified (i.e. this is not the case for mixed_precision). As a result, we use the returned full parameters instead of hardcodingp._full_param_paddedto writeback.p._full_param_paddedis the full parameter as_collect_local_paramsdid. Instead, we consume the return value from_rebuild_full_paramswhich explicitly tells us whether we can free the parameter or not.How checkpoint works
_lazy_initwhich casts buffers to their reduced dtype.Useful clarifications while reviewing the diff:
How fairscale implements MP for buffers:
How PT FSDP implements MP for buffers in this diff:
self.mixed_precisionconfig instead of passing in this arg to the recursive call, because different submodules may disable mixed precision. One example is BatchNorm usually disables mixed precision.Test coverage:
_reduce_scatter_baseto run the mixed precision checks.Follow up work (#74515):
[- [ ] Test local_state_dict checkpoint works with mixed precision. In particular we have to be careful about buffer casting.
Differential Revision: D35000703