Skip to content

Commit

Permalink
attempt to substantially reduce kernel launch overhead (#3638)
Browse files Browse the repository at this point in the history
This improves kernel launch latency by 2.2x (from 108us to 49us using
@bertmaher's benchmarking script in issue
#3619 ). Thanks also to
@liboyue's analysis and suggestions.

See the discussion in the third-party PR
#3503 (comment)
  • Loading branch information
apgoucher committed Apr 12, 2024
1 parent 8c5e33c commit e1d0fea
Showing 1 changed file with 108 additions and 54 deletions.
162 changes: 108 additions & 54 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def annotation(self):
return ""
return _normalize_ty(self._param.annotation)

@cached_property
def annotation_type(self):
annotation = self.annotation
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"
return ""

@cached_property
def is_constexpr(self):
return "constexpr" in self.annotation
Expand Down Expand Up @@ -164,18 +175,10 @@ def name(self):
return self.param.name

def mangled_type(self):
annotation = self.param.annotation
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"

if "Tensor" in annotation:
key = self.value.dtype
else:
key = JITFunction._key_of(self.value)
annotation_type = self.param.annotation_type
if annotation_type:
return annotation_type
key = JITFunction._key_of(self.value)
return JITFunction._type_of(key, self.param.is_const)

def specialization_key(self):
Expand Down Expand Up @@ -218,6 +221,73 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke
return serialized_obj


def create_function_from_signature(sig):
"""
Equivalent to sig.bind followed by apply_defaults. This generates a
native Python function (using exec) which can be memoized on a per-kernel
basis to avoid having to run these expensive functions -- which constitute
much of the kernel launch overhead -- every time we run the kernel.
"""

# Create the function argument list and the dict entries for the return statement
func_args = []
dict_entries = []
for name, param in sig.parameters.items():
if param.default is inspect.Parameter.empty:
func_args.append(name)
dict_entries.append(f"'{name}': {name}")
else:
func_args.append(f"{name}=default_{name}")
dict_entries.append(f"'{name}': {name}")

# Join all arguments into a function definition string
args_str = ', '.join(func_args)
dict_str = ', '.join(dict_entries)
func_body = f"def dynamic_func({args_str}):\n return {{{dict_str}}}"

# Prepare defaults to be inserted into function namespace
func_namespace = {
f"default_{name}": param.default
for name, param in sig.parameters.items()
if param.default is not inspect.Parameter.empty
}

# Execute the function string in func_namespace to create the function
exec(func_body, func_namespace)

# Extract the newly created function from the namespace
return func_namespace['dynamic_func']


type_canonicalisation_dict = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8e5": "fp8e5",
"float8e4b15": "fp8e4b15",
"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8_e5m2": "fp8e5",
"float8e5b16": "fp8e5b16",
"float8_e5m2fnuz": "fp8e5b16",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}

for v in list(type_canonicalisation_dict.values()):
type_canonicalisation_dict[v] = v


class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
Expand Down Expand Up @@ -285,36 +355,13 @@ def _type_of(key, is_const=False):
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return "*i8"
elif isinstance(key, str):
return key

dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8e5": "fp8e5",
"float8e4b15": "fp8e4b15",
"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8_e5m2": "fp8e5",
"float8e5b16": "fp8e5b16",
"float8_e5m2fnuz": "fp8e5b16",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
const_str = "k" if is_const else ""
return key if isinstance(key, str) else f"*{const_str}{tys[dtype_str]}"
dtype_str = type_canonicalisation_dict[dtype_str]
const_str = "*k" if is_const else "*"
return const_str + dtype_str

def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
Expand Down Expand Up @@ -396,26 +443,35 @@ def run(self, *args, grid, warmup, **kwargs):
hook(*args, **kwargs)

# bind non-reserved keyword args and set defaults
if self.binder is None:
self.binder = create_function_from_signature(self.signature)
self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
self.specialised_indices = [i for (i, p) in enumerate(self.params) if not p.do_not_specialize]

kwargs = {k: v for k, v in kwargs.items() if k not in options.__dict__}
bound_args = self.signature.bind(*args, **kwargs)
bound_args.apply_defaults()
assert len(bound_args.arguments) == len(self.params)
bound_args = self.binder(*args, **kwargs)
assert len(bound_args) == len(self.params)
# canonicalize grid
assert grid is not None
if callable(grid):
# Arguments are passed as a dict to `grid`, by contract.
# TODO(jlebar): In the new launch API, pass the compiler flags as a
# second parameter to `grid`.
grid = grid(dict(bound_args.arguments))
grid = grid(dict(bound_args))
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1

# compute cache key
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.arguments.items(), self.params)]
sig_key = tuple(arg.mangled_type() for arg in args if not arg.param.is_constexpr)
spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr)
args = [KernelArg(arg_value, param) for (_, arg_value), param in zip(bound_args.items(), self.params)]

signature = {args[i].param.name: args[i].mangled_type() for i in self.non_constexpr_indices}
sig_key = tuple(signature.values())
spec_key = tuple(args[i].specialization_key() for i in self.specialised_indices)
constexpr_key = tuple(args[i].value for i in self.constexpr_indices)

key = (sig_key, constexpr_key, spec_key, options)
key = str(key)
# Kernel is not cached; we have to compile.
Expand All @@ -430,9 +486,6 @@ def run(self, *args, grid, warmup, **kwargs):
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")

# Build kernel signature -- doesn't include constexpr arguments.
signature = {arg.param.name: arg.mangled_type() for arg in args if not arg.param.is_constexpr}

if self._call_hook(key, signature, device, constants, options, configs):
return None
# compile the kernel
Expand All @@ -446,14 +499,13 @@ def run(self, *args, grid, warmup, **kwargs):
kernel = self.cache[device][key]

# Verify key signature from the cache
signature = {arg.param.name: arg.mangled_type() for arg in args if not arg.param.is_constexpr}
if kernel.src.signature != signature:
raise RuntimeError(f"Signature mismatch for cached kernel {self.fn.__name__}:\n"
f" Cached signature: {kernel.src.signature}\n"
f" Call signature: {signature}")

if not warmup:
args = [arg.value for arg in args if not arg.param.is_constexpr]
args = [args[i].value for i in self.non_constexpr_indices]
launch_metadata = kernel.launch_metadata(grid, stream, *args)
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
Expand All @@ -472,6 +524,8 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
self.launch_metadata = launch_metadata

self.binder = None

self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
Expand Down

0 comments on commit e1d0fea

Please sign in to comment.