|
1 | 1 | #!/usr/bin/env python3 |
2 | | -from diffusers import UNetModel, GaussianDiffusion |
| 2 | +from diffusers import UNetModel, GaussianDDPMScheduler |
3 | 3 | import torch |
4 | 4 | import torch.nn.functional as F |
5 | | - |
6 | | -unet = UNetModel.from_pretrained("fusing/ddpm_dummy") |
7 | | -diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy") |
8 | | - |
| 5 | +import numpy as np |
| 6 | +import PIL.Image |
| 7 | +import tqdm |
| 8 | + |
| 9 | +#torch_device = "cuda" |
| 10 | +# |
| 11 | +#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church") |
| 12 | +#unet.to(torch_device) |
| 13 | +# |
| 14 | +#TIME_STEPS = 10 |
| 15 | +# |
| 16 | +#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS) |
| 17 | +# |
| 18 | +#diffusion_config = { |
| 19 | +# "beta_start": 0.0001, |
| 20 | +# "beta_end": 0.02, |
| 21 | +# "num_diffusion_timesteps": TIME_STEPS, |
| 22 | +#} |
| 23 | +# |
9 | 24 | # 2. Do one denoising step with model |
10 | | -batch_size, num_channels, height, width = 1, 3, 32, 32 |
11 | | -dummy_noise = torch.ones((batch_size, num_channels, height, width)) |
12 | | - |
13 | | - |
14 | | -TIME_STEPS = 10 |
15 | | - |
16 | | - |
| 25 | +#batch_size, num_channels, height, width = 1, 3, 256, 256 |
| 26 | +# |
| 27 | +#torch.manual_seed(0) |
| 28 | +#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda") |
| 29 | +# |
| 30 | +# |
17 | 31 | # Helper |
18 | | -def extract(a, t, x_shape): |
19 | | - b, *_ = t.shape |
20 | | - out = a.gather(-1, t) |
21 | | - return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
| 32 | +#def noise_like(shape, device, repeat=False): |
| 33 | +# def repeat_noise(): |
| 34 | +# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) |
| 35 | +# |
| 36 | +# def noise(): |
| 37 | +# return torch.randn(shape, device=device) |
| 38 | +# |
| 39 | +# return repeat_noise() if repeat else noise() |
| 40 | +# |
| 41 | +# |
| 42 | +#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64) |
| 43 | +#betas = torch.tensor(betas, device=torch_device) |
| 44 | +#alphas = 1.0 - betas |
| 45 | +# |
| 46 | +#alphas_cumprod = torch.cumprod(alphas, axis=0) |
| 47 | +#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
| 48 | +# |
| 49 | +#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| 50 | +#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) |
| 51 | +# |
| 52 | +#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| 53 | +#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20)) |
| 54 | +# |
| 55 | +# |
| 56 | +#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) |
| 57 | +#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) |
| 58 | +# |
| 59 | +# |
| 60 | +#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod) |
| 61 | +#coeff = 1 / torch.sqrt(alphas) |
| 62 | + |
| 63 | + |
| 64 | +def real_fn(): |
| 65 | + # Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf |
| 66 | + # 1: x_t ~ N(0,1) |
| 67 | + x_t = noise_image |
| 68 | + # 2: for t = T, ...., 1 do |
| 69 | + for i in reversed(range(TIME_STEPS)): |
| 70 | + t = torch.tensor([i]).to(torch_device) |
| 71 | + # 3: z ~ N(0, 1) |
| 72 | + noise = noise_like(x_t.shape, torch_device) |
| 73 | + |
| 74 | + # 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz |
| 75 | + # ------------------------- MODEL ------------------------------------# |
| 76 | + with torch.no_grad(): |
| 77 | + pred_noise = unet(x_t, t) # pred epsilon_theta |
| 78 | + |
| 79 | + # pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise |
| 80 | + # pred_x.clamp_(-1.0, 1.0) |
| 81 | + # pred mean |
| 82 | + # posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t |
| 83 | + # --------------------------------------------------------------------# |
| 84 | + |
| 85 | + posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise) |
| 86 | + |
| 87 | + # ------------------------- Variance Scheduler -----------------------# |
| 88 | + # pred variance |
| 89 | + posterior_log_variance = posterior_log_variance_clipped[t] |
| 90 | + |
| 91 | + b, *_, device = *x_t.shape, x_t.device |
| 92 | + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1))) |
| 93 | + posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp() |
| 94 | + # --------------------------------------------------------------------# |
| 95 | + |
| 96 | + x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32) |
| 97 | + x_t = x_t_1 |
| 98 | + |
| 99 | + print(x_t.abs().sum()) |
| 100 | + |
| 101 | + |
| 102 | +def post_process_to_image(x_t): |
| 103 | + image = x_t.cpu().permute(0, 2, 3, 1) |
| 104 | + image = (image + 1.0) * 127.5 |
| 105 | + image = image.numpy().astype(np.uint8) |
| 106 | + |
| 107 | + return PIL.Image.fromarray(image[0]) |
| 108 | + |
| 109 | + |
| 110 | +from pytorch_diffusion import Diffusion |
| 111 | + |
| 112 | +#diffusion = Diffusion.from_pretrained("lsun_church") |
| 113 | +#samples = diffusion.denoise(1) |
| 114 | +# |
| 115 | +#image = post_process_to_image(samples) |
| 116 | +#image.save("check.png") |
| 117 | +#import ipdb; ipdb.set_trace() |
| 118 | + |
| 119 | + |
| 120 | +device = "cuda" |
| 121 | +scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10) |
| 122 | + |
| 123 | +import ipdb; ipdb.set_trace() |
| 124 | + |
| 125 | +model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device) |
22 | 126 |
|
23 | 127 |
|
24 | | -def noise_like(shape, device, repeat=False): |
25 | | - def repeat_noise(): |
26 | | - return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) |
27 | | - |
28 | | - def noise(): |
29 | | - return torch.randn(shape, device=device) |
30 | | - |
31 | | - return repeat_noise() if repeat else noise() |
32 | | - |
33 | | - |
34 | | -# Schedule |
35 | | -def cosine_beta_schedule(timesteps, s=0.008): |
36 | | - """ |
37 | | - cosine schedule |
38 | | - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
39 | | - """ |
40 | | - steps = timesteps + 1 |
41 | | - x = torch.linspace(0, timesteps, steps, dtype=torch.float64) |
42 | | - alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 |
43 | | - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
44 | | - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
45 | | - return torch.clip(betas, 0, 0.999) |
| 128 | +torch.manual_seed(0) |
| 129 | +next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device) |
46 | 130 |
|
| 131 | +for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)): |
| 132 | + # define coefficients for time step t |
| 133 | + clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) |
| 134 | + clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) |
| 135 | + image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) |
| 136 | + clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) |
47 | 137 |
|
48 | | -betas = cosine_beta_schedule(TIME_STEPS) |
49 | | -alphas = 1.0 - betas |
50 | | -alphas_cumprod = torch.cumprod(alphas, axis=0) |
51 | | -alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
| 138 | + # predict noise residual |
| 139 | + with torch.no_grad(): |
| 140 | + noise_residual = model(next_image, t) |
52 | 141 |
|
53 | | -posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
54 | | -posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) |
| 142 | + # compute prev image from noise |
| 143 | + pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual |
| 144 | + pred_mean = torch.clamp(pred_mean, -1, 1) |
| 145 | + image = clip_coeff * pred_mean + image_coeff * next_image |
55 | 146 |
|
56 | | -posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
57 | | -posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20)) |
| 147 | + # sample variance |
| 148 | + variance = scheduler.sample_variance(t, image.shape, device=device) |
58 | 149 |
|
| 150 | + # sample previous image |
| 151 | + sampled_image = image + variance |
59 | 152 |
|
60 | | -sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) |
61 | | -sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) |
| 153 | + next_image = sampled_image |
62 | 154 |
|
63 | | -torch.manual_seed(0) |
64 | 155 |
|
65 | | -# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf |
66 | | -# 1: x_t ~ N(0,1) |
67 | | -x_t = dummy_noise |
68 | | -# 2: for t = T, ...., 1 do |
69 | | -for i in reversed(range(TIME_STEPS)): |
70 | | - t = torch.tensor([i]) |
71 | | - # 3: z ~ N(0, 1) |
72 | | - noise = noise_like(x_t.shape, "cpu") |
73 | | - |
74 | | - # 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz |
75 | | - # ------------------------- MODEL ------------------------------------# |
76 | | - pred_noise = unet(x_t, t) # pred epsilon_theta |
77 | | - pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise |
78 | | - pred_x.clamp_(-1.0, 1.0) |
79 | | - # pred mean |
80 | | - posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t |
81 | | - # --------------------------------------------------------------------# |
82 | | - |
83 | | - # ------------------------- Variance Scheduler -----------------------# |
84 | | - # pred variance |
85 | | - posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape) |
86 | | - b, *_, device = *x_t.shape, x_t.device |
87 | | - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1))) |
88 | | - posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp() |
89 | | - # --------------------------------------------------------------------# |
90 | | - |
91 | | - x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32) |
92 | | - |
93 | | - # FOR PATRICK TO VERIFY: make sure manual loop is equal to function |
94 | | - # --------------------------------------------------------------------# |
95 | | - x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise) |
96 | | - assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3 |
97 | | - # --------------------------------------------------------------------# |
98 | | - |
99 | | - x_t = x_t_1 |
| 156 | +image = post_process_to_image(next_image) |
| 157 | +image.save("example_new.png") |
0 commit comments