[feat] add load_lora_adapter() for compatible models#9712
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| text_encoder._hf_peft_config_loaded = None | ||
|
|
||
|
|
||
| def _fetch_state_dict( |
There was a problem hiding this comment.
Just taking it out of the class to be able to better reuse.
| return (is_model_cpu_offload, is_sequential_cpu_offload) | ||
|
|
||
| @classmethod | ||
| def _fetch_state_dict( |
There was a problem hiding this comment.
These are internal methods, so it should be okay to move them around. But would be good to run a quick Github search to see if they aren't being used directly somewhere? Just to sanity check that we don't backwards break anything.
There was a problem hiding this comment.
Valid. I deprecated and added tests.
|
@DN6 LMK what you think of the latest changes. Additionally, what do you think about the |
|
@BenjaminBossan could you give this a look too? |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for this refactor, always happy to see more lines being removed than added. I didn't check the functions that were moved around, as I assume they were left identical. Regarding the rest, just some smaller comments.
| _pipeline=self, | ||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||
| ) | ||
| transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} |
There was a problem hiding this comment.
For my understanding: This is a fix independent of the main change of this PR, right? Would it be possible to move this check inside of load_lora_into_transformer or would that not be a good idea?
There was a problem hiding this comment.
I think it kind of depends on the entrypoint to the underlying method.
We already have a similar check within load_lora_adapter():
diffusers/src/diffusers/loaders/peft.py
Line 153 in e187b70
So, I think it should be fine, without.
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | ||
| `default_{i}` where i is the total number of adapters being loaded. | ||
| Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: | ||
| low_cpu_mem_usage (`bool`, *optional*): |
There was a problem hiding this comment.
I saw that too, thx for fixing.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for addressing my comments, LGTM.
|
Ran the Flux integration tests and they pass. Failing tests are unrelated. |
* add first draft. * fix * updates. * updates. * updates * updates * updates. * fix-copies * lora constants. * add tests * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * docstrings. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
What does this PR do?
Similar to
load_attn_procs(), we want to have something similar for loading LoRAs into models, as the LoRA loading logic is generic.This way, we can reduce the LoC and have better maintainability. I am not too fixated on the
load_lora_adapter()name. Could also doload_adapter().@DN6 as discussed via Slack, could you give this a check? Could also add a
save_lora_adapter()method to complement.