Skip to content

Commit fe7d136

Browse files
correct dict
1 parent e660a05 commit fe7d136

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

src/diffusers/configuration_utils.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" ConfigMixinuration base class and utilities."""
17-
18-
19-
import copy
2017
import inspect
2118
import json
2219
import os
2320
import re
21+
from collections import OrderedDict
2422
from typing import Any, Dict, Tuple, Union
2523

2624
from huggingface_hub import hf_hub_download
@@ -63,10 +61,14 @@ def register_to_config(self, **kwargs):
6361
logger.error(f"Can't set {key} with value {value} for {self}")
6462
raise err
6563

66-
if not hasattr(self, "_dict_to_save"):
67-
self._dict_to_save = {}
64+
if not hasattr(self, "_internal_dict"):
65+
internal_dict = kwargs
66+
else:
67+
previous_dict = dict(self._internal_dict)
68+
internal_dict = {**self._internal_dict, **kwargs}
69+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
6870

69-
self._dict_to_save.update(kwargs)
71+
self._internal_dict = FrozenDict(internal_dict)
7072

7173
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
7274
"""
@@ -230,8 +232,7 @@ def __repr__(self):
230232

231233
@property
232234
def config(self) -> Dict[str, Any]:
233-
output = copy.deepcopy(self._dict_to_save)
234-
return output
235+
return self._internal_dict
235236

236237
def to_json_string(self) -> str:
237238
"""
@@ -240,7 +241,7 @@ def to_json_string(self) -> str:
240241
Returns:
241242
`str`: String containing all the attributes that make up this configuration instance in JSON format.
242243
"""
243-
config_dict = self._dict_to_save
244+
config_dict = self._internal_dict
244245
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
245246

246247
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
@@ -253,3 +254,39 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
253254
"""
254255
with open(json_file_path, "w", encoding="utf-8") as writer:
255256
writer.write(self.to_json_string())
257+
258+
259+
class FrozenDict(OrderedDict):
260+
def __init__(self, *args, **kwargs):
261+
# remove `None`
262+
args = (a for a in args if a is not None)
263+
kwargs = {k: v for k, v in kwargs if v is not None}
264+
265+
super().__init__(*args, **kwargs)
266+
267+
for key, value in self.items():
268+
setattr(self, key, value)
269+
270+
self.__frozen = True
271+
272+
def __delitem__(self, *args, **kwargs):
273+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
274+
275+
def setdefault(self, *args, **kwargs):
276+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
277+
278+
def pop(self, *args, **kwargs):
279+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
280+
281+
def update(self, *args, **kwargs):
282+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
283+
284+
def __setattr__(self, name, value):
285+
if hasattr(self, "__frozen") and self.__frozen:
286+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
287+
super().__setattr__(name, value)
288+
289+
def __setitem__(self, name, value):
290+
if hasattr(self, "__frozen") and self.__frozen:
291+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
292+
super().__setitem__(name, value)

src/diffusers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
338338
revision=revision,
339339
**kwargs,
340340
)
341-
model.register(name_or_path=pretrained_model_name_or_path)
341+
model.register_to_config(name_or_path=pretrained_model_name_or_path)
342342
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
343343
# Load model
344344
pretrained_model_name_or_path = str(pretrained_model_name_or_path)

src/diffusers/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def register_modules(self, **kwargs):
8888
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
8989
self.save_config(save_directory)
9090

91-
model_index_dict = self.config
91+
model_index_dict = dict(self.config)
9292
model_index_dict.pop("_class_name")
9393
model_index_dict.pop("_diffusers_version")
9494
model_index_dict.pop("_module")

tests/test_modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def __init__(
7373
new_obj = SampleObject.from_config(tmpdirname)
7474
new_config = new_obj.config
7575

76+
# unfreeze configs
77+
config = dict(config)
78+
new_config = dict(new_config)
79+
7680
assert config.pop("c") == (2, 5) # instantiated as tuple
7781
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
7882
assert config == new_config

0 commit comments

Comments
 (0)