Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prodigy, SophiaG optimizers #1350

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,26 +972,49 @@ def build(self, total_num_steps):

trainer_kwargs = {}

if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion

lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if self.cfg.optimizer in ["lion_pytorch", "prodigy", "sophia"]:
custom_optim_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
custom_optim_kwargs["weight_decay"] = training_arguments_kwargs[
"weight_decay"
]

if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
custom_optim_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)

trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
if self.cfg.optimizer == "lion_pytorch":
from axolotl.custom_optim.lion import Lion

trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **custom_optim_kwargs),
None,
)
if self.cfg.optimizer == "sophia":
from axolotl.custom_optim.sophia import SophiaG

trainer_kwargs["optimizers"] = (
SophiaG(params=self.model.parameters(), **custom_optim_kwargs),
None,
)
if self.cfg.optimizer == "prodigy":
from axolotl.custom_optim.prodigy import Prodigy

trainer_kwargs["optimizers"] = (
Prodigy(
params=filter(
lambda p: p.requires_grad, self.model.parameters()
),
**custom_optim_kwargs,
),
None,
)

# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"

Expand Down
Empty file.
182 changes: 182 additions & 0 deletions src/axolotl/custom_optim/lion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Callable, Optional, Tuple

import torch
from torch.optim.optimizer import Optimizer

try:
import triton
import triton.language as tl
except ImportError:
print(
"triton is not installed, please install by running `pip install triton -U --pre`"
)
exit()


def exists(val):
return val is not None


# update functions


def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thie update_fn function is redefined below on line 106.

# stepweight decay

p.data.mul_(1 - lr * wd)

# weight update

update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
p.add_(update, alpha=-lr)

# decay the momentum running average coefficient

exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)


def clone_inplace_updated_params(nargs):
nargs["p_ptr"] = nargs["p_ptr"].clone()
nargs["exp_avg_ptr"] = nargs["exp_avg_ptr"].clone()


# triton cuda kernel


@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE": 128}, num_warps=4, pre_hook=clone_inplace_updated_params
),
triton.Config(
{"BLOCK_SIZE": 1024}, num_warps=8, pre_hook=clone_inplace_updated_params
),
],
key=["n_elements"],
)
@triton.jit
def update_fn_kernel(
p_ptr,
grad_ptr,
exp_avg_ptr,
lr,
wd,
beta1,
beta2,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

mask = offsets < n_elements

# offsetted pointers

offset_p_ptr = p_ptr + offsets
offset_grad_ptr = grad_ptr + offsets
offset_exp_avg_ptr = exp_avg_ptr + offsets

# load

p = tl.load(offset_p_ptr, mask=mask)
grad = tl.load(offset_grad_ptr, mask=mask)
exp_avg = tl.load(offset_exp_avg_ptr, mask=mask)

# stepweight decay

p = p * (1 - lr * wd)

# diff between momentum running average and grad

diff = exp_avg - grad

# weight update

update = diff * beta1 + grad

# torch.sign

can_update = update != 0
update_sign = tl.where(update > 0, -lr, lr)

p = p + update_sign * can_update

# decay the momentum running average coefficient

exp_avg = diff * beta2 + grad

# store new params and momentum running average coefficient

tl.store(offset_p_ptr, p, mask=mask)
tl.store(offset_exp_avg_ptr, exp_avg, mask=mask)


def triton_update_fn(
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
lr: float,
wd: float,
beta1: float,
beta2: float,
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])
n_elements = p.numel()

def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

update_fn_kernel[grid](p, grad, exp_avg, lr, wd, beta1, beta2, n_elements)


class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])

defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)

super().__init__(params, defaults)

self.update_fn = update_fn

if use_triton:
self.update_fn = triton_update_fn

@torch.no_grad()
def step(self, closure: Optional[Callable] = None):
loss = None
if exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group["params"]):
grad, lr, wd, beta1, beta2, state = (
p.grad,
group["lr"],
group["weight_decay"],
*group["betas"],
self.state[p],
)

# init state - exponential moving average of gradient values

if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p)

exp_avg = state["exp_avg"]

self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2)

return loss
Loading
Loading