Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few more features_intermediate() models, AttentionExtract helper, related minor cleanup. #2168

Merged
merged 9 commits into from
May 11, 2024
3 changes: 2 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
Expand Down
2 changes: 1 addition & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
from .space_to_depth import SpaceToDepth, DepthToSpace
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
Expand Down
90 changes: 0 additions & 90 deletions timm/layers/activations_jit.py

This file was deleted.

92 changes: 41 additions & 51 deletions timm/layers/activations_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.

These activations are not compatible with jit scripting or ONNX export of the model, please use either
the JIT or basic versions of the activations.
These activations are not compatible with jit scripting or ONNX export of the model, please use
basic versions of the activations.

Hacked together by / Copyright 2020 Ross Wightman
"""
Expand All @@ -14,19 +14,17 @@
from torch.nn import functional as F


@torch.jit.script
def swish_jit_fwd(x):
def swish_fwd(x):
return x.mul(torch.sigmoid(x))


@torch.jit.script
def swish_jit_bwd(x, grad_output):
def swish_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))


class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
class SwishAutoFn(torch.autograd.Function):
""" optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
"""
Expand All @@ -37,123 +35,117 @@ def symbolic(g, x):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
return swish_fwd(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
return swish_bwd(x, grad_output)


def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x)
return SwishAutoFn.apply(x)


class SwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishMe, self).__init__()

def forward(self, x):
return SwishJitAutoFn.apply(x)
return SwishAutoFn.apply(x)


@torch.jit.script
def mish_jit_fwd(x):
def mish_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))


@torch.jit.script
def mish_jit_bwd(x, grad_output):
def mish_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))


class MishJitAutoFn(torch.autograd.Function):
class MishAutoFn(torch.autograd.Function):
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
A memory efficient, jit scripted variant of Mish
A memory efficient variant of Mish
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
return mish_fwd(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
return mish_bwd(x, grad_output)


def mish_me(x, inplace=False):
return MishJitAutoFn.apply(x)
return MishAutoFn.apply(x)


class MishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(MishMe, self).__init__()

def forward(self, x):
return MishJitAutoFn.apply(x)
return MishAutoFn.apply(x)


@torch.jit.script
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
def hard_sigmoid_fwd(x, inplace: bool = False):
return (x + 3).clamp(min=0, max=6).div(6.)


@torch.jit.script
def hard_sigmoid_jit_bwd(x, grad_output):
def hard_sigmoid_bwd(x, grad_output):
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
return grad_output * m


class HardSigmoidJitAutoFn(torch.autograd.Function):
class HardSigmoidAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_sigmoid_jit_fwd(x)
return hard_sigmoid_fwd(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_sigmoid_jit_bwd(x, grad_output)
return hard_sigmoid_bwd(x, grad_output)


def hard_sigmoid_me(x, inplace: bool = False):
return HardSigmoidJitAutoFn.apply(x)
return HardSigmoidAutoFn.apply(x)


class HardSigmoidMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidMe, self).__init__()

def forward(self, x):
return HardSigmoidJitAutoFn.apply(x)
return HardSigmoidAutoFn.apply(x)


@torch.jit.script
def hard_swish_jit_fwd(x):
def hard_swish_fwd(x):
return x * (x + 3).clamp(min=0, max=6).div(6.)


@torch.jit.script
def hard_swish_jit_bwd(x, grad_output):
def hard_swish_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= 3.)
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
return grad_output * m


class HardSwishJitAutoFn(torch.autograd.Function):
"""A memory efficient, jit-scripted HardSwish activation"""
class HardSwishAutoFn(torch.autograd.Function):
"""A memory efficient HardSwish activation"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_swish_jit_fwd(x)
return hard_swish_fwd(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output)
return hard_swish_bwd(x, grad_output)

@staticmethod
def symbolic(g, self):
Expand All @@ -164,55 +156,53 @@ def symbolic(g, self):


def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x)
return HardSwishAutoFn.apply(x)


class HardSwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishMe, self).__init__()

def forward(self, x):
return HardSwishJitAutoFn.apply(x)
return HardSwishAutoFn.apply(x)


@torch.jit.script
def hard_mish_jit_fwd(x):
def hard_mish_fwd(x):
return 0.5 * x * (x + 2).clamp(min=0, max=2)


@torch.jit.script
def hard_mish_jit_bwd(x, grad_output):
def hard_mish_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= -2.)
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
return grad_output * m


class HardMishJitAutoFn(torch.autograd.Function):
""" A memory efficient, jit scripted variant of Hard Mish
class HardMishAutoFn(torch.autograd.Function):
""" A memory efficient variant of Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_mish_jit_fwd(x)
return hard_mish_fwd(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_mish_jit_bwd(x, grad_output)
return hard_mish_bwd(x, grad_output)


def hard_mish_me(x, inplace: bool = False):
return HardMishJitAutoFn.apply(x)
return HardMishAutoFn.apply(x)


class HardMishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishMe, self).__init__()

def forward(self, x):
return HardMishJitAutoFn.apply(x)
return HardMishAutoFn.apply(x)



Loading
Loading