diff --git a/fastai_contrib/learner.py b/fastai_contrib/learner.py index 67e79cb..d761b35 100644 --- a/fastai_contrib/learner.py +++ b/fastai_contrib/learner.py @@ -39,7 +39,9 @@ def bilm_text_classifier_learner(data: DataBunch, bptt: int = 70, max_len: int = ds = data.train_ds vocab_size, n_class = len(data.vocab.itos), data.c if bicls_head == 'BiPoolingLinearClassifier': - count = 3*2 + count = 3 * 2 + elif if bicls_head == 'BiAttentionPoolingClassifier': + count = 5 else: count = 3 layers = [emb_sz * count] + lin_ftrs + [n_class] diff --git a/fastai_contrib/models.py b/fastai_contrib/models.py index 92ac7e4..89e8514 100644 --- a/fastai_contrib/models.py +++ b/fastai_contrib/models.py @@ -70,6 +70,118 @@ def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]: outputs.append(o) return self.concat(raw_outputs), self.concat(outputs) +class BiAttentionPoolingClassifier(nn.Module): + r" BiLM Pooling with self attention" + + def __init__(self, layers:Collection[int], drops:Collection[float], emb_sz:int): + super().__init__() + mod_layers = [] + activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None] + for n_in,n_out,p,actn in zip(layers[:-1],layers[1:], drops, activs): + mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn) + self.self_attn = MultiHeadAttention(n_head=8, d_model=emb_sz, d_k=64, d_v=64, dropout=0.1) + self.layers = nn.Sequential(*mod_layers) + + def pool(self, x:Tensor, bs:int, is_max:bool): + "Pool the tensor along the seq_len dimension." + f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d + return f(x.permute(2, 0, 1), (1,)).view(bs,-1) + + def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]: + raw_outputs, outputs = input + output = outputs[-1] + assert len(output.size()) == 4, 'Expected input dimension 4' + bs, sl, em_sz, passes = output.size() + + x = torch.cat([output[..., 0], output[..., 1]], 1) + x, _ = self.self_attn(x, x, x) + + avgpool = self.pool(x, bs, False) + mxpool = self.pool(x, bs, True) + + x = torch.cat([output[:,-1,..., 0], x, mxpool, + avgpool, output[:,-1,..., 1]], 1) + x = self.layers(x) + return x, raw_outputs, outputs + +class ScaledDotProductAttention(nn.Module): + r""" + Scaled Dot-Product Attention + based on: https://github.com/jadore801120/attention-is-all-you-need-pytorch + """ + + def __init__(self, temperature:float, attn_dropout:float=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v): + + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + attn = self.softmax(attn) + attn = self.dropout(attn) + output = torch.bmm(attn, v) + + return output, attn + +class MultiHeadAttention(nn.Module): + r""" + Multi-Head Attention module + based on: https://github.com/jadore801120/attention-is-all-you-need-pytorch + """ + + def __init__(self, n_head:int, d_model:int, d_k:int, d_v:int, dropout:float=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(n_head * d_v, d_model) + nn.init.xavier_normal_(self.fc.weight) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v): + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + + + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + residual = q + + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv + + x, attn = self.attention(q, k, v) + + x = x.view(n_head, sz_b, len_q, d_v) + x = x.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) + + x = self.dropout(self.fc(x)) + x = self.layer_norm(x + residual) + + return x, attn + class BiPoolingLinearClassifier(PoolingLinearClassifier): "Create a linear classifier with pooling." @@ -126,7 +238,6 @@ def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]: x = self.layers(x) return x, raw_outputs, outputs - def get_bilm(vocab_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int, tie_weights:bool=True, qrnn:bool=False, bias:bool=True, bidir:bool=False, output_p:float=0.4, hidden_p:float=0.2, input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5)->nn.Module: @@ -156,11 +267,20 @@ def get_birnn_classifier(bptt:int, max_seq:int, n_class:int, vocab_sz:int, emb_s qrnn=qrnn, hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p) head = BiPoolingLinearClassifier - if bicls_head == 'BiPoolingLinearClassifier': head = BiPoolingLinearClassifier - elif bicls_head == 'AvgPoolingLinearClassifier': head = AvgPoolingLinearClassifier - - model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops)) - model.reset() + + if bicls_head == 'BiPoolingLinearClassifier': + head = BiPoolingLinearClassifier + model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops)) + elif bicls_head == 'AvgPoolingLinearClassifier': + head = AvgPoolingLinearClassifier + model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops)) + elif bicls_head == 'BiAttentionPoolingClassifier': + head = BiAttentionPoolingClassifier + # attention requires an additional argument + # maybe use kwargs for initialising classes + model = SequentialRNN(BiLMModel(fwd_rnn_enc, bwd_rnn_enc), head(layers, drops, emb_sz)) + + model.reset() return model -#endregion \ No newline at end of file +#endregion