Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 27, 2022

Stack from ghstack:

This PR refactors and fixes _cast_buffers().

Before
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.

  • _cast_buffers(recurse=False) incorrectly casts all buffers, including those in submodules. This is because of this outer loop over self.modules():
  • There was a unit test that checked that buffers were cast as expected (test_mixed_precision_e2e_full_shard()). The unit test coincidentally passed because all modules shared the same buffer name "buffer". In _cast_buffers(), the dict mapping buffer name to original dtype is populated lazily (during _lazy_init()). However, the keys are unprefixed:
    for name, buf in module.named_buffers(recurse=False):
    if buf is None:
    continue
    buf = buf.to(device=device or self.compute_device)
    if name not in self._buffer_name_to_orig_dtype:
    self._buffer_name_to_orig_dtype[name] = buf.dtype
  • Thus, even though (1) _cast_buffers(recurse=False) was only called on the root and (2) self._buffer_name_to_orig_dtype had unprefixed names as keys, the unit test still passed because (1) _cast_buffers() still looped over all buffers despite recurse=False and (2) all submodules' buffers were named "buffer" and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

After
This PR separates _cast_buffers() into three methods: _get_buffers_and_dtypes_for_computation(), _get_buffers_and_dtypes_for_checkpoint(), and _cast_buffers_to_dtype_and_device(). This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for _cast_buffers_to_dtype_and_device() makes it clear exactly what buffers are being cast and to what dtype.

Both _get_...() functions assume that they are called on the root only for now. This coincides with the construction of _buffer_name_to_orig_dtype in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their _buffer_name_to_orig_dtype is populated but not used.) The dict's keys are clean since the buffer cast to original dtype happens in a summon_full_params() context, which cleans the names.

Follow-Ups

  • We can try to move _get_buffers_and_dtypes_for_checkpoint() into _state_dict_utils.py in a follow-up.
  • We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87935

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 88d3c09:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@awgu awgu changed the title [FSDP()][21/N] Refactor _buffer_name_to_orig_dtype computation [FSDP()][21/N] Refactor and fix _cast_buffers() Nov 1, 2022
awgu pushed a commit to awgu/pytorch that referenced this pull request Nov 1, 2022
ghstack-source-id: 6d46f28
Pull Request resolved: pytorch#87935
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L700
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L712-L717
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.


Note: We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.



[ghstack-poisoned]
awgu pushed a commit to awgu/pytorch that referenced this pull request Nov 2, 2022
ghstack-source-id: 0633b82
Pull Request resolved: pytorch#87935
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L700
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L712-L717
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.

**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

Pull Request resolved: pytorch#87935
Approved by: https://github.com/mrshenli
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This PR refactors and fixes `_cast_buffers()`.

**Before**
Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision.
- `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`:
https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L700
- There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed:
https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L712-L717
- Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly.

If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR.

**After**
This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype.

Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names.

**Follow-Ups**
- We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up.
- We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules).

Pull Request resolved: pytorch#87935
Approved by: https://github.com/mrshenli
@facebook-github-bot facebook-github-bot deleted the gh/awgu/165/head branch June 8, 2023 15:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (fsdp) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants