Skip to content

Commit

Permalink
Add fallback GELU implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
daskol committed Feb 4, 2022
1 parent 1ad44f7 commit 8ce2935
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
50 changes: 47 additions & 3 deletions fewbit/functional/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from sys import modules
from typing import Optional, Tuple, Union


# From Python 3.9 new routine for caching was introduced.
try:
from functools import cache
except ImportError:
from functools import lru_cache as cache


# Stepwise activation functions.
STEPWISE = ('hardshrink', 'hardsigmoid', 'hardtanh', 'leaky_relu', 'relu',
'relu6', 'softshrink', 'stepwise', 'threshold')
Expand All @@ -21,6 +29,37 @@
__all__ = STEPWISE + CONTINOUS + ('store', )


class GeluFallbackFunc(T.autograd.Function):
"""Class GeluFallbackFunc is a fallback implementation of GELU activation
function in pure Python.
"""

@staticmethod
@cache
def xs(device, dtype, bits: int):
return store.get('gelu', bits, device, dtype)[0][1:-1]

@staticmethod
@cache
def ys(device, dtype, bits: int):
return store.get('gelu', bits, device, dtype)[1]

@staticmethod
def forward(ctx, x: T.Tensor, bits: int = 3):
xs = GeluFallbackFunc.xs(x.device, x.dtype, bits)
discr = T.searchsorted(xs, x.float()).type(T.uint8)
ctx.save_for_backward(discr)
ctx.bits = bits
return T.nn.functional.gelu(x)

@staticmethod
def backward(ctx, grad_output):
discr, = ctx.saved_tensors
ys = GeluFallbackFunc.ys(grad_output.device, grad_output.dtype,
ctx.bits)
return ys[discr.type(T.int64)] * grad_output, None


class StepwiseStore:
"""Class StepwiseStore is a singleton object to store and cache stepwise
approximation for gradients of activation functions.
Expand Down Expand Up @@ -48,12 +87,15 @@ def get(self,
bits: int,
device: Union[None, str, T.device] = None,
dtype: Optional[T.dtype] = None):
key = (name, bits, T.device(device or 'cpu'), dtype or T.float32)
device = T.device(device or 'cpu')
dtype = dtype or T.float32
key = (name, bits, device, dtype)
if (leaf := self.CACHE.get(key, None)):
return leaf
if (leaf := self.STORE.get(key[:2], None)):
self.CACHE[key] = leaf
return leaf
leaf_cached = tuple(el.to(device, dtype) for el in leaf)
self.CACHE[key] = leaf_cached
return leaf_cached
raise KeyError(f'There is not {bits}-bit quantized gradients for '
f'activation function {name}.')

Expand Down Expand Up @@ -179,3 +221,5 @@ def stub(name, *args, **kwargs):
for name in CONTINOUS:
setattr(modules[__name__], name, dispatch(name, get_operator(name)))
del name

gelu = GeluFallbackFunc.apply # TODO: Force fallback implementation for now.
20 changes: 19 additions & 1 deletion fewbit/modules/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional, Tuple

from .. import functional
from ..functional.activations import stepwise
from ..functional.activations import GeluFallbackFunc, stepwise

# Stepwise activation functions.
STEPWISE = ('Hardshrink', 'Hardsigmoid', 'Hardtanh', 'LeakyReLU', 'ReLU',
Expand Down Expand Up @@ -127,10 +127,28 @@ def forward(self, xs: T.Tensor) -> T.Tensor:
return self._impl(xs, *self.args, **self.kwargs)


class GeluFallback(T.nn.Module):
"""Class GeluFallback implements GELU activation functions in pure Python.
"""

def __init__(self, bits: int = 3):
super().__init__()
self.bits = bits

def forward(self, x):
return GeluFallbackFunc.apply(x, self.bits)

def extra_repr(self) -> str:
return f'GeluFallback(bits={self.bits})'


# Produce PyTorch modules for in-place alternatives for built-in PyTorch
# activation function enumerated above manually at runtime.
for name in __all__:
if not hasattr(modules[__name__], name):
ty = type(name, (BuiltInStepwiseFunction, ), {})
setattr(modules[__name__], name, ty)
del name, ty


Gelu = GeluFallback # TODO: Force fallback implementation for now.

0 comments on commit 8ce2935

Please sign in to comment.