Skip to content

Commit 889aa60

Browse files
author
Nathan Lambert
authored
PNDM API Updates, Tests Cleaning (huggingface#103)
* organize PNDM tests, begin API change * clean timestep API PNDM * update pipeline PNDM * fix typo * API clean round 2 * small nit
1 parent 76f9b52 commit 889aa60

File tree

3 files changed

+128
-101
lines changed

3 files changed

+128
-101
lines changed

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,16 @@ def __call__(self, batch_size=1, generator=None, torch_device=None, num_inferenc
4343
)
4444
image = image.to(torch_device)
4545

46-
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
47-
for t in tqdm(range(len(prk_time_steps))):
48-
t_orig = prk_time_steps[t]
49-
model_output = self.unet(image, t_orig)["sample"]
46+
self.scheduler.set_timesteps(num_inference_steps)
47+
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
48+
model_output = self.unet(image, t)["sample"]
5049

51-
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
50+
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
5251

53-
timesteps = self.scheduler.get_time_steps(num_inference_steps)
54-
for t in tqdm(range(len(timesteps))):
55-
t_orig = timesteps[t]
56-
model_output = self.unet(image, t_orig)["sample"]
52+
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
53+
model_output = self.unet(image, t)["sample"]
5754

58-
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
55+
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
5956

6057
image = (image / 2 + 0.5).clamp(0, 1)
6158
image = image.cpu().permute(0, 2, 3, 1).numpy()

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
1616

1717
import math
18+
import pdb
1819
from typing import Union
1920

2021
import numpy as np
@@ -71,8 +72,6 @@ def __init__(
7172

7273
self.one = np.array(1.0)
7374

74-
self.set_format(tensor_format=tensor_format)
75-
7675
# For now we only support F-PNDM, i.e. the runge-kutta method
7776
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
7877
# mainly at formula (9), (12), (13) and the Algorithm 2.
@@ -82,49 +81,29 @@ def __init__(
8281
self.cur_model_output = 0
8382
self.cur_sample = None
8483
self.ets = []
85-
self.prk_time_steps = {}
86-
self.time_steps = {}
87-
self.set_prk_mode()
8884

89-
def get_prk_time_steps(self, num_inference_steps):
90-
if num_inference_steps in self.prk_time_steps:
91-
return self.prk_time_steps[num_inference_steps]
85+
# setable values
86+
self.num_inference_steps = None
87+
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
88+
self.prk_timesteps = None
89+
self.plms_timesteps = None
90+
91+
self.tensor_format = tensor_format
92+
self.set_format(tensor_format=tensor_format)
9293

93-
inference_step_times = list(
94+
def set_timesteps(self, num_inference_steps):
95+
self.num_inference_steps = num_inference_steps
96+
self.timesteps = list(
9497
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
9598
)
9699

97-
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
100+
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
98101
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
99102
)
100-
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
101-
102-
return self.prk_time_steps[num_inference_steps]
103-
104-
def get_time_steps(self, num_inference_steps):
105-
if num_inference_steps in self.time_steps:
106-
return self.time_steps[num_inference_steps]
107-
108-
inference_step_times = list(
109-
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
110-
)
111-
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
112-
113-
return self.time_steps[num_inference_steps]
114-
115-
def set_prk_mode(self):
116-
self.mode = "prk"
117-
118-
def set_plms_mode(self):
119-
self.mode = "plms"
120-
121-
def step(self, *args, **kwargs):
122-
if self.mode == "prk":
123-
return self.step_prk(*args, **kwargs)
124-
if self.mode == "plms":
125-
return self.step_plms(*args, **kwargs)
103+
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
104+
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
126105

127-
raise ValueError(f"mode {self.mode} does not exist.")
106+
self.set_format(tensor_format=self.tensor_format)
128107

129108
def step_prk(
130109
self,
@@ -138,7 +117,7 @@ def step_prk(
138117
solution to the differential equation.
139118
"""
140119
t = timestep
141-
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
120+
prk_time_steps = self.prk_timesteps
142121

143122
t_orig = prk_time_steps[t // 4 * 4]
144123
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
@@ -180,7 +159,7 @@ def step_plms(
180159
"for more information."
181160
)
182161

183-
timesteps = self.get_time_steps(num_inference_steps)
162+
timesteps = self.plms_timesteps
184163

185164
t_orig = timesteps[t]
186165
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]

tests/test_scheduler.py

Lines changed: 103 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def check_over_configs(self, time_step=0, **config):
7070
num_inference_steps = kwargs.pop("num_inference_steps", None)
7171

7272
for scheduler_class in self.scheduler_classes:
73-
scheduler_class = self.scheduler_classes[0]
7473
sample = self.dummy_sample
7574
residual = 0.1 * sample
7675

@@ -102,7 +101,6 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
102101
sample = self.dummy_sample
103102
residual = 0.1 * sample
104103

105-
scheduler_class = self.scheduler_classes[0]
106104
scheduler_config = self.get_scheduler_config()
107105
scheduler = scheduler_class(**scheduler_config)
108106

@@ -375,108 +373,168 @@ def get_scheduler_config(self, **kwargs):
375373
config.update(**kwargs)
376374
return config
377375

378-
def check_over_configs_pmls(self, time_step=0, **config):
376+
def check_over_configs(self, time_step=0, **config):
379377
kwargs = dict(self.forward_default_kwargs)
380378
sample = self.dummy_sample
381379
residual = 0.1 * sample
382380
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
383381

384382
for scheduler_class in self.scheduler_classes:
385-
scheduler_class = self.scheduler_classes[0]
386383
scheduler_config = self.get_scheduler_config(**config)
387384
scheduler = scheduler_class(**scheduler_config)
385+
scheduler.set_timesteps(kwargs["num_inference_steps"])
388386
# copy over dummy past residuals
389387
scheduler.ets = dummy_past_residuals[:]
390-
scheduler.set_plms_mode()
391388

392389
with tempfile.TemporaryDirectory() as tmpdirname:
393390
scheduler.save_config(tmpdirname)
394391
new_scheduler = scheduler_class.from_config(tmpdirname)
392+
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
395393
# copy over dummy past residuals
396394
new_scheduler.ets = dummy_past_residuals[:]
397-
new_scheduler.set_plms_mode()
398395

399-
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
400-
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
396+
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
397+
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
401398

402399
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
403400

404-
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
401+
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
402+
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
403+
404+
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
405+
406+
def test_from_pretrained_save_pretrained(self):
407+
pass
408+
409+
def check_over_forward(self, time_step=0, **forward_kwargs):
405410
kwargs = dict(self.forward_default_kwargs)
406411
kwargs.update(forward_kwargs)
407412
sample = self.dummy_sample
408413
residual = 0.1 * sample
409414
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
410415

411416
for scheduler_class in self.scheduler_classes:
412-
scheduler_class = self.scheduler_classes[0]
413417
scheduler_config = self.get_scheduler_config()
414418
scheduler = scheduler_class(**scheduler_config)
419+
scheduler.set_timesteps(kwargs["num_inference_steps"])
420+
415421
# copy over dummy past residuals
416422
scheduler.ets = dummy_past_residuals[:]
417-
scheduler.set_plms_mode()
418423

419424
with tempfile.TemporaryDirectory() as tmpdirname:
420425
scheduler.save_config(tmpdirname)
421426
new_scheduler = scheduler_class.from_config(tmpdirname)
422427
# copy over dummy past residuals
423428
new_scheduler.ets = dummy_past_residuals[:]
424-
new_scheduler.set_plms_mode()
429+
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
425430

426-
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
427-
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
431+
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
432+
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
433+
434+
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
435+
436+
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
437+
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
428438

429439
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
430440

441+
def test_pytorch_equal_numpy(self):
442+
kwargs = dict(self.forward_default_kwargs)
443+
num_inference_steps = kwargs.pop("num_inference_steps", None)
444+
445+
for scheduler_class in self.scheduler_classes:
446+
sample = self.dummy_sample
447+
residual = 0.1 * sample
448+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
449+
450+
sample_pt = torch.tensor(sample)
451+
residual_pt = 0.1 * sample_pt
452+
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
453+
454+
scheduler_config = self.get_scheduler_config()
455+
scheduler = scheduler_class(**scheduler_config)
456+
# copy over dummy past residuals
457+
scheduler.ets = dummy_past_residuals[:]
458+
459+
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
460+
# copy over dummy past residuals
461+
scheduler_pt.ets = dummy_past_residuals_pt[:]
462+
463+
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
464+
scheduler.set_timesteps(num_inference_steps)
465+
scheduler_pt.set_timesteps(num_inference_steps)
466+
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
467+
kwargs["num_inference_steps"] = num_inference_steps
468+
469+
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
470+
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
471+
472+
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
473+
474+
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
475+
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
476+
477+
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
478+
479+
def test_step_shape(self):
480+
kwargs = dict(self.forward_default_kwargs)
481+
482+
num_inference_steps = kwargs.pop("num_inference_steps", None)
483+
484+
for scheduler_class in self.scheduler_classes:
485+
scheduler_config = self.get_scheduler_config()
486+
scheduler = scheduler_class(**scheduler_config)
487+
488+
sample = self.dummy_sample
489+
residual = 0.1 * sample
490+
# copy over dummy past residuals
491+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
492+
scheduler.ets = dummy_past_residuals[:]
493+
494+
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
495+
scheduler.set_timesteps(num_inference_steps)
496+
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
497+
kwargs["num_inference_steps"] = num_inference_steps
498+
499+
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
500+
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
501+
502+
self.assertEqual(output_0.shape, sample.shape)
503+
self.assertEqual(output_0.shape, output_1.shape)
504+
505+
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
506+
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
507+
508+
self.assertEqual(output_0.shape, sample.shape)
509+
self.assertEqual(output_0.shape, output_1.shape)
510+
431511
def test_timesteps(self):
432512
for timesteps in [100, 1000]:
433513
self.check_over_configs(num_train_timesteps=timesteps)
434514

435-
def test_timesteps_pmls(self):
436-
for timesteps in [100, 1000]:
437-
self.check_over_configs_pmls(num_train_timesteps=timesteps)
438-
439515
def test_betas(self):
440516
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
441517
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
442518

443-
def test_betas_pmls(self):
444-
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
445-
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
446-
447519
def test_schedules(self):
448520
for schedule in ["linear", "squaredcos_cap_v2"]:
449521
self.check_over_configs(beta_schedule=schedule)
450522

451-
def test_schedules_pmls(self):
452-
for schedule in ["linear", "squaredcos_cap_v2"]:
453-
self.check_over_configs(beta_schedule=schedule)
454-
455523
def test_time_indices(self):
456524
for t in [1, 5, 10]:
457525
self.check_over_forward(time_step=t)
458526

459-
def test_time_indices_pmls(self):
460-
for t in [1, 5, 10]:
461-
self.check_over_forward_pmls(time_step=t)
462-
463527
def test_inference_steps(self):
464528
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
465529
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
466530

467-
def test_inference_steps_pmls(self):
468-
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
469-
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
470-
471-
def test_inference_pmls_no_past_residuals(self):
531+
def test_inference_plms_no_past_residuals(self):
472532
with self.assertRaises(ValueError):
473533
scheduler_class = self.scheduler_classes[0]
474534
scheduler_config = self.get_scheduler_config()
475535
scheduler = scheduler_class(**scheduler_config)
476536

477-
scheduler.set_plms_mode()
478-
479-
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
537+
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
480538

481539
def test_full_loop_no_noise(self):
482540
scheduler_class = self.scheduler_classes[0]
@@ -486,20 +544,15 @@ def test_full_loop_no_noise(self):
486544
num_inference_steps = 10
487545
model = self.dummy_model()
488546
sample = self.dummy_sample_deter
547+
scheduler.set_timesteps(num_inference_steps)
489548

490-
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
491-
for t in range(len(prk_time_steps)):
492-
t_orig = prk_time_steps[t]
493-
residual = model(sample, t_orig)
494-
495-
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
496-
497-
timesteps = scheduler.get_time_steps(num_inference_steps)
498-
for t in range(len(timesteps)):
499-
t_orig = timesteps[t]
500-
residual = model(sample, t_orig)
549+
for i, t in enumerate(scheduler.prk_timesteps):
550+
residual = model(sample, t)
551+
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
501552

502-
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
553+
for i, t in enumerate(scheduler.plms_timesteps):
554+
residual = model(sample, t)
555+
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
503556

504557
result_sum = np.sum(np.abs(sample))
505558
result_mean = np.mean(np.abs(sample))
@@ -562,7 +615,6 @@ def check_over_configs(self, time_step=0, **config):
562615
kwargs = dict(self.forward_default_kwargs)
563616

564617
for scheduler_class in self.scheduler_classes:
565-
scheduler_class = self.scheduler_classes[0]
566618
sample = self.dummy_sample
567619
residual = 0.1 * sample
568620

@@ -591,7 +643,6 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
591643
sample = self.dummy_sample
592644
residual = 0.1 * sample
593645

594-
scheduler_class = self.scheduler_classes[0]
595646
scheduler_config = self.get_scheduler_config()
596647
scheduler = scheduler_class(**scheduler_config)
597648

0 commit comments

Comments
 (0)