-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly #106886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…riendly [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106886
Note: Links to docs will display an error until the docs builds have been completed. ✅ 4 Unrelated FailuresAs of commit 44907c4: UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| registry_key = getattr(module, REGISTRY_KEY, None) | ||
| if registry_key is None: | ||
| default_registry: Dict[str, RegistryItem] = OrderedDict() | ||
| setattr(module, REGISTRY_KEY, default_registry) | ||
| return default_registry | ||
| else: | ||
| return registry_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed because setdefault NYI
| return _module_state_mapping.get(module, None) | ||
| if module in _module_state_mapping: | ||
| return _module_state_mapping[module] | ||
| else: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get w/ default - NYI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new code is worse perf though, you do the dict lookup twice, whereas with .get() you only need to do it once
|
|
||
| class GroupMember(metaclass=_WorldMeta): | ||
| NON_GROUP_MEMBER = object() | ||
| NON_GROUP_MEMBER = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comparison w/ object NYI - -100 is spiritually okay (any identity)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm.. yes, looks ok 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one also seems simple to fix directly in Dynamo
| # TODO(voz): Don't graph break on this | ||
| warnings.warn( | ||
| "An unexpected prefix is detected. This case " | ||
| " should only happen when using DMP with FSDP. " | ||
| f"prefix = {prefix}, " | ||
| f"submodule_name = {submodule_name}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warn NYI
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly LGTM. couple of questions
| """ | ||
| default_registry: Dict[str, RegistryItem] = OrderedDict() | ||
| return module.__dict__.setdefault(REGISTRY_KEY, default_registry) # type: ignore[call-overload] | ||
| registry_key = getattr(module, REGISTRY_KEY, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, but iiuc the LHS var here would be more aptly named registry as it is 'the registry' which is accessed by using the REGISTRY_KEY in the module's dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, good nit
| registry_key = getattr(module, REGISTRY_KEY, None) | ||
| if registry_key is None: | ||
| default_registry: Dict[str, RegistryItem] = OrderedDict() | ||
| setattr(module, REGISTRY_KEY, default_registry) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does your new version go through module's setattr method, whereas the previous one didn't? I am not sure if that's important, but sometimes subtle bugs creep in that way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they are identical, will test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will make perf worse, getattr/setattr is worse perf than dict ordinarily, but NN module has a fairly complicated __setattr__ handler which makes it worse. getattr access on NN module is famously slow (remember the conversation about CSE?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also just not that keen on this change, we should be in the business of modeling __dict__ properly...
| return _module_state_mapping.get(module, None) | ||
| if module in _module_state_mapping: | ||
| return _module_state_mapping[module] | ||
| else: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one LGTM
|
|
||
| class GroupMember(metaclass=_WorldMeta): | ||
| NON_GROUP_MEMBER = object() | ||
| NON_GROUP_MEMBER = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm.. yes, looks ok 👍
| if _rank_not_in_group(pg): | ||
| raise RuntimeError("Invalid process group specified") | ||
| pg_store = _world.pg_map.get(pg, None) | ||
| pg_store = _world.pg_map[pg] if pg in _world.pg_map else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
| f"submodule_name = {submodule_name}" | ||
| ) | ||
| if ( | ||
| not torch.distributed._functional_collectives.is_torchdynamo_compiling() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: where should we move this util? seems like it will take on a bigger life than functional_collectives
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not there, that's for sure! Theres 2 options I think:
- A top level distributed only flag in utils or top level moduel
__init__- liketorch.distributed.utils.is_compiling() - Remove it entirely from distributed, move it to a torch level util, since its a generally useful sort of check import -> check flag pattern other modules will want as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i'd be in favor of 2. i am wary of import cycles, that may be the very reason we (I?) put it there in the first place. So probably don't try to do it in this PR unless you're feeling lucky :)
| if ( | ||
| not torch.distributed._functional_collectives.is_torchdynamo_compiling() | ||
| ): | ||
| # TODO(voz): Don't graph break on this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have a plan to not graph break on warn? (not in this PR, what you did here looks good for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of - its the same as print. TODO is perhaps toooo promise-y? There's a few strategies we can do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to file an issue and then link to the issue, makes it easier to track
| with no_dispatch(): | ||
| tensor.record_stream(stream) | ||
| else: | ||
| tensor.record_stream(stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is no_dispatch just a problematic graph-break under compile? or something else?
what does compile do with tensor.record_stream anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The no_dispatch was added in #88014 cc @fegin
Looking over the PR, it looks like this is because we don't actually support Stream arguments in torch dispatch, so it just chokes. If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False), a better version of this would just be to check if there are any modes before disabling dispatch.
| # Check that all ranks plan to all-gather the same index parameters | ||
| for (r1, i1), (r2, i2) in itertools.combinations( | ||
| ( | ||
| if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this error-checking logic really safe to skip under compile?
in the practical sense, probably. but technically we're supposed to throw the same set of errors in PT2 as eager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its not amazing to skip, this also plagued me. One thing we can do is potentially pull this out into an op. It just didn't meet the importance bar for the first MVP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this is not important for the MVP. I do not know of any case where this error was raised (but that is also a good thing).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expect_true will work here too
| # We don't run a even queue for freeing under torch compile atm | ||
| # But maybe we need to? TODO(voz): Look into this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think its fine to diverge here, insofar as compile captures a coherent program that it can optimize. e.g. I don't care if compile uses the same sets of streams/events or if it totally ignores eager prefetching and does its own prefetching.
but I care that compiling fsdp doesn't fall on its face if we turn prefetching on and some other code hangs waiting for an event that will never be enqueued.
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me!
| # Check that all ranks plan to all-gather the same index parameters | ||
| for (r1, i1), (r2, i2) in itertools.combinations( | ||
| ( | ||
| if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this is not important for the MVP. I do not know of any case where this error was raised (but that is also a good thing).
…it dynamo friendly" [ghstack-poisoned]
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if you fix the nits, thanks!
…it dynamo friendly" [ghstack-poisoned]
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
…it dynamo friendly" [ghstack-poisoned]
|
Successfully rebased |
| free_event.record() | ||
| state._free_event_queue.enqueue(free_event) | ||
| if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): | ||
| # We don't run a even queue for freeing under torch compile atm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/even/event/?
| ) | ||
| if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): | ||
| # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2 | ||
| # tensor comparison control flow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You kind of want something like expect_true here; you can defer the equality check to runtime because it's purely for error checking.
|
@awgu calls the shots here, my comments are non blocking |
These are good comments, I will address |
…it dynamo friendly" [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):