diff --git a/lion/lion_pytorch.py b/lion/lion_pytorch.py index 9f5662f9..ffd52d99 100644 --- a/lion/lion_pytorch.py +++ b/lion/lion_pytorch.py @@ -77,7 +77,12 @@ def step(self, closure=None): # Weight update update = exp_avg * beta1 + grad * (1 - beta1) - p.add_(torch.sign(update), alpha=-group['lr']) + + p.add_(torch.sign(update), alpha=-group['lr'], inplace=True) + #This has been made more efficient by using the torch.sign function's inplace mode. + #This will prevent the need to create a new tensor for the updated parameter, + #which can save a significant amount of time for large models. + # Decay the momentum running average coefficient exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)