Check correct model type is passed to from_pretrained#10189
Check correct model type is passed to from_pretrained#10189hlky merged 26 commits intohuggingface:mainfrom
from_pretrained#10189Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): | ||
| if key not in passed_class_obj: | ||
| continue | ||
| class_name = passed_class_obj[key].__class__.__name__ | ||
| if class_name != expected_class_name: | ||
| raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") |
There was a problem hiding this comment.
Let's add a test for this too?
|
Thanks @hlky! This looks leaner than I had thought. thank you! |
|
Thanks @sayakpaul. I've added a test, trimmed |
| if key not in passed_class_obj or key == "scheduler": | ||
| continue |
There was a problem hiding this comment.
If we pass scheduler=text_encoder that should be errored out as well, right?
There was a problem hiding this comment.
Added some special handling for scheduler
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
tests/pipelines/test_pipelines.py
Outdated
| tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") | ||
| with self.assertRaises(ValueError) as error_context: | ||
| _ = StableDiffusionPipeline.from_pretrained( | ||
| "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer | ||
| ) | ||
|
|
||
| assert "Expected" in str(error_context.exception) | ||
| assert "text_encoder" in str(error_context.exception) | ||
| assert f"{tokenizer.__class__.__name}" in str(error_context.exception) |
There was a problem hiding this comment.
Maybe also a check for the scheduler as that is handled slightly differently?
There was a problem hiding this comment.
Will add it. For context this is what we're handling:
diffusers/src/diffusers/schedulers/scheduling_utils.py
Lines 33 to 48 in d041dd5
There was a problem hiding this comment.
That's cool. But I don't see the flow matching schedulers here. So, if I do assign a text encoder to scheduler in an RF pipeline (FluxPipeline, for example), would it still work as expected?
There was a problem hiding this comment.
Yes that also works, for pipelines like Flux we're getting the type
<class 'diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler'>For SD etc we get the enum
<enum 'KarrasDiffusionSchedulers'>[<KarrasDiffusionSchedulers.DDIMScheduler: 1>, <KarrasDiffusionSchedulers.DDPMScheduler: 2>, <KarrasDiffusionSchedulers.PNDMScheduler: 3>, <KarrasDiffusionSchedulers.LMSDiscreteScheduler: 4>, <KarrasDiffusionSchedulers.EulerDiscreteScheduler: 5>, <KarrasDiffusionSchedulers.HeunDiscreteScheduler: 6>, <KarrasDiffusionSchedulers.EulerAncestralDiscreteScheduler: 7>, <KarrasDiffusionSchedulers.DPMSolverMultistepScheduler: 8>, <KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler: 9>, <KarrasDiffusionSchedulers.KDPM2DiscreteScheduler: 10>, <KarrasDiffusionSchedulers.KDPM2AncestralDiscreteScheduler: 11>, <KarrasDiffusionSchedulers.DEISMultistepScheduler: 12>, <KarrasDiffusionSchedulers.UniPCMultistepScheduler: 13>, <KarrasDiffusionSchedulers.DPMSolverSDEScheduler: 14>, <KarrasDiffusionSchedulers.EDMEulerScheduler: 15>]So we apply the same processing (str, split, strip applies for type case) to get a list of scheduler
['FlowMatchEulerDiscreteScheduler']['DDIMScheduler', 'DDPMScheduler', 'PNDMScheduler', 'LMSDiscreteScheduler', 'EulerDiscreteScheduler', 'HeunDiscreteScheduler', 'EulerAncestralDiscreteScheduler', 'DPMSolverMultistepScheduler', 'DPMSolverSinglestepScheduler', 'KDPM2DiscreteScheduler', 'KDPM2AncestralDiscreteScheduler', 'DEISMultistepScheduler', 'UniPCMultistepScheduler', 'DPMSolverSDEScheduler', 'EDMEulerScheduler']If it's not a scheduler it will raise or if it's the wrong type of scheduler.
There was a problem hiding this comment.
Thanks for explaining! Works for me.
There was a problem hiding this comment.
We now also support Union, context is failed test test_load_connected_checkpoint_with_passed_obj for KandinskyV22CombinedPipeline, we also change scheduler type to Union[DDPMScheduler, UnCLIPScheduler], the test is actually for passing obj to submodels, but changing the scheduler is how that test works.
There was a problem hiding this comment.
Tests for wrong scheduler are added.
sayakpaul
left a comment
There was a problem hiding this comment.
Go to go from my side once the scheduler related tests are added. Thanks!
|
Note that we add |
| scheduler_types.extend([str(scheduler_type)]) | ||
| scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] | ||
|
|
||
| for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): |
There was a problem hiding this comment.
Since we already have the types extracted in expected_types can't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?
There was a problem hiding this comment.
I think it might be better to make this check more agnostic to the component names.
We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.
from diffusers import (
AnimateDiffPipeline,
UNetMotionModel,
)
unet = UNetMotionModel()
pipe = AnimateDiffPipeline.from_pretrained(
"hf-internal-testing/tiny-sd-pipe", unet=unet
)Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)
I would try something like:
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
if key not in passed_class_obj:
continue
class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
# Handle case where scheduler is still valid
# raise if scheduler is meant to be a Flow based scheduler?
elif not _is_valid_type:
raise ValueError(f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}.")
There was a problem hiding this comment.
Added this for scheduler
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__
if _requires_flow_match and not _is_flow_match:
raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
elif not _requires_flow_match and _is_flow_match:
raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")
There was a problem hiding this comment.
I think we don't need a value error here, a warning is enough, no?
There was a problem hiding this comment.
A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.
There was a problem hiding this comment.
let's do a warning then:)
There was a problem hiding this comment.
Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).
Here
if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.
| tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") | ||
| with self.assertRaises(ValueError) as error_context: | ||
| _ = StableDiffusionPipeline.from_pretrained( | ||
| "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer | ||
| ) | ||
|
|
||
| assert "is of type" in str(error_context.exception) | ||
| assert "but should be" in str(error_context.exception) |
There was a problem hiding this comment.
We're now using warning, but for this case, CLIPTokenizer in text_encoder we still get a ValueError later on from here
diffusers/src/diffusers/pipelines/pipeline_utils.py
Lines 893 to 897 in 6324340
So it's a little inconsistent and needs further testing to determine which other cases this already applies to.
|
|
||
| _is_valid_type = class_obj.__class__.__name__ in _expected_class_types | ||
| if isinstance(class_obj, SchedulerMixin) and not _is_valid_type: | ||
| _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) |
There was a problem hiding this comment.
I think checking against a FlowMatchSchedulers enum would be better in case we end up not using "FlowMatch" in the class name.
_requires_flow_match = any(class_type in FlowMatchSchedulers.__members__ for class_type in _expected_class_types)
_is_flow_match = class_obj.__class__.__name__ in FlowMatchSchedulerscc: @yiyixuxu
| _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) | ||
| _is_flow_match = "FlowMatch" in class_obj.__class__.__name__ | ||
| if _requires_flow_match and not _is_flow_match: | ||
| logger.warning(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") |
There was a problem hiding this comment.
Probably okay to raise an error here because scheduler.scale_noise would raise an error in the flow matching pipelines if a non-FlowMatch scheduler is used.
There was a problem hiding this comment.
I really don't want to raise any error here because type hint was not something enforced in this library and it is hard even for us to tell which schedulers can be used/cannot.
There was a problem hiding this comment.
e.g kandinsky if memory serves I think ddim may also works with some of the pipelines, and the compatibility may change
| elif not _requires_flow_match and _is_flow_match: | ||
| logger.warning(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") | ||
| elif not _is_valid_type: | ||
| logger.warning( |
There was a problem hiding this comment.
I think if it's not a scheduler and the types don't match it's okay to raise an error. I think it would break in the model loading step anyway in this case. wdyt @yiyixuxu?
There was a problem hiding this comment.
I prefer a warning because:
- I think there is very little /no benefits in raising an error vs a warning here
- in case we make a mistake in type hint, we will throw an error by mistake
There was a problem hiding this comment.
we just added use_flow_sigma to a few non-flow match schedulers with the SANA pr, and also we plan to refactor them but don't have a design finalized yet
given that, I think maybe we can skip checking for scheduler altogether for now, and revisit later. let me know what you guys think!
There was a problem hiding this comment.
I've removed scheduler related changes for now, I think we can revisit that later, as @yiyixuxu mentioned above type hints haven't been strictly enforced there are probably some missing/wrong, especially for schedulers. Warning is better because of that too, if there is some wrong type hint that makes its way into a release we'd have to issue a hotfix release to fix it, that just creates headaches and issue reports.
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
|
Thanks for taking it up, @hlky! To me it's a really nice QoL improvement from a DX perspective. |
…10189) * Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Example
Fixes #10093
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc @sayakpaul