Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Feb 13, 2023

Stack from ghstack (oldest at bottom):

GraphModules that were created during DDPOptimizer graph breaking
lacked compile_subgraph_reason, which caused an exception when
running .explain().

Now the reason is provided and users can use .explain() to find out
that DDPOptimizer is causing graph breaks.

Fixes #94579

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

GraphModules that were created during DDPOptimizer graph breaking
lacked `compile_subgraph_reason`, which caused an exception when
running .explain().

Now the reason is provided and users can use .explain() to find out
that DDPOptimizer is causing graph breaks.

Fixes #94579

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94749

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit f8025ad:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines 403 to 404
_, _, _, _, break_reasons, _ = torch._dynamo.explain(ddp_m, inputs)
self.assertEqual(len(break_reasons), 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, this might be a bit saner:

Suggested change
_, _, _, _, break_reasons, _ = torch._dynamo.explain(ddp_m, inputs)
self.assertEqual(len(break_reasons), 3)
explain_out = torch._dynamo.explain(ddp_m, inputs)
break_reasons = explain_out[4]
self.assertEqual(len(break_reasons), 3)

input_mod.recompile()
input_mod.compile_subgraph_reason = GraphCompileReason(
"DDPOptimizer intentional graph-break to improve communication overlap."
" Set `torch._dynamo.config.optimize_ddp = False` to disable",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a longer a note or comment on this flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea just added a pointer. good idea.

GraphModules that were created during DDPOptimizer graph breaking
lacked `compile_subgraph_reason`, which caused an exception when
running .explain().

Now the reason is provided and users can use .explain() to find out
that DDPOptimizer is causing graph breaks.

Fixes #94579

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@wconstab wconstab added the topic: not user facing topic category label Feb 13, 2023
@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 13, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 2, 2, macos-m1-12)

Details for Dev Infra team Raised by workflow job

GraphModules that were created during DDPOptimizer graph breaking
lacked `compile_subgraph_reason`, which caused an exception when
running .explain().

Now the reason is provided and users can use .explain() to find out
that DDPOptimizer is causing graph breaks.

Fixes #94579

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Feb 13, 2023
GraphModules that were created during DDPOptimizer graph breaking
lacked `compile_subgraph_reason`, which caused an exception when
running .explain().

Now the reason is provided and users can use .explain() to find out
that DDPOptimizer is causing graph breaks.

Fixes #94579

ghstack-source-id: fa51cd1
Pull Request resolved: #94749
@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_8-cuda11_7-with-pypi-cudnn-test / test

Details for Dev Infra team Raised by workflow job

@wconstab
Copy link
Contributor Author

@pytorchbot merge -f "Broken trunk"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/wconstab/97/head branch June 8, 2023 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants