Skip to content

Commit 26ea58d

Browse files
committed
Merge branch 'main' of https://github.com/huggingface/diffusers into main
2 parents d1fb309 + 4261c3a commit 26ea58d

30 files changed

+876
-354
lines changed

Makefile

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,23 @@ autogenerate_code: deps_table_update
3434
# Check that the repo is in a good state
3535

3636
repo-consistency:
37-
python utils/check_copies.py
38-
python utils/check_table.py
3937
python utils/check_dummies.py
4038
python utils/check_repo.py
4139
python utils/check_inits.py
42-
python utils/check_config_docstrings.py
43-
python utils/tests_fetcher.py --sanity_check
4440

4541
# this target runs checks on all files
4642

4743
quality:
4844
black --check --preview $(check_dirs)
4945
isort --check-only $(check_dirs)
5046
flake8 $(check_dirs)
51-
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
47+
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
5248

5349
# Format source code automatically and check is there are any problems left that need manual fixing
5450

5551
extra_style_checks:
5652
python utils/custom_init_isort.py
57-
python utils/sort_auto_mappings.py
58-
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
53+
doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source
5954

6055
# this target runs checks on all files and potentially modifies some of them
6156

@@ -73,8 +68,6 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
7368

7469
fix-copies:
7570
python utils/check_dummies.py --fix_and_overwrite
76-
python utils/check_table.py --fix_and_overwrite
77-
python utils/check_copies.py --fix_and_overwrite
7871

7972
# Run tests for the library
8073

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
import torch.nn.functional as F
6+
7+
import bitsandbytes as bnb
8+
import PIL.Image
9+
from accelerate import Accelerator
10+
from datasets import load_dataset
11+
from diffusers import DDPMScheduler, Glide, GlideUNetModel
12+
from diffusers.hub_utils import init_git_repo, push_to_hub
13+
from diffusers.optimization import get_scheduler
14+
from diffusers.utils import logging
15+
from torchvision.transforms import (
16+
CenterCrop,
17+
Compose,
18+
InterpolationMode,
19+
Normalize,
20+
RandomHorizontalFlip,
21+
Resize,
22+
ToTensor,
23+
)
24+
from tqdm.auto import tqdm
25+
26+
27+
logger = logging.get_logger(__name__)
28+
29+
30+
def main(args):
31+
accelerator = Accelerator(mixed_precision=args.mixed_precision)
32+
33+
pipeline = Glide.from_pretrained("fusing/glide-base")
34+
model = pipeline.text_unet
35+
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
36+
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
37+
38+
augmentations = Compose(
39+
[
40+
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
41+
CenterCrop(args.resolution),
42+
RandomHorizontalFlip(),
43+
ToTensor(),
44+
Normalize([0.5], [0.5]),
45+
]
46+
)
47+
dataset = load_dataset(args.dataset, split="train")
48+
49+
text_encoder = pipeline.text_encoder.eval()
50+
51+
def transforms(examples):
52+
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
53+
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
54+
text_inputs = text_inputs.input_ids.to(accelerator.device)
55+
with torch.no_grad():
56+
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state
57+
return {"images": images, "text_embeddings": text_embeddings}
58+
59+
dataset.set_transform(transforms)
60+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
61+
62+
lr_scheduler = get_scheduler(
63+
"linear",
64+
optimizer=optimizer,
65+
num_warmup_steps=args.warmup_steps,
66+
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
67+
)
68+
69+
model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
70+
model, text_encoder, optimizer, train_dataloader, lr_scheduler
71+
)
72+
73+
if args.push_to_hub:
74+
repo = init_git_repo(args, at_init=True)
75+
76+
# Train!
77+
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
78+
world_size = torch.distributed.get_world_size() if is_distributed else 1
79+
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
80+
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
81+
logger.info("***** Running training *****")
82+
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
83+
logger.info(f" Num Epochs = {args.num_epochs}")
84+
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
85+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
86+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
87+
logger.info(f" Total optimization steps = {max_steps}")
88+
89+
for epoch in range(args.num_epochs):
90+
model.train()
91+
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
92+
pbar.set_description(f"Epoch {epoch}")
93+
for step, batch in enumerate(train_dataloader):
94+
clean_images = batch["images"]
95+
batch_size, n_channels, height, width = clean_images.shape
96+
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
97+
timesteps = torch.randint(
98+
0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device
99+
).long()
100+
101+
# add noise onto the clean images according to the noise magnitude at each timestep
102+
# (this is the forward diffusion process)
103+
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
104+
105+
if step % args.gradient_accumulation_steps != 0:
106+
with accelerator.no_sync(model):
107+
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
108+
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
109+
# Learn the variance using the variational bound, but don't let
110+
# it affect our mean prediction.
111+
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
112+
113+
# predict the noise residual
114+
loss = F.mse_loss(model_output, noise_samples)
115+
116+
loss = loss / args.gradient_accumulation_steps
117+
118+
accelerator.backward(loss)
119+
optimizer.step()
120+
else:
121+
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
122+
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
123+
# Learn the variance using the variational bound, but don't let
124+
# it affect our mean prediction.
125+
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
126+
127+
# predict the noise residual
128+
loss = F.mse_loss(model_output, noise_samples)
129+
loss = loss / args.gradient_accumulation_steps
130+
accelerator.backward(loss)
131+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
132+
optimizer.step()
133+
lr_scheduler.step()
134+
optimizer.zero_grad()
135+
pbar.update(1)
136+
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
137+
138+
accelerator.wait_for_everyone()
139+
140+
# Generate a sample image for visual inspection
141+
if accelerator.is_main_process:
142+
model.eval()
143+
with torch.no_grad():
144+
pipeline.unet = accelerator.unwrap_model(model)
145+
146+
generator = torch.manual_seed(0)
147+
# run pipeline in inference (sample random noise and denoise)
148+
image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50)
149+
150+
# process image to PIL
151+
image_processed = image.squeeze(0)
152+
image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
153+
image_pil = PIL.Image.fromarray(image_processed)
154+
155+
# save image
156+
test_dir = os.path.join(args.output_dir, "test_samples")
157+
os.makedirs(test_dir, exist_ok=True)
158+
image_pil.save(f"{test_dir}/{epoch:04d}.png")
159+
160+
# save the model
161+
if args.push_to_hub:
162+
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
163+
else:
164+
pipeline.save_pretrained(args.output_dir)
165+
accelerator.wait_for_everyone()
166+
167+
168+
if __name__ == "__main__":
169+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
170+
parser.add_argument("--local_rank", type=int, default=-1)
171+
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
172+
parser.add_argument("--output_dir", type=str, default="glide-text2image")
173+
parser.add_argument("--overwrite_output_dir", action="store_true")
174+
parser.add_argument("--resolution", type=int, default=64)
175+
parser.add_argument("--batch_size", type=int, default=4)
176+
parser.add_argument("--num_epochs", type=int, default=100)
177+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
178+
parser.add_argument("--lr", type=float, default=1e-4)
179+
parser.add_argument("--warmup_steps", type=int, default=500)
180+
parser.add_argument("--push_to_hub", action="store_true")
181+
parser.add_argument("--hub_token", type=str, default=None)
182+
parser.add_argument("--hub_model_id", type=str, default=None)
183+
parser.add_argument("--hub_private_repo", action="store_true")
184+
parser.add_argument(
185+
"--mixed_precision",
186+
type=str,
187+
default="no",
188+
choices=["no", "fp16", "bf16"],
189+
help=(
190+
"Whether to use mixed precision. Choose"
191+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
192+
"and an Nvidia Ampere GPU."
193+
),
194+
)
195+
196+
args = parser.parse_args()
197+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
198+
if env_local_rank != -1 and env_local_rank != args.local_rank:
199+
args.local_rank = env_local_rank
200+
201+
main(args)

0 commit comments

Comments
 (0)