Skip to content

Commit fe31373

Browse files
improve
1 parent 3a5c65d commit fe31373

File tree

14 files changed

+691
-760
lines changed

14 files changed

+691
-760
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
2727
Example:
2828

2929
```python
30-
from diffusers import UNetModel, GaussianDiffusion
30+
from diffusers import UNetModel, GaussianDDPMScheduler
3131
import torch
3232

3333
# 1. Load model
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
4040
image = unet(dummy_noise, time_step)
4141

4242
# 3. Load sampler
43-
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
43+
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
4444

4545
# 4. Sample image from sampler passing the model
4646
image = sampler.sample(model, batch_size=1)
@@ -54,12 +54,12 @@ print(image)
5454
Example:
5555

5656
```python
57-
from diffusers import UNetModel, GaussianDiffusion
57+
from diffusers import UNetModel, GaussianDDPMScheduler
5858
from modeling_ddpm import DDPM
5959
import tempfile
6060

6161
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
62-
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
62+
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
6363

6464
# compose Diffusion Pipeline
6565
ddpm = DDPM(unet, sampler)

examples/sample_loop.py

Lines changed: 142 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,157 @@
11
#!/usr/bin/env python3
2-
from diffusers import UNetModel, GaussianDiffusion
2+
from diffusers import UNetModel, GaussianDDPMScheduler
33
import torch
44
import torch.nn.functional as F
5-
6-
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
7-
diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy")
8-
5+
import numpy as np
6+
import PIL.Image
7+
import tqdm
8+
9+
#torch_device = "cuda"
10+
#
11+
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
12+
#unet.to(torch_device)
13+
#
14+
#TIME_STEPS = 10
15+
#
16+
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
17+
#
18+
#diffusion_config = {
19+
# "beta_start": 0.0001,
20+
# "beta_end": 0.02,
21+
# "num_diffusion_timesteps": TIME_STEPS,
22+
#}
23+
#
924
# 2. Do one denoising step with model
10-
batch_size, num_channels, height, width = 1, 3, 32, 32
11-
dummy_noise = torch.ones((batch_size, num_channels, height, width))
12-
13-
14-
TIME_STEPS = 10
15-
16-
25+
#batch_size, num_channels, height, width = 1, 3, 256, 256
26+
#
27+
#torch.manual_seed(0)
28+
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
29+
#
30+
#
1731
# Helper
18-
def extract(a, t, x_shape):
19-
b, *_ = t.shape
20-
out = a.gather(-1, t)
21-
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
32+
#def noise_like(shape, device, repeat=False):
33+
# def repeat_noise():
34+
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
35+
#
36+
# def noise():
37+
# return torch.randn(shape, device=device)
38+
#
39+
# return repeat_noise() if repeat else noise()
40+
#
41+
#
42+
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
43+
#betas = torch.tensor(betas, device=torch_device)
44+
#alphas = 1.0 - betas
45+
#
46+
#alphas_cumprod = torch.cumprod(alphas, axis=0)
47+
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
48+
#
49+
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
50+
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
51+
#
52+
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
53+
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
54+
#
55+
#
56+
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
57+
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
58+
#
59+
#
60+
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
61+
#coeff = 1 / torch.sqrt(alphas)
62+
63+
64+
def real_fn():
65+
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
66+
# 1: x_t ~ N(0,1)
67+
x_t = noise_image
68+
# 2: for t = T, ...., 1 do
69+
for i in reversed(range(TIME_STEPS)):
70+
t = torch.tensor([i]).to(torch_device)
71+
# 3: z ~ N(0, 1)
72+
noise = noise_like(x_t.shape, torch_device)
73+
74+
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
75+
# ------------------------- MODEL ------------------------------------#
76+
with torch.no_grad():
77+
pred_noise = unet(x_t, t) # pred epsilon_theta
78+
79+
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
80+
# pred_x.clamp_(-1.0, 1.0)
81+
# pred mean
82+
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
83+
# --------------------------------------------------------------------#
84+
85+
posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise)
86+
87+
# ------------------------- Variance Scheduler -----------------------#
88+
# pred variance
89+
posterior_log_variance = posterior_log_variance_clipped[t]
90+
91+
b, *_, device = *x_t.shape, x_t.device
92+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
93+
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
94+
# --------------------------------------------------------------------#
95+
96+
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
97+
x_t = x_t_1
98+
99+
print(x_t.abs().sum())
100+
101+
102+
def post_process_to_image(x_t):
103+
image = x_t.cpu().permute(0, 2, 3, 1)
104+
image = (image + 1.0) * 127.5
105+
image = image.numpy().astype(np.uint8)
106+
107+
return PIL.Image.fromarray(image[0])
108+
109+
110+
from pytorch_diffusion import Diffusion
111+
112+
#diffusion = Diffusion.from_pretrained("lsun_church")
113+
#samples = diffusion.denoise(1)
114+
#
115+
#image = post_process_to_image(samples)
116+
#image.save("check.png")
117+
#import ipdb; ipdb.set_trace()
118+
119+
120+
device = "cuda"
121+
scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10)
122+
123+
import ipdb; ipdb.set_trace()
124+
125+
model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device)
22126

23127

24-
def noise_like(shape, device, repeat=False):
25-
def repeat_noise():
26-
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
27-
28-
def noise():
29-
return torch.randn(shape, device=device)
30-
31-
return repeat_noise() if repeat else noise()
32-
33-
34-
# Schedule
35-
def cosine_beta_schedule(timesteps, s=0.008):
36-
"""
37-
cosine schedule
38-
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
39-
"""
40-
steps = timesteps + 1
41-
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
42-
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
43-
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
44-
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
45-
return torch.clip(betas, 0, 0.999)
128+
torch.manual_seed(0)
129+
next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device)
46130

131+
for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)):
132+
# define coefficients for time step t
133+
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
134+
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
135+
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
136+
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
47137

48-
betas = cosine_beta_schedule(TIME_STEPS)
49-
alphas = 1.0 - betas
50-
alphas_cumprod = torch.cumprod(alphas, axis=0)
51-
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
138+
# predict noise residual
139+
with torch.no_grad():
140+
noise_residual = model(next_image, t)
52141

53-
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
54-
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
142+
# compute prev image from noise
143+
pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual
144+
pred_mean = torch.clamp(pred_mean, -1, 1)
145+
image = clip_coeff * pred_mean + image_coeff * next_image
55146

56-
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
57-
posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
147+
# sample variance
148+
variance = scheduler.sample_variance(t, image.shape, device=device)
58149

150+
# sample previous image
151+
sampled_image = image + variance
59152

60-
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
61-
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
153+
next_image = sampled_image
62154

63-
torch.manual_seed(0)
64155

65-
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
66-
# 1: x_t ~ N(0,1)
67-
x_t = dummy_noise
68-
# 2: for t = T, ...., 1 do
69-
for i in reversed(range(TIME_STEPS)):
70-
t = torch.tensor([i])
71-
# 3: z ~ N(0, 1)
72-
noise = noise_like(x_t.shape, "cpu")
73-
74-
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
75-
# ------------------------- MODEL ------------------------------------#
76-
pred_noise = unet(x_t, t) # pred epsilon_theta
77-
pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
78-
pred_x.clamp_(-1.0, 1.0)
79-
# pred mean
80-
posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
81-
# --------------------------------------------------------------------#
82-
83-
# ------------------------- Variance Scheduler -----------------------#
84-
# pred variance
85-
posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape)
86-
b, *_, device = *x_t.shape, x_t.device
87-
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
88-
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
89-
# --------------------------------------------------------------------#
90-
91-
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
92-
93-
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
94-
# --------------------------------------------------------------------#
95-
x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
96-
assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
97-
# --------------------------------------------------------------------#
98-
99-
x_t = x_t_1
156+
image = post_process_to_image(next_image)
157+
image.save("example_new.png")

models/vision/ddpm/example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#!/usr/bin/env python3
2-
from diffusers import UNetModel, GaussianDiffusion
3-
from modeling_ddpm import DDPM
42
import tempfile
53

4+
from diffusers import GaussianDDPMScheduler, UNetModel
5+
from modeling_ddpm import DDPM
6+
7+
68
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
7-
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
9+
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
810

911
# compose Diffusion Pipeline
1012
ddpm = DDPM(unet, sampler)

models/vision/ddpm/modeling_ddpm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919

2020
class DDPM(DiffusionPipeline):
21-
2221
def __init__(self, unet, gaussian_sampler):
2322
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
2423

models/vision/ddpm/run_ddpm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
#!/usr/bin/env python3
22
import torch
33

4-
from diffusers import GaussianDiffusion, UNetModel
4+
from diffusers import GaussianDDPMScheduler, UNetModel
55

66

77
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
88

9-
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
9+
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
1010

1111
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
1212
loss = diffusion(training_images)

src/diffusers/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
__version__ = "0.0.1"
66

7+
from .modeling_utils import PreTrainedModel
78
from .models.unet import UNetModel
8-
from .samplers.gaussian import GaussianDiffusion
9-
109
from .pipeline_utils import DiffusionPipeline
11-
from .modeling_utils import PreTrainedModel
10+
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler

src/diffusers/configuration_utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818

1919
import copy
20+
import inspect
2021
import json
2122
import os
2223
import re
23-
import inspect
2424
from typing import Any, Dict, Tuple, Union
2525

2626
from requests import HTTPError
@@ -186,6 +186,11 @@ def get_config_dict(
186186
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
187187
expected_keys.remove("self")
188188

189+
for key in expected_keys:
190+
if key in kwargs:
191+
# overwrite key
192+
config_dict[key] = kwargs.pop(key)
193+
189194
passed_keys = set(config_dict.keys())
190195

191196
unused_kwargs = kwargs
@@ -194,17 +199,16 @@ def get_config_dict(
194199

195200
if len(expected_keys - passed_keys) > 0:
196201
logger.warn(
197-
f"{expected_keys - passed_keys} was not found in config. "
198-
f"Values will be initialized to default values."
202+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
199203
)
200204

201205
return config_dict, unused_kwargs
202206

203207
@classmethod
204-
def from_config(
205-
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
206-
):
207-
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
208+
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
209+
config_dict, unused_kwargs = cls.get_config_dict(
210+
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
211+
)
208212

209213
model = cls(**config_dict)
210214

src/diffusers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
# CHANGE to diffusers.utils
2626
from transformers.utils import (
27+
CONFIG_NAME,
2728
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
2829
EntryNotFoundError,
2930
RepositoryNotFoundError,
@@ -33,7 +34,6 @@
3334
is_offline_mode,
3435
is_remote_url,
3536
logging,
36-
CONFIG_NAME,
3737
)
3838

3939

0 commit comments

Comments
 (0)