Enabling gradient checkpointing in eval() mode#9878
Enabling gradient checkpointing in eval() mode#9878yiyixuxu merged 7 commits intohuggingface:mainfrom MikeTkachuk:enable_grckpt_in_eval
Conversation
|
Since all of the module implementations used the same |
| # 3. Transformer blocks | ||
| for i, block in enumerate(self.transformer_blocks): | ||
| if self.training and self.gradient_checkpointing: | ||
| if self.gradient_checkpointing: |
There was a problem hiding this comment.
oh thanks! why do we also removed the torch.is_grad_enabled() check? gradient checkpointing isn't meaningful without gradient being computed, no?
There was a problem hiding this comment.
added it back, thanks for pointing it out.
it does not break anything, but found that it throws an annoying warning when use_reentrant=True,
There was a problem hiding this comment.
but found that it throws an annoying warning when use_reentrant=True,
what do you mean by that?
There was a problem hiding this comment.
use_reentrant is an argument passed to torch.utils.checkpoint.checkpoint
if True one of the checks will print this to stderr
warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None"
)
but diffusers are using use_reentrant=False anyway
There was a problem hiding this comment.
oh got thanks, so the warning is specific to when we use gradient checkpointing when gradient is not enabled
|
hi @MikeTkachuk unfortunately we have to update the branch and resolve conflicts now... |
…le_grckpt_in_eval # Conflicts: # src/diffusers/models/controlnet_flux.py # src/diffusers/models/controlnet_sd3.py
|
done, also fixed in a few other places I missed |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
can you rebase again? |
|
fixed |
Removed unnecessary
if self.training ...check when using gradient checkpointing#9850