Enable Gradient Checkpointing for UNet2DModel#6718
Enable Gradient Checkpointing for UNet2DModel#6718dg845 wants to merge 35 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
You're right, I missed this 😅. |
… and AttnUpBlock2D.
…ard for gradient checkpointing in AttnDownBlock2D and AttnUpBlock2D.
|
The As a note, in their current diffusers/src/diffusers/models/unets/unet_2d_blocks.py Lines 1045 to 1046 in d4c7ab7 So I have written the diffusers/src/diffusers/models/unets/unet_2d_blocks.py Lines 1072 to 1079 in e837857 This has the potential to cause problems if
diffusers/src/diffusers/models/unets/unet_2d_blocks.py Lines 1183 to 1188 in d4c7ab7 which seems wrong when Since |
I think this is still fine because |
sayakpaul
left a comment
There was a problem hiding this comment.
Just some nits, but looks very good. Nice test, too.
… positional arg when gradient checkpointing for AttnDownBlock2D/AttnUpBlock2D.
…checkpointing for CrossAttnDownBlock2D/CrossAttnUpBlock2D as well.
|
Regarding #6718 (comment): I think in this case the best short term solution is to use the standard In the long term, at least in def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forwardcould be used, and ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
scale=lora_scale,
return_dict=True,
**ckpt_kwargs,
)Note that |
|
@dg845 I think I don't quite follow the concern fully. Could you maybe try to demonstrate the issue with a simpler example?
Would like to see when this case arises. From what I understand, gradient checkpointing is used during training, and I would like to keep the legacy blocks as is until and unless absolutely necessary. This is why I am asking for a simpler example to understand the consequences. |
Sorry, I should have made it clear that the above follows from my belief that the My understanding is that in the original LoRA paper the LoRA scale parameter
I think in practice Similarly, if we look at unlike for something like dropout where the forward pass would be different depending on whether So in my view the discrepancy between the gradient checkpointing code and non-gradient checkpointing code in e.g.
Practically speaking, we might not consider the train-test mismatch that arises to be that bad, since we may want to tune the scaling of the LoRA update during inference anyway (e.g. if we are performing inference with multiple LoRAs simultaneously). |
|
That being said, perhaps it's better if I move the changes (especially to |
|
But this discussion is starting to deviate from the original topic of the PR a bit IMO.
^ this I agree. And maybe this could be handled first in a separate PR and then we revisit this PR. Does that work? |
Sounds good :). To be more precise, would something like this sound good to you?
|
…radient checkpointing for CrossAttnDownBlock2D/CrossAttnUpBlock2D as well." This reverts commit 8756be5.
…ions exactly parallel to CrossAttnDownBlock2D/CrossAttnUpBlock2D implementations.
|
Yeah that is right. |
|
I have updated the gradient checkpointing implementation in this PR to be exactly parallel to that of |
…a movement. (huggingface#6704) * load cumprod tensor to device Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> * fixing ci Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> * make fix-copies Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> --------- Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
…uggingface#6736) Fix bug in ResnetBlock2D.forward when not USE_PEFT_BACKEND and using scale_shift for time emb where the lora scale gets overwritten. Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Update train_diffusion_dpo.py Address huggingface#6702 * Update train_diffusion_dpo_sdxl.py * Empty-Commit --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…ss (huggingface#6762) * add is_flaky to test_model_cpu_offload_forward_pass * style * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
add ipo and hinge loss to dpo trainer
* update * update * updaet * add tests and docs * clean up * add to toctree * fix copies * pr review feedback * fix copies * fix tests * update docs * update * update * update docs * update * update * update * update
add missing param
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
move sigma to device Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
--------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
* add * remove transformer --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
…gface#6738) * harmonize the module structure for models in tests * make the folders modules. --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
* Update testing_utils.py * Update testing_utils.py
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
I think the PR is borked. Should we open a new PR instead? @dg845 |
|
Created a new PR with the changes at #7201. Will close this PR. |
* Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
* Port UNet2DModel gradient checkpointing code from huggingface#6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
* Port UNet2DModel gradient checkpointing code from #6718. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Vincent Neemie <92559302+VincentNeemie@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>

What does this PR do?
This PR enables gradient checkpointing for
UNet2DModelby setting the_supports_gradient_checkpointingflag toTrue. SinceUNet2DConditionModelhas_supports_gradient_checkpointing = True, it seems likeUNet2DModelshould support gradient checkpointing as well, unless I'm missing something.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten
@sayakpaul