Skip to content

Commit fe99460

Browse files
committed
update config dict logic
1 parent a61a961 commit fe99460

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

src/diffusers/configuration_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
8989

9090
self.to_json_file(output_config_file)
9191
logger.info(f"ConfigMixinuration saved in {output_config_file}")
92+
9293

9394
@classmethod
9495
def get_config_dict(
@@ -182,35 +183,43 @@ def get_config_dict(
182183
logger.info(f"loading configuration file {config_file}")
183184
else:
184185
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
186+
187+
return config_dict
185188

189+
@classmethod
190+
def extract_init_dict(cls, config_dict, **kwargs):
186191
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
187192
expected_keys.remove("self")
188-
193+
import ipdb; ipdb.set_trace()
194+
init_dict = {}
189195
for key in expected_keys:
190196
if key in kwargs:
191197
# overwrite key
192-
config_dict[key] = kwargs.pop(key)
198+
init_dict[key] = kwargs.pop(key)
199+
elif key in config_dict:
200+
# use value from config dict
201+
init_dict[key] = config_dict.pop(key)
193202

194-
passed_keys = set(config_dict.keys())
195-
196-
unused_kwargs = kwargs
197-
for key in passed_keys - expected_keys:
198-
unused_kwargs[key] = config_dict.pop(key)
199203

204+
unused_kwargs = config_dict.update(kwargs)
205+
206+
passed_keys = set(init_dict.keys())
200207
if len(expected_keys - passed_keys) > 0:
201208
logger.warn(
202209
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
203210
)
204211

205-
return config_dict, unused_kwargs
212+
return init_dict, unused_kwargs
206213

207214
@classmethod
208215
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
209-
config_dict, unused_kwargs = cls.get_config_dict(
216+
config_dict = cls.get_config_dict(
210217
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
211218
)
212219

213-
model = cls(**config_dict)
220+
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
221+
222+
model = cls(**init_dict)
214223

215224
if return_unused_kwargs:
216225
return model, unused_kwargs

src/diffusers/pipeline_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
9797
else:
9898
cached_folder = pretrained_model_name_or_path
9999

100-
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
100+
config_dict = cls.get_config_dict(cached_folder)
101+
module = config_dict["_module"]
102+
class_name_ = config_dict["_class_name"]
103+
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
101104

102-
module = pipeline_kwargs.pop("_module", None)
103-
# TODO(Suraj) - make from hub import work
104-
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
105-
# Add Sylvains code from transformers
105+
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
106+
import ipdb; ipdb.set_trace()
106107

107108
init_kwargs = {}
108109

109-
for name, (library_name, class_name) in config_dict.items():
110+
for name, (library_name, class_name) in init_dict.items():
110111
importable_classes = LOADABLE_CLASSES[library_name]
111112

112113
if library_name == module:
@@ -131,6 +132,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
131132

132133
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
133134

134-
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
135+
135136
model = class_obj(**init_kwargs)
136137
return model

0 commit comments

Comments
 (0)