diff --git a/ranger21/ranger21.py b/ranger21/ranger21.py index 88968fe..a8ab0a5 100644 --- a/ranger21/ranger21.py +++ b/ranger21/ranger21.py @@ -735,13 +735,16 @@ def step(self, closure=None): if self.use_adabelief: grad_ma.mul_(beta1).add_(grad, alpha=1 - beta1) grad_residual = grad - grad_ma - variance_ma_belief.mul_(beta2).addcmul( + variance_ma_belief.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2 ) # print(f"upper loop grad = {grad.shape}") variance_ma.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # print(f"variance_ma, grad adjusted") - variance_ma_debiased = variance_ma / bias_correction2 + if self.use_adabelief: + variance_ma_debiased = variance_ma_belief / bias_correction2 + else: + variance_ma_debiased = variance_ma / bias_correction2 variance_ma_sum += variance_ma_debiased.sum() # print(f"variance_ma_sum = {variance_ma_sum}") @@ -909,7 +912,8 @@ def step(self, closure=None): # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_variance_ma, variance_ma, out=variance_ma) # Use the max. for normalizing running avg. of gradient - denom = (variance_ma.sqrt() / math.sqrt(bias_correction2)).add_( + variance_source = variance_ma_belief if self.use_adabelief else variance_ma + denom = (variance_source.sqrt() / math.sqrt(bias_correction2)).add_( group["eps"] )