@@ -70,7 +70,6 @@ def check_over_configs(self, time_step=0, **config):
7070 num_inference_steps = kwargs .pop ("num_inference_steps" , None )
7171
7272 for scheduler_class in self .scheduler_classes :
73- scheduler_class = self .scheduler_classes [0 ]
7473 sample = self .dummy_sample
7574 residual = 0.1 * sample
7675
@@ -102,7 +101,6 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
102101 sample = self .dummy_sample
103102 residual = 0.1 * sample
104103
105- scheduler_class = self .scheduler_classes [0 ]
106104 scheduler_config = self .get_scheduler_config ()
107105 scheduler = scheduler_class (** scheduler_config )
108106
@@ -375,108 +373,168 @@ def get_scheduler_config(self, **kwargs):
375373 config .update (** kwargs )
376374 return config
377375
378- def check_over_configs_pmls (self , time_step = 0 , ** config ):
376+ def check_over_configs (self , time_step = 0 , ** config ):
379377 kwargs = dict (self .forward_default_kwargs )
380378 sample = self .dummy_sample
381379 residual = 0.1 * sample
382380 dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
383381
384382 for scheduler_class in self .scheduler_classes :
385- scheduler_class = self .scheduler_classes [0 ]
386383 scheduler_config = self .get_scheduler_config (** config )
387384 scheduler = scheduler_class (** scheduler_config )
385+ scheduler .set_timesteps (kwargs ["num_inference_steps" ])
388386 # copy over dummy past residuals
389387 scheduler .ets = dummy_past_residuals [:]
390- scheduler .set_plms_mode ()
391388
392389 with tempfile .TemporaryDirectory () as tmpdirname :
393390 scheduler .save_config (tmpdirname )
394391 new_scheduler = scheduler_class .from_config (tmpdirname )
392+ new_scheduler .set_timesteps (kwargs ["num_inference_steps" ])
395393 # copy over dummy past residuals
396394 new_scheduler .ets = dummy_past_residuals [:]
397- new_scheduler .set_plms_mode ()
398395
399- output = scheduler .step (residual , time_step , sample , ** kwargs )["prev_sample" ]
400- new_output = new_scheduler .step (residual , time_step , sample , ** kwargs )["prev_sample" ]
396+ output = scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
397+ new_output = new_scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
401398
402399 assert np .sum (np .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
403400
404- def check_over_forward_pmls (self , time_step = 0 , ** forward_kwargs ):
401+ output = scheduler .step_plms (residual , time_step , sample , ** kwargs )["prev_sample" ]
402+ new_output = new_scheduler .step_plms (residual , time_step , sample , ** kwargs )["prev_sample" ]
403+
404+ assert np .sum (np .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
405+
406+ def test_from_pretrained_save_pretrained (self ):
407+ pass
408+
409+ def check_over_forward (self , time_step = 0 , ** forward_kwargs ):
405410 kwargs = dict (self .forward_default_kwargs )
406411 kwargs .update (forward_kwargs )
407412 sample = self .dummy_sample
408413 residual = 0.1 * sample
409414 dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
410415
411416 for scheduler_class in self .scheduler_classes :
412- scheduler_class = self .scheduler_classes [0 ]
413417 scheduler_config = self .get_scheduler_config ()
414418 scheduler = scheduler_class (** scheduler_config )
419+ scheduler .set_timesteps (kwargs ["num_inference_steps" ])
420+
415421 # copy over dummy past residuals
416422 scheduler .ets = dummy_past_residuals [:]
417- scheduler .set_plms_mode ()
418423
419424 with tempfile .TemporaryDirectory () as tmpdirname :
420425 scheduler .save_config (tmpdirname )
421426 new_scheduler = scheduler_class .from_config (tmpdirname )
422427 # copy over dummy past residuals
423428 new_scheduler .ets = dummy_past_residuals [:]
424- new_scheduler .set_plms_mode ( )
429+ new_scheduler .set_timesteps ( kwargs [ "num_inference_steps" ] )
425430
426- output = scheduler .step (residual , time_step , sample , ** kwargs )["prev_sample" ]
427- new_output = new_scheduler .step (residual , time_step , sample , ** kwargs )["prev_sample" ]
431+ output = scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
432+ new_output = new_scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
433+
434+ assert np .sum (np .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
435+
436+ output = scheduler .step_plms (residual , time_step , sample , ** kwargs )["prev_sample" ]
437+ new_output = new_scheduler .step_plms (residual , time_step , sample , ** kwargs )["prev_sample" ]
428438
429439 assert np .sum (np .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
430440
441+ def test_pytorch_equal_numpy (self ):
442+ kwargs = dict (self .forward_default_kwargs )
443+ num_inference_steps = kwargs .pop ("num_inference_steps" , None )
444+
445+ for scheduler_class in self .scheduler_classes :
446+ sample = self .dummy_sample
447+ residual = 0.1 * sample
448+ dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
449+
450+ sample_pt = torch .tensor (sample )
451+ residual_pt = 0.1 * sample_pt
452+ dummy_past_residuals_pt = [residual_pt + 0.2 , residual_pt + 0.15 , residual_pt + 0.1 , residual_pt + 0.05 ]
453+
454+ scheduler_config = self .get_scheduler_config ()
455+ scheduler = scheduler_class (** scheduler_config )
456+ # copy over dummy past residuals
457+ scheduler .ets = dummy_past_residuals [:]
458+
459+ scheduler_pt = scheduler_class (tensor_format = "pt" , ** scheduler_config )
460+ # copy over dummy past residuals
461+ scheduler_pt .ets = dummy_past_residuals_pt [:]
462+
463+ if num_inference_steps is not None and hasattr (scheduler , "set_timesteps" ):
464+ scheduler .set_timesteps (num_inference_steps )
465+ scheduler_pt .set_timesteps (num_inference_steps )
466+ elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
467+ kwargs ["num_inference_steps" ] = num_inference_steps
468+
469+ output = scheduler .step_prk (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
470+ output_pt = scheduler_pt .step_prk (residual_pt , 1 , sample_pt , num_inference_steps , ** kwargs )["prev_sample" ]
471+
472+ assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
473+
474+ output = scheduler .step_plms (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
475+ output_pt = scheduler_pt .step_plms (residual_pt , 1 , sample_pt , num_inference_steps , ** kwargs )["prev_sample" ]
476+
477+ assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
478+
479+ def test_step_shape (self ):
480+ kwargs = dict (self .forward_default_kwargs )
481+
482+ num_inference_steps = kwargs .pop ("num_inference_steps" , None )
483+
484+ for scheduler_class in self .scheduler_classes :
485+ scheduler_config = self .get_scheduler_config ()
486+ scheduler = scheduler_class (** scheduler_config )
487+
488+ sample = self .dummy_sample
489+ residual = 0.1 * sample
490+ # copy over dummy past residuals
491+ dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
492+ scheduler .ets = dummy_past_residuals [:]
493+
494+ if num_inference_steps is not None and hasattr (scheduler , "set_timesteps" ):
495+ scheduler .set_timesteps (num_inference_steps )
496+ elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
497+ kwargs ["num_inference_steps" ] = num_inference_steps
498+
499+ output_0 = scheduler .step_prk (residual , 0 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
500+ output_1 = scheduler .step_prk (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
501+
502+ self .assertEqual (output_0 .shape , sample .shape )
503+ self .assertEqual (output_0 .shape , output_1 .shape )
504+
505+ output_0 = scheduler .step_plms (residual , 0 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
506+ output_1 = scheduler .step_plms (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
507+
508+ self .assertEqual (output_0 .shape , sample .shape )
509+ self .assertEqual (output_0 .shape , output_1 .shape )
510+
431511 def test_timesteps (self ):
432512 for timesteps in [100 , 1000 ]:
433513 self .check_over_configs (num_train_timesteps = timesteps )
434514
435- def test_timesteps_pmls (self ):
436- for timesteps in [100 , 1000 ]:
437- self .check_over_configs_pmls (num_train_timesteps = timesteps )
438-
439515 def test_betas (self ):
440516 for beta_start , beta_end in zip ([0.0001 , 0.001 , 0.01 ], [0.002 , 0.02 , 0.2 ]):
441517 self .check_over_configs (beta_start = beta_start , beta_end = beta_end )
442518
443- def test_betas_pmls (self ):
444- for beta_start , beta_end in zip ([0.0001 , 0.001 , 0.01 ], [0.002 , 0.02 , 0.2 ]):
445- self .check_over_configs_pmls (beta_start = beta_start , beta_end = beta_end )
446-
447519 def test_schedules (self ):
448520 for schedule in ["linear" , "squaredcos_cap_v2" ]:
449521 self .check_over_configs (beta_schedule = schedule )
450522
451- def test_schedules_pmls (self ):
452- for schedule in ["linear" , "squaredcos_cap_v2" ]:
453- self .check_over_configs (beta_schedule = schedule )
454-
455523 def test_time_indices (self ):
456524 for t in [1 , 5 , 10 ]:
457525 self .check_over_forward (time_step = t )
458526
459- def test_time_indices_pmls (self ):
460- for t in [1 , 5 , 10 ]:
461- self .check_over_forward_pmls (time_step = t )
462-
463527 def test_inference_steps (self ):
464528 for t , num_inference_steps in zip ([1 , 5 , 10 ], [10 , 50 , 100 ]):
465529 self .check_over_forward (time_step = t , num_inference_steps = num_inference_steps )
466530
467- def test_inference_steps_pmls (self ):
468- for t , num_inference_steps in zip ([1 , 5 , 10 ], [10 , 50 , 100 ]):
469- self .check_over_forward_pmls (time_step = t , num_inference_steps = num_inference_steps )
470-
471- def test_inference_pmls_no_past_residuals (self ):
531+ def test_inference_plms_no_past_residuals (self ):
472532 with self .assertRaises (ValueError ):
473533 scheduler_class = self .scheduler_classes [0 ]
474534 scheduler_config = self .get_scheduler_config ()
475535 scheduler = scheduler_class (** scheduler_config )
476536
477- scheduler .set_plms_mode ()
478-
479- scheduler .step (self .dummy_sample , 1 , self .dummy_sample , 50 )["prev_sample" ]
537+ scheduler .step_plms (self .dummy_sample , 1 , self .dummy_sample , 50 )["prev_sample" ]
480538
481539 def test_full_loop_no_noise (self ):
482540 scheduler_class = self .scheduler_classes [0 ]
@@ -486,20 +544,15 @@ def test_full_loop_no_noise(self):
486544 num_inference_steps = 10
487545 model = self .dummy_model ()
488546 sample = self .dummy_sample_deter
547+ scheduler .set_timesteps (num_inference_steps )
489548
490- prk_time_steps = scheduler .get_prk_time_steps (num_inference_steps )
491- for t in range (len (prk_time_steps )):
492- t_orig = prk_time_steps [t ]
493- residual = model (sample , t_orig )
494-
495- sample = scheduler .step_prk (residual , t , sample , num_inference_steps )["prev_sample" ]
496-
497- timesteps = scheduler .get_time_steps (num_inference_steps )
498- for t in range (len (timesteps )):
499- t_orig = timesteps [t ]
500- residual = model (sample , t_orig )
549+ for i , t in enumerate (scheduler .prk_timesteps ):
550+ residual = model (sample , t )
551+ sample = scheduler .step_prk (residual , i , sample , num_inference_steps )["prev_sample" ]
501552
502- sample = scheduler .step_plms (residual , t , sample , num_inference_steps )["prev_sample" ]
553+ for i , t in enumerate (scheduler .plms_timesteps ):
554+ residual = model (sample , t )
555+ sample = scheduler .step_plms (residual , i , sample , num_inference_steps )["prev_sample" ]
503556
504557 result_sum = np .sum (np .abs (sample ))
505558 result_mean = np .mean (np .abs (sample ))
@@ -562,7 +615,6 @@ def check_over_configs(self, time_step=0, **config):
562615 kwargs = dict (self .forward_default_kwargs )
563616
564617 for scheduler_class in self .scheduler_classes :
565- scheduler_class = self .scheduler_classes [0 ]
566618 sample = self .dummy_sample
567619 residual = 0.1 * sample
568620
@@ -591,7 +643,6 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
591643 sample = self .dummy_sample
592644 residual = 0.1 * sample
593645
594- scheduler_class = self .scheduler_classes [0 ]
595646 scheduler_config = self .get_scheduler_config ()
596647 scheduler = scheduler_class (** scheduler_config )
597648
0 commit comments