Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 29, 2022

Stack from ghstack:

This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88040

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit d930488:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu pushed a commit to awgu/pytorch that referenced this pull request Nov 2, 2022
ghstack-source-id: 019499e
Pull Request resolved: pytorch#88040
awgu pushed a commit to awgu/pytorch that referenced this pull request Nov 2, 2022
ghstack-source-id: 35fae68
Pull Request resolved: pytorch#88040
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple of minor questions, thanks for adding this!

def test_training(self):
"""Tests training (forward, backward, optimizer)."""
device = torch.device("cuda")
local_model = Model(device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have support for composable FSDP + meta device? Is there a source of truth where we can find the feature set covered by composable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have support for composable FSDP + meta device?

I believe there should be because the composable FSDP constructor includes the same module materialization logic as the normal FSDP constructor.

Is there a source of truth where we can find the feature set covered by composable?

This is difficult to document right now since we are still prototyping. As I continue testing and thinking about the design, I may realize some sharp edges that prevent feature parity. I will try to stabilize soon. (same for use_orig_params=True)

elif _handles_key:
_assert_in_training_states(state, [TrainingState.IDLE])
allowed_states = [TrainingState.IDLE]
if _is_composable(state):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are allowed states different in composable vs non-composable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, note that this is TrainingState, which is per state: _FSDPState object, and not HandleTrainingState, which is per FlatParamHandle / FlatParameter. (_FSDPState is Union[_State, FullyShardedDataParallel], where _State is from torch/distributed/_composable/contract.py).

For composable, state represents the local FSDP root (without wrapping). Upon the first FlatParameter's pre-backward hook, the state will transition to FORWARD_BACKWARD. For any subsequent FlatParameter's pre-backward hooks, state.training_state will already be in FORWARD_BACKWARD. However, each FlatParamHandle's training state transitions like you expect (i.e. IDLE -> BACKWARD_PRE here).

This shows why I had to refactor TrainingState in an earlier PR. We have to stratify to accommodate the difference between state: _FSDPState and FlatParamHandle / FlatParameter.

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 2, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: pytorch#88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.
Pull Request resolved: pytorch#88040
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
@facebook-github-bot facebook-github-bot deleted the gh/awgu/176/head branch June 8, 2023 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (fsdp) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants