-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Support DDP ignored parameters in DDPOptimizer #88460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88460
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 95887c9: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx [ghstack-poisoned]
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx [ghstack-poisoned]
| DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore) | ||
| ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) | ||
| parameter_ids_to_ignore = [ | ||
| id(ddp_m.module.get_parameter(p)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so.. this seems better than the hacky fqn/mangled name thing. But, is it totally reliable? I wondered if there could be edge cases, or if dynamo would possibly make copies, etc.
torch/_dynamo/eval_frame.py
Outdated
| parameter_ids_to_ignore=[ | ||
| id(ddp_module.module.get_parameter(p)) | ||
| for p in ddp_module.parameters_to_ignore | ||
| ], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if these parameters are id stable, why not just annotate them directly? at the ddp level, whenever I add to parameters_to_ignore it should be something like:
def mark_parameter_as_ignored(module, name):
assert name in module.named_parameters()
ignored_parameter_list.append(name)
parameter = module.named_parameters()[name]
parameter._ignored = True
And then you don't need to leak your bookeeping of ignored_parameter_list anywhere else (You could even get rid of it, potentially).
And in dynamo, you would just do:
p.requires_grad and not p._ignored:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @voznesenskym, i think that's a great idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @aazzolini @mrshenli any issues with this approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM. This is also what we are proposing for the new annotation-based API: https://fb.quip.com/bpvPA6f2dtrA
The only thing is that, we might want this at parameter level (instead of module-level) for DDP to have parity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing is that, we might want this at parameter level (instead of module-level) for DDP to have parity.
I thought it already was at the parameter level? See my latest code, i think it's what you want.
But now i'm confused- in DDPOptimizer I simply ignore all buffers, since i thought the implication was they never require grad, and thus wouldn't be allreduced by DDP. If some buffers get allreduced by DDP, then i'd want to follow this up with another PR that tests buffers and gets that behavior right.
For now i've marked both params/buffers that are in the parameters_and_buffers_to_ignore list with the same marker on the DDP side, since that seems consistent with the convention there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD says we can't rely on parameter id's being stable, that mostly works, but there are a few edge cases where it doesn't. In particular, reparametrization cannot always preserve original parameter id.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does 'reparametrization' mean exactly?
i'm thinking it might be worth sticking with the current approach (marking params) as it is simple, and the consequences of getting it wrong are relatively minor (graph-breaks wouldn't exactly match ddp's buckets, so perf would degrade anywhere from a little bit to matching dynamo+ddp without graphbreaks.
but if there is another scheme that is not too complex i'd be open to it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reparametrization is when you register a rule to recompute parameter every time before it's used https://pytorch.org/tutorials/intermediate/parametrizations.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. i think i want to propose that we just land this as is.
- i'm not sure any of the users of DDP's ignored parameters flag today are also using parametrizations
- it wouldn't be catastrophic if ignored_parameters was not honored in ddp optimizer. (that is the defacto today)
- we could potentially revisit this later
Also, I'm curious-does dynamo+AOT handle parametrization on its own currently?
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx [ghstack-poisoned]
| for name, param in module.named_parameters(): | ||
| if name in params_and_buffers_to_ignore: | ||
| param._ddp_ignored = True | ||
| for name, buffer in module.named_buffers(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli i thought buffers by definition do not require grad, and therefore ddp ignores them by default?
if not i should update the logic in DDPOptimizer accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Buffers are by default broadcast right before the forward pass if broadcast_buffers=True
is passed to DDP constructor. This is true by default.
But then, if the buffer appear on the igore_parameters field, it's not part of the broadcast.
I think buffers shouldn't count for the purposes of splitting the model since we're not syncing them after the backward pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, this makes sense. then my change is ok- we mark them but we still ignore them in ddp optimizer.
| for name, param in module.named_parameters(): | ||
| if name in params_and_buffers_to_ignore: | ||
| param._ddp_ignored = True | ||
| for name, buffer in module.named_buffers(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Buffers are by default broadcast right before the forward pass if broadcast_buffers=True
is passed to DDP constructor. This is true by default.
But then, if the buffer appear on the igore_parameters field, it's not part of the broadcast.
I think buffers shouldn't count for the purposes of splitting the model since we're not syncing them after the backward pass.
| buckets[0].size += p.storage().nbytes() | ||
| # TODO correct FQ name? | ||
| buckets[0].params.append(f"{node}_{name}") | ||
| buckets[0].params.append(f"{node.target}_{name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is just useful for visualization purposes. see the debug output on the next PR in this stack. the buckets table is printed using this string.
aazzolini
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's land as is but can you please add a lot of comments on the code explaining where the logic breaks down and how we could solve it etc?
|
@pytorchbot merge -f "Flaky CI, no gpus available on gpu runner" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#88460 Approved by: https://github.com/aazzolini
Pull Request resolved: pytorch#88460 Approved by: https://github.com/aazzolini
Stack from ghstack (oldest at bottom):
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx