-
Notifications
You must be signed in to change notification settings - Fork 0
/
Attention_Class.py
26 lines (20 loc) · 894 Bytes
/
Attention_Class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from keras.layers import Layer
import keras.backend as K
# Below class is of attention layer mechanism.
class attention(Layer):
def __init__(self,**kwargs):
super(attention,self).__init__(**kwargs)
def build(self,input_shape):
self.W=self.add_weight(name="att_weight",shape=(input_shape[-1],1),initializer="normal")
self.b=self.add_weight(name="att_bias",shape=(input_shape[1],1),initializer="zeros")
super(attention, self).build(input_shape)
def call(self,x):
et=K.squeeze(K.tanh(K.dot(x,self.W)+self.b),axis=-1)
at=K.softmax(et)
at=K.expand_dims(at,axis=-1)
output=x*at
return K.sum(output,axis=1)
def compute_output_shape(self,input_shape):
return (input_shape[0],input_shape[-1])
def get_config(self):
return super(attention,self).get_config()