@@ -74,7 +74,7 @@ def get_chebs(num_epochs):
7474
7575
7676def normalize_gradient (x , use_channels = False , epsilon = 1e-8 ):
77- """use stdev to normalize gradients"""
77+ """ use stdev to normalize gradients """
7878 size = x .dim ()
7979 # print(f"size = {size}")
8080
@@ -90,7 +90,7 @@ def normalize_gradient(x, use_channels=False, epsilon=1e-8):
9090
9191
9292def centralize_gradient (x , gc_conv_only = False ):
93- """credit - https://github.com/Yonghongwei/Gradient-Centralization"""
93+ """credit - https://github.com/Yonghongwei/Gradient-Centralization """
9494
9595 size = x .dim ()
9696 # print(f"size = {size}")
@@ -144,6 +144,7 @@ def __init__(
144144 warmup_pct_default = 0.22 ,
145145 logging_active = True ,
146146 ):
147+
147148 # todo - checks on incoming params
148149 defaults = dict (
149150 lr = lr , momentum = momentum , betas = betas , eps = eps , weight_decay = weight_decay
@@ -352,13 +353,13 @@ def show_settings(self):
352353
353354 if self .warmdown_active :
354355 print (
355- f"\n Warm-down: Linear warmdown, starting at { self .warm_down_start_pct * 100 } %, iteration { self .start_warm_down } of { self .total_iterations } "
356+ f"\n Warm-down: Linear warmdown, starting at { self .warm_down_start_pct * 100 } %, iteration { self .start_warm_down } of { self .total_iterations } "
356357 )
357358 print (f"warm down will decay until { self .min_lr } lr" )
358359
359360 # lookahead functions
360361 def clear_cache (self ):
361- """clears the lookahead cached params"""
362+ """clears the lookahead cached params """
362363
363364 print (f"clearing lookahead cache..." )
364365 for group in self .param_groups :
@@ -390,7 +391,7 @@ def backup_and_load_cache(self):
390391 p .data .copy_ (param_state ["lookahead_params" ])
391392
392393 def unit_norm (self , x ):
393- """axis-based Euclidean norm"""
394+ """ axis-based Euclidean norm"""
394395 # verify shape
395396 keepdim = True
396397 dim = None
@@ -432,6 +433,7 @@ def agc(self, p):
432433 p .grad .detach ().copy_ (new_grads )
433434
434435 def warmup_dampening (self , lr , step ):
436+
435437 style = self .warmup_type
436438 warmup = self .num_warmup_iters
437439
@@ -440,6 +442,7 @@ def warmup_dampening(self, lr, step):
440442
441443 if step > warmup :
442444 if not self .warmup_complete :
445+
443446 if not self .warmup_curr_pct == 1.0 :
444447 print (
445448 f"Error - lr did not achieve full set point from warmup, currently { self .warmup_curr_pct } "
@@ -462,7 +465,7 @@ def warmup_dampening(self, lr, step):
462465 raise ValueError (f"warmup type { style } not implemented." )
463466
464467 def get_warm_down (self , lr , iteration ):
465- """linear style warmdown"""
468+ """ linear style warmdown """
466469 if iteration < self .start_warm_down :
467470 return lr
468471
@@ -475,8 +478,8 @@ def get_warm_down(self, lr, iteration):
475478 self .warmdown_displayed = True
476479
477480 warmdown_iteration = (
478- ( iteration + 1 ) - self . start_warm_down
479- ) # to force the first iteration to be 1 instead of 0
481+ iteration + 1
482+ ) - self . start_warm_down # to force the first iteration to be 1 instead of 0
480483
481484 if warmdown_iteration < 1 :
482485 print (
@@ -486,8 +489,8 @@ def get_warm_down(self, lr, iteration):
486489 # print(f"warmdown iteration = {warmdown_iteration}")
487490 # linear start 3672 5650 total iterations 1972 iterations
488491
489- warmdown_pct = (
490- warmdown_iteration / ( self .warmdown_total_iterations + 1 )
492+ warmdown_pct = warmdown_iteration / (
493+ self .warmdown_total_iterations + 1
491494 ) # +1 to offset that we have to include first as an iteration to support 1 index instead of 0 based.
492495 if warmdown_pct > 1.00 :
493496 print (f"error in warmdown pct calc. new pct = { warmdown_pct } " )
@@ -534,6 +537,7 @@ def track_epochs(self, iteration):
534537 self .backup_and_load_cache ()
535538
536539 def get_cheb_lr (self , lr , iteration ):
540+
537541 # first confirm we are done with warmup
538542 if self .use_warmup :
539543 if iteration < self .num_warmup_iters + 1 :
@@ -569,6 +573,7 @@ def get_state_values(self, group, state):
569573 # @staticmethod
570574 @torch .no_grad ()
571575 def step (self , closure = None ):
576+
572577 loss = None
573578 if closure is not None and isinstance (closure , collections .abc .Callable ):
574579 with torch .enable_grad ():
@@ -693,15 +698,15 @@ def step(self, closure=None):
693698 if not self .param_size :
694699 self .param_size = param_size
695700 print (f"params size saved" )
696- print (f"total param groups = { i + 1 } " )
697- print (f"total params in groups = { j + 1 } " )
701+ print (f"total param groups = { i + 1 } " )
702+ print (f"total params in groups = { j + 1 } " )
698703
699704 if not self .param_size :
700705 raise ValueError ("failed to set param size" )
701706
702707 # stable weight decay
703708 if self .use_madgrad :
704- variance_normalized = torch .pow (variance_ma_sum / param_size , 1 / 3 )
709+ variance_normalized = torch .pow (variance_ma_sum / param_size , 1 / 3 )
705710 else :
706711 variance_normalized = math .sqrt (variance_ma_sum / param_size )
707712 # variance_mean = variance_ma_sum / param_size
@@ -849,6 +854,7 @@ def step(self, closure=None):
849854 variance_ma_belief = state ["variance_ma_belief" ]
850855
851856 if self .momentum_pnm :
857+
852858 max_variance_ma = state ["max_variance_ma" ]
853859
854860 if state ["step" ] % 2 == 1 :
@@ -862,8 +868,8 @@ def step(self, closure=None):
862868 state ["grad_ma" ],
863869 )
864870
865- bias_correction1 = 1 - beta1 ** step
866- bias_correction2 = 1 - beta2 ** step
871+ bias_correction1 = 1 - beta1 ** step
872+ bias_correction2 = 1 - beta2 ** step
867873
868874 if self .momentum_pnm :
869875 # Maintains the maximum of all 2nd moment running avg. till now
@@ -883,9 +889,9 @@ def step(self, closure=None):
883889 grad = normalize_gradient (grad )
884890
885891 if not self .use_adabelief :
886- grad_ma .mul_ (beta1 ** 2 ).add_ (grad , alpha = 1 - beta1 ** 2 )
892+ grad_ma .mul_ (beta1 ** 2 ).add_ (grad , alpha = 1 - beta1 ** 2 )
887893
888- noise_norm = math .sqrt ((1 + beta2 ) ** 2 + beta2 ** 2 )
894+ noise_norm = math .sqrt ((1 + beta2 ) ** 2 + beta2 ** 2 )
889895
890896 step_size = lr / bias_correction1
891897
0 commit comments