-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
82 lines (66 loc) · 2.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File : model.py
# Author : Yan <yanwong@126.com>
# Date : 30.03.2020
# Last Modified Date: 16.04.2020
# Last Modified By : Yan <yanwong@126.com>
import tensorflow as tf
import tensorflow_addons as tfa
# class PretrainedEmbedding(tf.keras.layers.Layer):
# def __init__(self, embeddings, rate=0.1):
# super(PretrainedEmbedding, self).__init__()
#
# self.embeddings = tf.constant(embeddings)
# self.dropout = tf.keras.layers.Dropout(rate=rate)
#
# def call(self, inputs, training=None):
# output = tf.nn.embedding_lookup(self.embeddings, inputs)
# return self.dropout(output, training=training)
class CRFLayer(tf.keras.layers.Layer):
def __init__(self, num_tags):
super(CRFLayer, self).__init__()
self.num_tags = num_tags
def build(self, input_shape):
self.trans_params = self.add_weight(
name='trans_params',
shape=(self.num_tags, self.num_tags))
self.build = True
def call(self, x, seq_len):
# x.shape == (batch_size, max_seq_len, num_tags)
tags, scores = tfa.text.crf_decode(x, self.trans_params, seq_len)
return tags, scores
class Model(tf.keras.Model):
def __init__(self, embeddings, vocab_size, config):
super(Model, self).__init__()
self.rnn_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(
config.d_word_lstm,
return_sequences=True,
recurrent_dropout=config.lstm_dropout,
kernel_regularizer=tf.keras.regularizers.l2(config.l2_lambda),
recurrent_regularizer=tf.keras.regularizers.l2(config.l2_lambda)))
self.hidden_layer = tf.keras.layers.Dense(
config.d_word_lstm, activation='tanh')
self.final_layer = tf.keras.layers.Dense(config.n_tags)
self.crf_layer = CRFLayer(config.n_tags)
self.embedding_layer = tf.keras.layers.Embedding(
vocab_size, config.d_word, trainable=config.non_static_emb)
self.embedding_dropout = tf.keras.layers.Dropout(rate=config.emb_dropout)
if embeddings is not None:
self.embedding_layer.build((None, vocab_size))
self.embedding_layer.set_weights([embeddings])
# When deploying, tf need to know the input signature.
# @tf.function(input_signature=[tf.TensorSpec([None, None], tf.int64),
# tf.TensorSpec([], tf.bool),
# tf.TensorSpec([None, None], tf.float32)])
def call(self, x, training, padding_mask):
# x.shape == (batch_size, max_seq_len)
# padding_mask.shape == (batch_size, max_seq_len)
x = self.embedding_layer(x) # (batch_size, max_seq_len, d_word)
x = self.embedding_dropout(x, training=training)
x = self.rnn_layer(x, mask=tf.cast(padding_mask, tf.bool), training=training)
x = self.hidden_layer(x)
logits = self.final_layer(x)
true_seq_len = tf.cast(tf.math.reduce_sum(padding_mask, axis=1), tf.int32)
tags, scores = self.crf_layer(logits, true_seq_len)
return tags, logits