Skip to content

Commit

Permalink
actually, just follow @ipoletaev advice and remove autotuner for now
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 9, 2023
1 parent 6ab873a commit 2226ec8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 36 deletions.
17 changes: 5 additions & 12 deletions lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple, Optional, Callable

import torch
Expand Down Expand Up @@ -33,7 +34,8 @@ def __init__(
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False
use_triton: bool = False,
triton_block_size: int = 1024
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
Expand All @@ -52,7 +54,7 @@ def __init__(

if use_triton:
from lion_pytorch.triton import update_fn as triton_update_fn
self.update_fn = triton_update_fn
self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size)

@torch.no_grad()
def step(
Expand All @@ -65,11 +67,6 @@ def step(
with torch.enable_grad():
loss = closure()

# address an issue with autotune and in-place updates with triton
# on the first .step call, simply do not update parameters in-place, if using triton

update_kwargs = dict(inplace = False) if self.use_triton and not self.took_first_step else dict()

# update all parameters

for group in self.param_groups:
Expand All @@ -91,11 +88,7 @@ def step(
lr,
wd,
beta1,
beta2,
**update_kwargs
beta2
)

if not self.took_first_step:
self.took_first_step = True

return loss
43 changes: 20 additions & 23 deletions lion_pytorch/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
print('triton is not installed, please install by running `pip install triton -U --pre`')
exit()

# helper functions

def calc_num_warps(block_size):
num_warps = 4
if block_size >= 2048:
num_warps = 8
if block_size >= 4096:
num_warps = 16
return num_warps

# triton cuda kernel

@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
], key = ['n_elements'])
@triton.jit
def update_fn_kernel(
p_ptr,
Expand Down Expand Up @@ -81,37 +87,28 @@ def update_fn(
wd: float,
beta1: float,
beta2: float,
inplace: bool = True
inplace: bool = True,
BLOCK_SIZE: int = 1024
):
assert all([t.is_cuda for t in (p, grad, exp_avg)])
n_elements = p.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

# address autotune and in-place update issue

if not inplace:
orig_p = p
orig_exp_avg = exp_avg
n_elements = p.numel()

p = p.clone()
exp_avg = exp_avg.clone()
block_size = triton.next_power_of_2(BLOCK_SIZE)
num_warps = calc_num_warps(block_size)
n_rows = triton.cdiv(n_elements, block_size)

# call triton cuda kernel

update_fn_kernel[grid](
update_fn_kernel[(n_rows,)](
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2,
n_elements
n_elements,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
)

# update if not in-place call

if not inplace:
orig_p.copy_(p)
orig_exp_avg.copy_(exp_avg)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.1.0',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

1 comment on commit 2226ec8

@ipoletaev
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant inplace argument left inside triton'ed update_fn.

Please sign in to comment.