-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add lightweight reparametrization for _stateless calls
#68969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slowFor more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 4be2c41 (more details on the Dr. CI page):
🕵️ 14 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| Job | Step | Action |
|---|---|---|
| Checkout PyTorch | 🔁 rerun | |
| Checkout PyTorch | 🔁 rerun |
ci.pytorch.org: 1 failed
This comment was automatically generated by Dr. CI (expand for details).
Please report bugs/suggestions to the (internal) Dr. CI Users group.
_stateless calls
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting!
How does this compose with other parametrization now?
b6103e8 to
1a9c2b9
Compare
|
@albanD thanks!!!
I just fixed a small bug relating to this. Currently the module maintains the original behavior, this means that the parameter passed as an argument to the functional call will be used instead of the original parametrization for that attribute. This is We can also add a mode in which the parameterizations are kept as they are and never replaced, or we can try to introspect the parameterizations and replace the parameter inside, but this feels very hacky. |
|
@emcastillo I'll leave the actual review to Alban, but just wanted to say that this is awesome, and we'd be glad to change functorch to using this after our performance concerns have been resolved :) |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this.
I do agree this is a solution that will work even though we do loose a little bit of flexibility: in particular we enforce that stateless ignores all parametrization which was not the case before. It is an open question if we want that though?
I am still not convinced though that parametrization cannot be sped up to be similar to this and I think it will be a generally useful thing to do.
@Chillee is this perf improvement enough that you can use this for functorch? If so, we can add this as a temporary fix and then move back to parametrization when the perf gap there has been solved?
torch/nn/utils/_stateless.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! I should fix this comment
torch/nn/utils/_stateless.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super().__getattribute__ here right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that led to infinite recursion :D, so I went with the base object method. (probably I Just did something wrong as usual)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tried again to fix it and it worked now, seems I was messing with something else originally :D
Actually, this is the current behavior in master branch if I am not mistaken, we just register a parametrization that acts like an identity function for the value of the parameter passed as a tensor. So in essence, we are replacing the previously registered parameterizations and returning the parameter that we pass in the state dict. I think it is possible to call the actual parametrization in this PR approach, do we want that?
I tried to change the parametrization code to support spare tensors along with modules, but all the typing annotations made the code to be a cludge, I ended up with this design because it did the same thing and is cleaner. BTW thanks for the comments! while writing this, I realized several alternatives to improve this that I didn't consider before. |
1a9c2b9 to
6be2f19
Compare
|
@albanD, I just pushed support to apply the existing parameterizations to a parameter via a kwarg. I think this should solve your main concern! If we have a parametrization over Also, I cleaned up old comments and add type declarations. |
e6ff61e to
dab2ca5
Compare
|
Just measured the same example with functorch and these are the results
Seems that this PR is slightly faster, This PR will re-create the functional module on every call to avoid dealing with shared-state as an initial requirement for #61447. import cupyx
import torch
import torchvision.models as models
from functorch import make_functional_with_buffers
def main():
model = models.resnet50(pretrained=True).cuda()
func, params, buffers = make_functional_with_buffers(model)
func.__name__ = 'resnet_func'
x = torch.rand((128, 3, 224, 224)).cuda()
print('Non functional call')
print(cupyx.time.repeat(lambda: model(x), n_repeat=20))
print('functional call')
print(cupyx.time.repeat(func, (params, buffers, x), n_repeat=20))
if __name__ == "__main__":
main() |
|
Sorry, for the above comment I took the master |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good to me.
Just one question on wether we want the apply_parametrizations but that's it.
torch/nn/utils/_stateless.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: false -> False
|
btw, this is a repro for the issue I was running into before. (needs this PR to handle duplicate params: #71542) Currently throws Would be great if we could verify this works now, but not blocking (since it's broken anyways right now on master...) |
For posterity: this is broken on master because of the weight tying - calling Since this PR changes the logic to avoid using the parametrization mechanism in torch.nn.utils underneath, I expect it to fix this issue. Definitely agree we should have a test for the weight-tied case once #71542 lands. |
|
@Chillee @jbschlosser I just tested this PR together with #71542 in the code snippet above and I confirm the error is gone! |
dab2ca5 to
58a45c1
Compare
|
@albanD, review comment addressed! this should be ready to ship :) |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
|
@albanD has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot ciflow rerun |
|
This command didn't do anything. |
Summary: #61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: #68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d
Summary: #61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: #68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d (cherry picked from commit fd4b6bd)
Summary: pytorch/pytorch#61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: pytorch/pytorch#68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d (cherry picked from commit fd4b6bd)
Summary: pytorch/pytorch#61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: pytorch/pytorch#68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d (cherry picked from commit fd4b6bd)
Summary: pytorch/pytorch#61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: pytorch/pytorch#68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d (cherry picked from commit fd4b6bd)
Summary: pytorch/pytorch#61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large. I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual `nn.Module` objects so this option was not feasible. resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters. Used script: https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5539 | 184909 | | 0 | 5561 | 184843 | | 25 | 11363 | 189236 | | 50 | 18716 | 195378 | | 75 | 22851 | 198641 | | 100 | 27441 | 202281 | This PR just swaps the `__getattr__` of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor. The execution times now are as follows: | % of parameters passed | CPU Time (us) | GPU Time (us) | |------------------------|---------------|---------------| | regular call | 5939 | 187533 | | 0 | 5899 | 187570 | | 25 | 8541 | 188953 | | 50 | 10045 | 189826 | | 75 | 11049 | 190344 | | 100 | 11911 | 190800 | | functorch with 100% params | 14014 | 191727 Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap. cc albanD zou3519 Pull Request resolved: pytorch/pytorch#68969 Reviewed By: george-qi Differential Revision: D33836360 Pulled By: albanD fbshipit-source-id: 532561f64b18ca14c6ae2d77dcacb339397a589d (cherry picked from commit fd4b6bd)
#61447 introduced a mechanism for performing functional calls in a model using the reparametrization API. However, the overhead introduced in a single call was too large.
I tried to address this by modifying the reparametrization code to support spare tensors, but the changes needed were too large due to type checking and several parts of the code expecting actual
nn.Moduleobjects so this option was not feasible.resnet50 and call functional with a parameters dict covering the 0, 25, 50, and 100% of the model total parameters.
Used script:
https://gist.github.com/emcastillo/f344a58638bd71d130c71c45f86f0c3a
This PR just swaps the
__getattr__of the submodules to look into a dict holding only the parameters when called, greatly reducing the burden of having to instantiate custom modules and calling forward to just retrieve a tensor.The execution times now are as follows:
Now we see that the CPU time overhead is greatly reduced and the GPU time barely increases due to the effective overlap.
cc @albanD @zou3519