Skip to content

Commit

Permalink
address an issue with triton auto-tuner and in-place calls. make the …
Browse files Browse the repository at this point in the history
…assumption that after the first optimizer.step call, things are properly cached
  • Loading branch information
lucidrains committed May 9, 2023
1 parent ee26e32 commit 6ab873a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
15 changes: 14 additions & 1 deletion lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(
super().__init__(params, defaults)

self.update_fn = update_fn
self.use_triton = use_triton
self.took_first_step = False

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
Expand All @@ -63,6 +65,13 @@ def step(
with torch.enable_grad():
loss = closure()

# address an issue with autotune and in-place updates with triton
# on the first .step call, simply do not update parameters in-place, if using triton

update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict()

# update all parameters

for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

Expand All @@ -82,7 +91,11 @@ def step(
lr,
wd,
beta1,
beta2
beta2,
**update_kwargs
)

if not self.took_first_step:
self.took_first_step = True

return loss
28 changes: 24 additions & 4 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import Tensor

try:
import triton
Expand All @@ -7,6 +8,7 @@
print('triton is not installed, please install by running `pip install triton -U --pre`')
exit()

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
Expand Down Expand Up @@ -72,19 +74,31 @@ def update_fn_kernel(
tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)

def update_fn(
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
p: Tensor,
grad: Tensor,
exp_avg: Tensor,
lr: float,
wd: float,
beta1: float,
beta2: float
beta2: float,
inplace: bool = True
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])
n_elements = p.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

# address autotune and in-place update issue

if not inplace:
orig_p = p
orig_exp_avg = exp_avg

p = p.clone()
exp_avg = exp_avg.clone()

# call triton cuda kernel

update_fn_kernel[grid](
p,
grad,
Expand All @@ -95,3 +109,9 @@ def update_fn(
beta2,
n_elements
)

# update if not in-place call

if not inplace:
orig_p.copy_(p)
orig_exp_avg.copy_(exp_avg)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.8',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6ab873a

Please sign in to comment.