Skip to content

Commit

Permalink
tweak eps for causal linear attn
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 7, 2021
1 parent 734e304 commit 24ecf20
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions linear_attention_transformer/linear_attention_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def linear_attn(q, k, v, kv_mask = None):
attn = einsum('bhnd,bhde->bhne', q, context)
return attn.reshape(*q.shape)

def causal_linear_attn(q, k, v, kv_mask = None, bucket_size = None, eps = 1e-6):
def causal_linear_attn(q, k, v, kv_mask = None, bucket_size = None, eps = 1e-3):
b, h, n, e, dtype = *q.shape, q.dtype
bucket_size = default(bucket_size, 64)
bucket_size = max(bucket_size, 1)
Expand Down Expand Up @@ -253,7 +253,7 @@ def causal_linear_attn(q, k, v, kv_mask = None, bucket_size = None, eps = 1e-6):
b_k_cumsum = F.pad(b_k_cumsum, (0, 0, 1, 0), value = 0.)
b_k_cumsum, _ = split_at_index(2, -1, b_k_cumsum)

D_inv = 1. / einsum('bhud,bhund->bhun', b_k_cumsum + eps, b_q)
D_inv = 1. / einsum('bhud,bhund->bhun', b_k_cumsum, b_q).clamp(min = eps)
attn = einsum('bhund,bhude,bhun->bhune', b_q, context, D_inv)
return attn.reshape(*q.shape)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'linear_attention_transformer',
packages = find_packages(exclude=['examples']),
version = '0.19.0',
version = '0.19.1',
license='MIT',
description = 'Linear Attention Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 24ecf20

Please sign in to comment.