From c7392f2bab3165244d1c565b66409fa11fa82367 Mon Sep 17 00:00:00 2001 From: Sachin Panicker <86884610+sachinspanicker@users.noreply.github.com> Date: Thu, 10 Aug 2023 04:29:26 +0530 Subject: [PATCH] p.add_(torch.sign(update), alpha=-group['lr']) has been made more efficient by using the torch.sign function's inplace mode. This will prevent the need to create a new tensor for the updated parameter, which can save a significant amount of time for large models. (#1193) --- lion/lion_pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lion/lion_pytorch.py b/lion/lion_pytorch.py index 9f5662f9..ffd52d99 100644 --- a/lion/lion_pytorch.py +++ b/lion/lion_pytorch.py @@ -77,7 +77,12 @@ def step(self, closure=None): # Weight update update = exp_avg * beta1 + grad * (1 - beta1) - p.add_(torch.sign(update), alpha=-group['lr']) + + p.add_(torch.sign(update), alpha=-group['lr'], inplace=True) + #This has been made more efficient by using the torch.sign function's inplace mode. + #This will prevent the need to create a new tensor for the updated parameter, + #which can save a significant amount of time for large models. + # Decay the momentum running average coefficient exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)