Enable Gradient Checkpointing for UNet2DModel (New)#7201
Enable Gradient Checkpointing for UNet2DModel (New)#7201yiyixuxu merged 14 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
sorry @dg845 |
…ggingface#3675) * fix: assertion. * assertion fix.
* Fixing the global_step key not found * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
…point check and remove references to deprecated scale/lora_scale
… UNet2DConditionModel gradient checkpointing tests
|
@yiyixuxu resolved the merge conflicts and updated the code to match the way gradient checkpointing is currently done in |
|
@hlky can you take a look here and help merge in? |
hlky
left a comment
There was a problem hiding this comment.
Thanks! I've left a general comment and I'll look into the test failures, should be good to merge after CI turns green.
| def _set_gradient_checkpointing(self, module, value=False): | ||
| if hasattr(module, "gradient_checkpointing"): | ||
| module.gradient_checkpointing = value |
There was a problem hiding this comment.
I notice we use hasattr(module, "gradient_checkpointing") in around half the cases of _set_gradient_checkpointing and module names in the other half, no need to change it for now but we could look at making it uniform at some point.
| ) | ||
|
|
||
| def test_effective_gradient_checkpointing(self): | ||
| super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) |
There was a problem hiding this comment.
cc @sayakpaul time_proj.weight has no grad, other parameters are ok
* Port UNet2DModel gradient checkpointing code from huggingface#6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
* Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
What does this PR do?
This PR supports gradient checkpointing for
UNet2DModel.Successor to PR #6718.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul
@yiyixuxu