Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| model = model_class_copy(**init_dict) | ||
| model.enable_gradient_checkpointing() | ||
|
|
||
| print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}") |
There was a problem hiding this comment.
Unrelated but my hands were itching.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for adding a method to save the LoRA adapter. Overall, this looks good, I have a few comments but no blockers.
| if prefix is not None: | ||
| keys = list(state_dict.keys()) | ||
| model_keys = [k for k in keys if k.startswith(f"{prefix}.")] |
There was a problem hiding this comment.
Just a better and more robust way to filter out the state dict.
| if len(state_dict) > 0: | ||
| if adapter_name in getattr(self, "peft_config", {}): | ||
| raise ValueError( | ||
| f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." | ||
| ) |
There was a problem hiding this comment.
Catching this error earlier than previous.
|
|
||
| if network_alphas is not None and len(network_alphas) >= 1: | ||
| alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] | ||
| alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] |
There was a problem hiding this comment.
Removing redundant conditions.
| if lora_config_kwargs["use_dora"]: | ||
| if is_peft_version("<", "0.9.0"): | ||
| raise ValueError( | ||
| "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." | ||
| ) | ||
| else: | ||
| lora_config_kwargs.pop("use_dora") | ||
| if is_peft_version("<", "0.9.0"): | ||
| lora_config_kwargs.pop("use_dora") |
There was a problem hiding this comment.
Breaking the conditionals to be more explicit.
| assert new_output.sample.shape == (4, 4, 16, 16) | ||
|
|
||
| @require_peft_backend | ||
| def test_lora(self): |
There was a problem hiding this comment.
Unneeded to test here now.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for this PR, LGTM. I assume the docstring will be completed before merging.
|
@yiyixuxu could you give this a review? |
|
why isn't this merged yet? |
|
Pending reviews from Yiyi. It will take a bit as she is off for some days. |
|
thanks! |
What does this PR do?
Complementing #9712, this PR adds a
save_lora_adapter()for the models that support LoRA loading. It also adds tests to ensure things don't break.Additionally, this PR:
load_attn_procs()method when it tries to load a LoRA state dict (and adds tests for it).load_attn_procs()method withload_lora_adapter()insrc/diffusers/loaders/lora_pipeline.py. I have run the integration tests for the SD and SDXL LoRAs (as those are impacted by this change), and the tests passed.Additionally, just to be sure, I have run the integration tests under
tests/lorato ensure they pass.