forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpipeline_utils.py
More file actions
189 lines (150 loc) · 7.46 KB
/
pipeline_utils.py
File metadata and controls
189 lines (150 loc) · 7.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import Optional, Union
from huggingface_hub import snapshot_download
from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt"
logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
},
}
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json"
def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
for name, module in kwargs.items():
# retrive library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_file = module.__module__.split(".")[-1]
pipeline_dir = module.__module__.split(".")[-2]
is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrive class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
# save model index config
self.register_to_config(**register_dict)
# set models
setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class)
if issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Add docstrings
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
# if the model is in a pipeline module, then we load it from the pipeline
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name)
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name))
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model