diff --git a/lion/lion_pytorch.py b/lion/lion_pytorch.py index ffd52d99..b27f4e54 100644 --- a/lion/lion_pytorch.py +++ b/lion/lion_pytorch.py @@ -78,10 +78,7 @@ def step(self, closure=None): # Weight update update = exp_avg * beta1 + grad * (1 - beta1) - p.add_(torch.sign(update), alpha=-group['lr'], inplace=True) - #This has been made more efficient by using the torch.sign function's inplace mode. - #This will prevent the need to create a new tensor for the updated parameter, - #which can save a significant amount of time for large models. + p.add_(update.sign_(), alpha=-group['lr']) # Decay the momentum running average coefficient exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)