-
Notifications
You must be signed in to change notification settings - Fork 174
/
linear_attention.py
92 lines (73 loc) · 3.32 KB
/
linear_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
"""Implement unmasked linear attention."""
import torch
from torch.nn import Module
from ..attention_registry import AttentionRegistry, Optional, Callable, Int, \
EventDispatcherInstance
from ..events import EventDispatcher
from ..feature_maps import elu_feature_map
class LinearAttention(Module):
"""Implement unmasked attention using dot product of feature maps in
O(N D^2) complexity.
Given the queries, keys and values as Q, K, V instead of computing
V' = softmax(Q.mm(K.t()), dim=-1).mm(V),
we make use of a feature map function Φ(.) and perform the following
computation
V' = normalize(Φ(Q).mm(Φ(K).t())).mm(V).
The above can be computed in O(N D^2) complexity where D is the
dimensionality of Q, K and V and N is the sequence length. Depending on the
feature map, however, the complexity of the attention might be limited.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, query_dimensions, feature_map=None, eps=1e-6,
event_dispatcher=""):
super(LinearAttention, self).__init__()
self.feature_map = (
feature_map(query_dimensions) if feature_map else
elu_feature_map(query_dimensions)
)
self.eps = eps
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
# Apply the feature map to the queries and keys
self.feature_map.new_feature_map(queries.device)
Q = self.feature_map.forward_queries(queries)
K = self.feature_map.forward_keys(keys)
# Apply the key padding mask and make sure that the attn_mask is
# all_ones
if not attn_mask.all_ones:
raise RuntimeError(("LinearAttention does not support arbitrary "
"attention masks"))
K = K * key_lengths.float_matrix[:, :, None, None]
# Compute the KV matrix, namely the dot product of keys and values so
# that we never explicitly compute the attention matrix and thus
# decrease the complexity
KV = torch.einsum("nshd,nshm->nhmd", K, values)
# Compute the normalizer
Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z)
return V.contiguous()
# Register the attention implementation so that it becomes available in our
# builders
AttentionRegistry.register(
"linear", LinearAttention,
[
("query_dimensions", Int),
("feature_map", Optional(Callable)),
("event_dispatcher", Optional(EventDispatcherInstance, ""))
]
)