From 24ecf20b11a7c8ddbc15e33a30f0be0cc73b145d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 7 Sep 2021 13:15:30 -0700 Subject: [PATCH] tweak eps for causal linear attn --- linear_attention_transformer/linear_attention_transformer.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/linear_attention_transformer/linear_attention_transformer.py b/linear_attention_transformer/linear_attention_transformer.py index 1a6193f..5c8e1fc 100644 --- a/linear_attention_transformer/linear_attention_transformer.py +++ b/linear_attention_transformer/linear_attention_transformer.py @@ -220,7 +220,7 @@ def linear_attn(q, k, v, kv_mask = None): attn = einsum('bhnd,bhde->bhne', q, context) return attn.reshape(*q.shape) -def causal_linear_attn(q, k, v, kv_mask = None, bucket_size = None, eps = 1e-6): +def causal_linear_attn(q, k, v, kv_mask = None, bucket_size = None, eps = 1e-3): b, h, n, e, dtype = *q.shape, q.dtype bucket_size = default(bucket_size, 64) bucket_size = max(bucket_size, 1) @@ -253,7 +253,7 @@ def causal_linear_attn(q, k, v, kv_mask = None, bucket_size = None, eps = 1e-6): b_k_cumsum = F.pad(b_k_cumsum, (0, 0, 1, 0), value = 0.) b_k_cumsum, _ = split_at_index(2, -1, b_k_cumsum) - D_inv = 1. / einsum('bhud,bhund->bhun', b_k_cumsum + eps, b_q) + D_inv = 1. / einsum('bhud,bhund->bhun', b_k_cumsum, b_q).clamp(min = eps) attn = einsum('bhund,bhude,bhun->bhune', b_q, context, D_inv) return attn.reshape(*q.shape) diff --git a/setup.py b/setup.py index 4e04dbd..4ef2b08 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'linear_attention_transformer', packages = find_packages(exclude=['examples']), - version = '0.19.0', + version = '0.19.1', license='MIT', description = 'Linear Attention Transformer', author = 'Phil Wang',