Skip to content

Commit 07ffe73

Browse files
committed
Style
1 parent bb98a5b commit 07ffe73

File tree

11 files changed

+91
-96
lines changed

11 files changed

+91
-96
lines changed

models/vision/glide/convert_weights.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
from torch import nn
33

4-
from transformers import CLIPTextConfig, GPT2Tokenizer
5-
from diffusers import UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
4+
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, UNetGLIDEModel
65
from modeling_glide import GLIDE
6+
from transformers import CLIPTextConfig, GPT2Tokenizer
7+
78

89
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
910
state_dict = torch.load("base.pt", map_location="cpu")
@@ -22,7 +23,7 @@
2223
)
2324
model = CLIPTextModel(config).eval()
2425
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
25-
#tokenizer.save_pretrained("./glide-base")
26+
# tokenizer.save_pretrained("./glide-base")
2627

2728
hf_encoder = model.text_model
2829

@@ -51,11 +52,11 @@
5152
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
5253
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
5354

54-
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
55-
#with torch.no_grad():
55+
# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
56+
# with torch.no_grad():
5657
# outputs = model(**inputs)
5758

58-
#model.save_pretrained("./glide-base")
59+
# model.save_pretrained("./glide-base")
5960

6061
### Convert the UNet
6162

@@ -80,4 +81,4 @@
8081

8182
glide = GLIDE(unet=unet_model, noise_scheduler=scheduler, text_encoder=model, tokenizer=tokenizer)
8283

83-
glide.save_pretrained("./glide-base")
84+
glide.save_pretrained("./glide-base")

models/vision/glide/modeling_glide.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# limitations under the License.
1515

1616

17-
from diffusers import DiffusionPipeline, UNetGLIDEModel, ClassifierFreeGuidanceScheduler, CLIPTextModel
18-
from transformers import GPT2Tokenizer
17+
import numpy as np
18+
import torch
1919

2020
import tqdm
21-
import torch
22-
import numpy as np
21+
from diffusers import ClassifierFreeGuidanceScheduler, CLIPTextModel, DiffusionPipeline, UNetGLIDEModel
22+
from transformers import GPT2Tokenizer
2323

2424

2525
def _extract_into_tensor(arr, timesteps, broadcast_shape):
@@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
4040

4141
class GLIDE(DiffusionPipeline):
4242
def __init__(
43-
self,
44-
unet: UNetGLIDEModel,
45-
noise_scheduler: ClassifierFreeGuidanceScheduler,
46-
text_encoder: CLIPTextModel,
47-
tokenizer: GPT2Tokenizer
43+
self,
44+
unet: UNetGLIDEModel,
45+
noise_scheduler: ClassifierFreeGuidanceScheduler,
46+
text_encoder: CLIPTextModel,
47+
tokenizer: GPT2Tokenizer,
4848
):
4949
super().__init__()
50-
self.register_modules(unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer)
50+
self.register_modules(
51+
unet=unet, noise_scheduler=noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer
52+
)
5153

5254
def q_posterior_mean_variance(self, x_start, x_t, t):
5355
"""
@@ -129,7 +131,9 @@ def __call__(self, prompt, generator=None, torch_device=None):
129131
self.text_encoder.to(torch_device)
130132

131133
# 1. Sample gaussian noise
132-
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator)
134+
image = self.noise_scheduler.sample_noise(
135+
(1, self.unet.in_channels, 64, 64), device=torch_device, generator=generator
136+
)
133137

134138
# 2. Encode tokens
135139
# an empty input is needed to guide the model away from (
@@ -141,9 +145,7 @@ def __call__(self, prompt, generator=None, torch_device=None):
141145
t = torch.tensor([i] * image.shape[0], device=torch_device)
142146
mean, variance, log_variance, pred_xstart = self.p_mean_variance(self.unet, transformer_out, image, t)
143147
noise = self.noise_scheduler.sample_noise(image.shape)
144-
nonzero_mask = (
145-
(t != 0).float().view(-1, *([1] * (len(image.shape) - 1)))
146-
) # no noise when t == 0
148+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
147149
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
148150

149151
return image

models/vision/glide/run_glide.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
2+
23
from modeling_glide import GLIDE
34

5+
46
generator = torch.Generator()
57
generator = generator.manual_seed(0)
68

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
__version__ = "0.0.1"
66

77
from .modeling_utils import ModelMixin
8+
from .models.clip_text_transformer import CLIPTextModel
89
from .models.unet import UNetModel
910
from .models.unet_glide import UNetGLIDEModel
1011
from .models.unet_ldm import UNetLDMModel
11-
from .models.clip_text_transformer import CLIPTextModel
1212
from .pipeline_utils import DiffusionPipeline
13-
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
1413
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
14+
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler

src/diffusers/configuration_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
8989

9090
self.to_json_file(output_config_file)
9191
logger.info(f"ConfigMixinuration saved in {output_config_file}")
92-
9392

9493
@classmethod
9594
def get_config_dict(
@@ -183,7 +182,7 @@ def get_config_dict(
183182
logger.info(f"loading configuration file {config_file}")
184183
else:
185184
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
186-
185+
187186
return config_dict
188187

189188
@classmethod
@@ -199,9 +198,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
199198
# use value from config dict
200199
init_dict[key] = config_dict.pop(key)
201200

202-
203201
unused_kwargs = config_dict.update(kwargs)
204-
202+
205203
passed_keys = set(init_dict.keys())
206204
if len(expected_keys - passed_keys) > 0:
207205
logger.warn(
@@ -212,9 +210,7 @@ def extract_init_dict(cls, config_dict, **kwargs):
212210

213211
@classmethod
214212
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
215-
config_dict = cls.get_config_dict(
216-
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
217-
)
213+
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
218214

219215
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
220216

src/diffusers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19+
from .clip_text_transformer import CLIPTextModel
1920
from .unet import UNetModel
2021
from .unet_glide import UNetGLIDEModel
2122
from .unet_ldm import UNetLDMModel
22-
from .clip_text_transformer import CLIPTextModel

src/diffusers/models/clip_text_transformer.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
# limitations under the License.
1515
""" PyTorch CLIP model."""
1616

17-
from dataclasses import dataclass
1817
import math
18+
from dataclasses import dataclass
1919
from typing import Any, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.utils.checkpoint
2323
from torch import nn
2424

25+
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
2526
from transformers.activations import ACT2FN
2627
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
2728
from transformers.modeling_utils import PreTrainedModel
@@ -32,7 +33,7 @@
3233
logging,
3334
replace_return_docstrings,
3435
)
35-
from transformers import CLIPModel, CLIPConfig, CLIPVisionConfig, CLIPTextConfig
36+
3637

3738
logger = logging.get_logger(__name__)
3839

@@ -153,11 +154,11 @@ def __init__(self, config: CLIPTextConfig):
153154
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
154155

155156
def forward(
156-
self,
157-
input_ids: Optional[torch.LongTensor] = None,
158-
position_ids: Optional[torch.LongTensor] = None,
159-
inputs_embeds: Optional[torch.FloatTensor] = None,
160-
attention_mask: Optional[torch.Tensor] = None,
157+
self,
158+
input_ids: Optional[torch.LongTensor] = None,
159+
position_ids: Optional[torch.LongTensor] = None,
160+
inputs_embeds: Optional[torch.FloatTensor] = None,
161+
attention_mask: Optional[torch.Tensor] = None,
161162
) -> torch.Tensor:
162163
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
163164

@@ -193,16 +194,15 @@ def __init__(self, config):
193194
)
194195
self.scale = 1 / math.sqrt(math.sqrt(self.head_dim))
195196

196-
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim*3)
197+
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
197198
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
198199

199-
200200
def forward(
201-
self,
202-
hidden_states: torch.Tensor,
203-
attention_mask: Optional[torch.Tensor] = None,
204-
causal_attention_mask: Optional[torch.Tensor] = None,
205-
output_attentions: Optional[bool] = False,
201+
self,
202+
hidden_states: torch.Tensor,
203+
attention_mask: Optional[torch.Tensor] = None,
204+
causal_attention_mask: Optional[torch.Tensor] = None,
205+
output_attentions: Optional[bool] = False,
206206
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
207207
"""Input shape: Batch x Time x Channel"""
208208

@@ -212,9 +212,7 @@ def forward(
212212
qkv_states = qkv_states.view(bsz, tgt_len, self.num_heads, -1)
213213
query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=-1)
214214

215-
attn_weights = torch.einsum(
216-
"bthc,bshc->bhts", query_states * self.scale, key_states * self.scale
217-
)
215+
attn_weights = torch.einsum("bthc,bshc->bhts", query_states * self.scale, key_states * self.scale)
218216

219217
wdtype = attn_weights.dtype
220218
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).type(wdtype)
@@ -252,11 +250,11 @@ def __init__(self, config: CLIPConfig):
252250
self.layer_norm2 = nn.LayerNorm(self.embed_dim)
253251

254252
def forward(
255-
self,
256-
hidden_states: torch.Tensor,
257-
attention_mask: torch.Tensor,
258-
causal_attention_mask: torch.Tensor,
259-
output_attentions: Optional[bool] = False,
253+
self,
254+
hidden_states: torch.Tensor,
255+
attention_mask: torch.Tensor,
256+
causal_attention_mask: torch.Tensor,
257+
output_attentions: Optional[bool] = False,
260258
) -> Tuple[torch.FloatTensor]:
261259
"""
262260
Args:
@@ -313,31 +311,31 @@ def _init_weights(self, module):
313311
module.padding_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
314312
elif isinstance(module, CLIPVisionEmbeddings):
315313
factor = self.config.initializer_factor
316-
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim ** -0.5 * factor)
314+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
317315
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
318316
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
319317
elif isinstance(module, CLIPAttention):
320318
factor = self.config.initializer_factor
321-
in_proj_std = (module.embed_dim ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
322-
out_proj_std = (module.embed_dim ** -0.5) * factor
319+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
320+
out_proj_std = (module.embed_dim**-0.5) * factor
323321
nn.init.normal_(module.qkv_proj.weight, std=in_proj_std)
324322
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
325323
elif isinstance(module, CLIPMLP):
326324
factor = self.config.initializer_factor
327325
in_proj_std = (
328-
(module.config.hidden_size ** -0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
326+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
329327
)
330328
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
331329
nn.init.normal_(module.fc1.weight, std=fc_std)
332330
nn.init.normal_(module.fc2.weight, std=in_proj_std)
333331
elif isinstance(module, CLIPModel):
334332
nn.init.normal_(
335333
module.text_projection.weight,
336-
std=module.text_embed_dim ** -0.5 * self.config.initializer_factor,
334+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
337335
)
338336
nn.init.normal_(
339337
module.visual_projection.weight,
340-
std=module.vision_embed_dim ** -0.5 * self.config.initializer_factor,
338+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
341339
)
342340

343341
if isinstance(module, nn.LayerNorm):
@@ -463,13 +461,13 @@ def __init__(self, config: CLIPConfig):
463461
self.gradient_checkpointing = False
464462

465463
def forward(
466-
self,
467-
inputs_embeds,
468-
attention_mask: Optional[torch.Tensor] = None,
469-
causal_attention_mask: Optional[torch.Tensor] = None,
470-
output_attentions: Optional[bool] = None,
471-
output_hidden_states: Optional[bool] = None,
472-
return_dict: Optional[bool] = None,
464+
self,
465+
inputs_embeds,
466+
attention_mask: Optional[torch.Tensor] = None,
467+
causal_attention_mask: Optional[torch.Tensor] = None,
468+
output_attentions: Optional[bool] = None,
469+
output_hidden_states: Optional[bool] = None,
470+
return_dict: Optional[bool] = None,
473471
) -> Union[Tuple, BaseModelOutput]:
474472
r"""
475473
Args:
@@ -562,13 +560,13 @@ def __init__(self, config: CLIPTextConfig):
562560
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
563561
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
564562
def forward(
565-
self,
566-
input_ids: Optional[torch.Tensor] = None,
567-
attention_mask: Optional[torch.Tensor] = None,
568-
position_ids: Optional[torch.Tensor] = None,
569-
output_attentions: Optional[bool] = None,
570-
output_hidden_states: Optional[bool] = None,
571-
return_dict: Optional[bool] = None,
563+
self,
564+
input_ids: Optional[torch.Tensor] = None,
565+
attention_mask: Optional[torch.Tensor] = None,
566+
position_ids: Optional[torch.Tensor] = None,
567+
output_attentions: Optional[bool] = None,
568+
output_hidden_states: Optional[bool] = None,
569+
return_dict: Optional[bool] = None,
572570
) -> Union[Tuple, BaseModelOutputWithPooling]:
573571
r"""
574572
Returns:
@@ -652,13 +650,13 @@ def set_input_embeddings(self, value):
652650
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
653651
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
654652
def forward(
655-
self,
656-
input_ids: Optional[torch.Tensor] = None,
657-
attention_mask: Optional[torch.Tensor] = None,
658-
position_ids: Optional[torch.Tensor] = None,
659-
output_attentions: Optional[bool] = None,
660-
output_hidden_states: Optional[bool] = None,
661-
return_dict: Optional[bool] = None,
653+
self,
654+
input_ids: Optional[torch.Tensor] = None,
655+
attention_mask: Optional[torch.Tensor] = None,
656+
position_ids: Optional[torch.Tensor] = None,
657+
output_attentions: Optional[bool] = None,
658+
output_hidden_states: Optional[bool] = None,
659+
return_dict: Optional[bool] = None,
662660
) -> Union[Tuple, BaseModelOutputWithPooling]:
663661
r"""
664662
Returns:
@@ -684,4 +682,4 @@ def forward(
684682
output_attentions=output_attentions,
685683
output_hidden_states=output_hidden_states,
686684
return_dict=return_dict,
687-
)
685+
)

src/diffusers/models/unet_glide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def __init__(
470470
self.channel_mult = channel_mult
471471
self.conv_resample = conv_resample
472472
self.use_checkpoint = use_checkpoint
473-
#self.dtype = torch.float16 if use_fp16 else torch.float32
473+
# self.dtype = torch.float16 if use_fp16 else torch.float32
474474
self.num_heads = num_heads
475475
self.num_head_channels = num_head_channels
476476
self.num_heads_upsample = num_heads_upsample

0 commit comments

Comments
 (0)