-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[stateless] add weight tying support #90477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90477
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 FailuresAs of commit d07d26f: FLAKY - The following jobs failed but were likely due to flakiness present on master:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This changed the algorithm from make_functional and from the original PR #87079, so some new numbers from that In running a script like this on resnet-18 (no tied parameters), this went from 58% slower than vanilla (not using functional_call at all) to 63% slower. This is a 3.5% slowdown from the functional_call before this PR Raw numbers in case anyone wants to see:
|
|
If I'm understanding the numbers, they are measuring some fixed number of iterations of resnet18 using the functional_call API. Shouldn't resnet be faster on GPU than CPU? |
[ghstack-poisoned]
Yep! They're running it 10 times and then this is the average time as reported
Yeah I didn't explain these tables well. From the docs, my understanding is that this reports the amount of time used by the CPU and GPU separately. In this case, what we're worried about is the CPU time since the code for the stateless call happens on CPU |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks correct, some questions
| parameters_and_buffers: Dict[str, Tensor], | ||
| args: Union[Any, Tuple], | ||
| kwargs: Dict[str, Any] = None, | ||
| tie_weights: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?
Easiest thing to do seems like:
- nn.utils.stateless.functional_call should retain the same behavior as before (tie_weights=False?)
- we deprecate nn.utils.stateless.functional_call in the next version of PyTorch (so, now on master)
- we introduce a new torch.func.functional_call (probably needs a better name) to replace it in the next version of PyTorch that has our preferred default (tie_weights=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?
Yep, this is BC-breaking on a beta feature. This is from earlier talks with @albanD where we figured the default behavior should be what most people expect (which, from the make_functional requests, seems to be to have the tied weights get changed together).
Since we are moving it to torch.func anyway, I'm fine to change the default back to match old behavior for now and break it when we do the move unless @albanD has other thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both can work really.
But it might be simple to really have the torch.func version be an alias to the old version.
And yes from earlier discussions, I think this BC-breaking is ok.
[ghstack-poisoned]
[ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code LGTM, minor comments. Algorithm feels a bit weird but I can't come up with something simpler
torch/nn/utils/stateless.py
Outdated
| parameters_and_buffers: Dict[str, Tensor], | ||
| args: Union[Any, Tuple], | ||
| kwargs: Dict[str, Any] = None, | ||
| tie_weights: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "weights" can refer to only parameters, "tie_weights" effectively ties both parameters and buffers. Is that a problem? If we think it's a problem, we can rename it tie_weights_and_buffers. I kind of like the shorter name (and we can document that it applies to both parameters and buffers), but open to suggestions.
Docstring for nn.utils.stateless.functional_call needs to be updated with new flag and details on what it does (unless the plan was to just update torch.func.functional_call)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm somewhat attached to just weights since the original paper calls it weight tying (granted there they are only tying the weight/parameters). However I hear you that it's ambiguous so the docstring comment explicitly calls out that it ties both parameters and buffers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
| def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None: | ||
|
|
||
| def _create_swap_params(params_and_buffers): | ||
| def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is tensor being used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unused but previously it was the original weight. Since it was harder to get for tied weights, I passed None and had to update the signature to match that. I can add a patch on top of this to remove that parameter since it appears unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
[ghstack-poisoned]
[ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
| def add_to_name_map(n: str, t: torch.Tensor): | ||
| # if the tensor hasn't been seen before, add it to the map | ||
| if t not in weight_to_name_and_tied_names: | ||
| weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for n in names, I am not sure what exact object the dict_keys object is. What is the complexity of this lookup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
O(1) according to the internet: https://stackoverflow.com/questions/17539367/python-dictionary-keys-in-complexity
cc zou3519 Chillee soumith [ghstack-poisoned]
|
@pytorchbot merge -f "failures from flaky test and unrelated mps test" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @zou3519 @Chillee @soumith