Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Testing the lion optimizer #432

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from training.data import get_data
from training.distributed import is_master, init_distributed_device, broadcast_object
from training.logger import setup_logging
from training.optimizers import Lion
from training.params import parse_args
from training.scheduler import cosine_lr, const_lr, const_lr_cooldown
from training.train import train_one_epoch, evaluate
Expand Down Expand Up @@ -296,15 +297,29 @@ def main(args):
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
if 'lion' in args.opt:
logging.info('Using Lion optimizer.')
optimizer = Lion(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
use_triton='triton' in args.opt,
)
else:
logging.info('Using adamw optimizer.')
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)

if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down
1 change: 1 addition & 0 deletions src/training/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lion import Lion
89 changes: 89 additions & 0 deletions src/training/optimizers/lion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# This file is from https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py
from typing import Tuple, Optional, Callable

import torch
from torch.optim.optimizer import Optimizer

# functions

def exists(val):
return val is not None

# update functions

def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay

p.data.mul_(1 - lr * wd)

# weight update

update = exp_avg.clone().lerp_(grad, 1 - beta1)
p.add_(torch.sign(update), alpha = -lr)

# decay the momentum running average coefficient

exp_avg.lerp_(grad, 1 - beta2)

# class

class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])

defaults = dict(
lr = lr,
betas = betas,
weight_decay = weight_decay
)

super().__init__(params, defaults)

self.update_fn = update_fn

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
self.update_fn = triton_update_fn

@torch.no_grad()
def step(
self,
closure: Optional[Callable] = None
):

loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]

# init state - exponential moving average of gradient values

if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)

exp_avg = state['exp_avg']

self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)

return loss
4 changes: 4 additions & 0 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def parse_args(args):
parser.add_argument(
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
)
parser.add_argument(
"--opt", type=str, default='adamw',
help="Which optimizer to use. Choices are ['adamw', 'lion', 'lion-triton']."
)
parser.add_argument(
"--use-bn-sync",
default=False,
Expand Down