Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jan 25, 2023

Stack from ghstack (oldest at bottom):

When running compiled submods for the purpose of producing outputs to pass
to the compilation step for the next submod, we use fake parameters and
assume fake inputs, but we forgot to activate our fake_mode during execution.

This caused certain edge cases where tensors other than activations or parameters
got created during execution, such as scalar->tensor expansion in the case
of executing torch.where(tensor, scalar, scalar).

Also add a test and clarify behavior of DDPOptimizer via comments.

Fixes #92941
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

When running compiled submods for the purpose of producing outputs to pass
to the compilation step for the next submod, we use fake parameters and
assume fake inputs, but we forgot to activate our fake_mode during execution.

This caused certain edge cases where tensors other than activations or parameters
got created during execution, such as scalar->tensor expansion in the case
of executing torch.where(tensor, scalar, scalar).

Also add a test and clarify behavior of DDPOptimizer via comments.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92986

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cac8faa:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Jan 25, 2023
When running compiled submods for the purpose of producing outputs to pass
to the compilation step for the next submod, we use fake parameters and
assume fake inputs, but we forgot to activate our fake_mode during execution.

This caused certain edge cases where tensors other than activations or parameters
got created during execution, such as scalar->tensor expansion in the case
of executing torch.where(tensor, scalar, scalar).

Also add a test and clarify behavior of DDPOptimizer via comments.

ghstack-source-id: ce20aa8
Pull Request resolved: #92986
@wconstab wconstab added the release notes: distributed (ddp) release notes category label Jan 25, 2023
@wconstab
Copy link
Contributor Author

@voznesenskym Semi-related to this PR, is the code that infers a FakeMode from input tensors or else creates a new one (https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/distributed.py#L148) good enough going forward, or do we want to align this FakeMode better with one that lives in dynamo? I'm not sure what the current design is.

# Finally, we have to produce inputs for use compiling the next submodule,
# and these need to be FakeTensors, so we execute the module under fake_mode
with fake_mode:
return curr_submod(*new_args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is executed at runtime, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean.

Since torch dynamo defers its compilation until the first execution, then in a way yes, this code happens "at runtime".

But this code only happens as a part of the compilation flow, which in a simple (static model) scenario only happens once. The second time a user calls their compiled ddp model, none of this code should run, since we're not recompiling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you confused it with 'WrapperModule.forwrad' - that's the only piece of code in the whole `ddp_optimizer' file that I'd expect to run repeatedly on every runtime. (all it does is unwrap the tuple output from the compiled subgraph)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, wrappermodule.forward is the place I was thinking of.

This looks fine to me.

When Ed and I were working on it - it was very confusing which part of this was compile time, and which was runtime.

@voznesenskym
Copy link
Collaborator

@voznesenskym Semi-related to this PR, is the code that infers a FakeMode from input tensors or else creates a new one (https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/distributed.py#L148) good enough going forward, or do we want to align this FakeMode better with one that lives in dynamo? I'm not sure what the current design is.

fake_mode = fake_mode_from_tensors(example_inputs)

Is the right way to do things :) I added it for this purpose. The idea is that you do a best effort to get the current fake mode, and if there isn't one, you can make it. There's a few useful comments in that definition, one should say something along the lines of maybe having it always provide a fake_mode...

@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 25, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants