Skip to content

Commit

Permalink
[cuda kernels] only compile them when initializing (#29133)
Browse files Browse the repository at this point in the history
* only compile when needed

* fix mra as well

* fix yoso as well

* update

* rempve comment

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

* opps

* Update src/transformers/models/deta/modeling_deta.py

* nit
  • Loading branch information
ArthurZucker committed Feb 20, 2024
1 parent a7755d2 commit 5e95dca
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 68 deletions.
53 changes: 42 additions & 11 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import copy
import math
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -46,21 +48,42 @@
from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels


logger = logging.get_logger(__name__)

# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None
MultiScaleDeformableAttention = None


def load_cuda_kernels():
from torch.utils.cpp_extension import load

global MultiScaleDeformableAttention

root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]

MultiScaleDeformableAttention = load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)


if is_vision_available():
from transformers.image_transforms import center_to_corners_format
Expand Down Expand Up @@ -590,6 +613,14 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):

def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__()

kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
Expand Down
29 changes: 13 additions & 16 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,15 @@

logger = logging.get_logger(__name__)

MultiScaleDeformableAttention = None


# Copied from models.deformable_detr.load_cuda_kernels
def load_cuda_kernels():
from torch.utils.cpp_extension import load

global MultiScaleDeformableAttention

root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
src_files = [
root / filename
Expand All @@ -78,22 +83,6 @@ def load_cuda_kernels():
],
)

import MultiScaleDeformableAttention as MSDA

return MSDA


# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function):
Expand Down Expand Up @@ -596,6 +585,14 @@ class DetaMultiscaleDeformableAttention(nn.Module):

def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
super().__init__()

kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
Expand Down
40 changes: 15 additions & 25 deletions src/transformers/models/mra/modeling_mra.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,19 @@
# See all Mra models at https://huggingface.co/models?filter=mra
]

mra_cuda_kernel = None


def load_cuda_kernels():
global cuda_kernel
global mra_cuda_kernel
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra"

def append_root(files):
return [src_folder / file for file in files]

src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"])

cuda_kernel = load("cuda_kernel", src_files, verbose=True)

import cuda_kernel


cuda_kernel = None


if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")

try:
load_cuda_kernels()
except Exception as e:
logger.warning(
"Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of"
f" PyTorch and CUDA Toolkit are installed: {e}"
)
else:
pass
mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True)


def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
Expand All @@ -112,7 +95,7 @@ def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
indices = indices.int()
indices = indices.contiguous()

max_vals, max_vals_scatter = cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :]

return max_vals, max_vals_scatter
Expand Down Expand Up @@ -178,7 +161,7 @@ def mm_to_sparse(dense_query, dense_key, indices, block_size=32):
indices = indices.int()
indices = indices.contiguous()

return cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())


def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32):
Expand Down Expand Up @@ -216,7 +199,7 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz
indices = indices.contiguous()
dense_key = dense_key.contiguous()

dense_qk_prod = cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
return dense_qk_prod

Expand Down Expand Up @@ -393,7 +376,7 @@ def mra2_attention(
"""
Use Mra to approximate self-attention.
"""
if cuda_kernel is None:
if mra_cuda_kernel is None:
return torch.zeros_like(query).requires_grad_()

batch_size, num_head, seq_len, head_dim = query.size()
Expand Down Expand Up @@ -561,6 +544,13 @@ def __init__(self, config, position_embedding_type=None):
f"heads ({config.num_attention_heads})"
)

kernel_loaded = mra_cuda_kernel is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
Expand Down
39 changes: 23 additions & 16 deletions src/transformers/models/yoso/modeling_yoso.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_torch_cuda_available,
logging,
)
from .configuration_yoso import YosoConfig


Expand All @@ -49,28 +56,22 @@
# See all YOSO models at https://huggingface.co/models?filter=yoso
]

lsh_cumulation = None


def load_cuda_kernels():
global lsh_cumulation
try:
from torch.utils.cpp_extension import load
from torch.utils.cpp_extension import load

def append_root(files):
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
return [src_folder / file for file in files]

src_files = append_root(
["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"]
)
def append_root(files):
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
return [src_folder / file for file in files]

load("fast_lsh_cumulation", src_files, verbose=True)
src_files = append_root(["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"])

import fast_lsh_cumulation as lsh_cumulation
load("fast_lsh_cumulation", src_files, verbose=True)

return True
except Exception:
lsh_cumulation = None
return False
import fast_lsh_cumulation as lsh_cumulation


def to_contiguous(input_tensors):
Expand Down Expand Up @@ -305,6 +306,12 @@ def __init__(self, config, position_embedding_type=None):
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
kernel_loaded = lsh_cumulation is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
Expand Down

0 comments on commit 5e95dca

Please sign in to comment.