[LoRA] feat: support unload_lora_weights() for Flux Control.#10206
[LoRA] feat: support unload_lora_weights() for Flux Control.#10206
unload_lora_weights() for Flux Control.#10206Conversation
|
|
||
| current_param_weight = overwritten_params[f"{name}.weight"] | ||
| in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] | ||
| with torch.device("meta"): |
There was a problem hiding this comment.
Since we already pin torch version this is safe enough.
There was a problem hiding this comment.
Also cc: @a-r-r-o-w. Something we should consider doing in:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_pipeline.py#L2351-L2354
unload_lora_weights() for Flux Control.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Looks pretty good, well tested, no issues from my side.
|
@yiyixuxu @a-r-r-o-w could you give this a look? |
a-r-r-o-w
left a comment
There was a problem hiding this comment.
Thanks for supporting this! Changes look good to me
| logger.info(f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}.") | ||
|
|
||
| # For `unload_lora_weights()`. | ||
| overwritten_params[f"{current_module_name}.weight"] = module_weight |
There was a problem hiding this comment.
I think this would have a small but significant memory overhead. For inference purposes only with loras, maybe this could be made opt-out if we know we never want call unload_lora_weights. Not a blocker though and can be tackled in a different PR but lmk your thoughts
There was a problem hiding this comment.
Yeah this could be tackled with discard_original_layers. For now, I have added a note as a comment about it.
| original_module = torch.nn.Linear( | ||
| in_features, | ||
| out_features, | ||
| bias=bias, | ||
| dtype=module_weight.dtype, | ||
| ) | ||
|
|
||
| tmp_state_dict = {"weight": current_param_weight} | ||
| if module_bias is not None: | ||
| tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) | ||
| original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) | ||
| setattr(parent_module, current_module_name, original_module) |
There was a problem hiding this comment.
@a-r-r-o-w thanks for flagging the device assignment while initializing original_module. device takes priority so original_module was not getting initialized on "meta", rending the previous copy_() ops ineffective.
LMK what you think about the current changes (have run the corresponding tests on a GPU and they pass).
@DN6 LMK your comments here too.
* feat: support unload_lora_weights() for Flux Control. * tighten test * minor * updates * meta device fixes.
What does this PR do?
Fixes: #10202.
Will request for reviews from others later.