-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_utils.py
131 lines (104 loc) · 3.61 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File : data_utils.py
# Author : Yan <yanwong@126.com>
# Date : 31.03.2020
# Last Modified Date: 10.04.2020
# Last Modified By : Yan <yanwong@126.com>
import logging
import glob
import collections
import numpy as np
import tensorflow as tf
logging.basicConfig(level=logging.INFO)
def load_w2v(fname, vocab):
""" Loads pre-trained word vectors.
Args:
fname: The pre-trained word vector file.
Created by word2vec in non-binary mode.
vocab: The dict of words appearing in training corpus.
Returns:
A dict of pre-trained word vectors for each word in vocab.
"""
word_vecs = {}
with open(fname, "r") as f:
header = f.readline()
vocab_size, layer1_size = list(map(int, header.split()))
for line in f:
toks = line.split()
assert(len(toks) == layer1_size + 1)
word = toks[0]
if word in vocab:
word_vecs[word] = np.array(toks[1:]).astype(np.float)
return word_vecs
def load_vocab(vocab_file):
""" Load vocab as an ordered dict.
Args:
vocab_file: The vocab file in which each line is a single word.
Returns:
An ordered dict of which key is the word and value is id.
"""
vocab = collections.OrderedDict()
with tf.io.gfile.GFile(vocab_file, 'r') as f:
lines = f.readlines()
for i in range(len(lines)):
word = lines[i].strip()
vocab[word] = i
return vocab
def load_vocab_embeddings(word_vecs, vocab, emb_dim):
"""Load pre-trained word embeddings for words in vocab. For the word that's
in vocab but there's no corresponding pre-trained embedding, generate a
embedding randomly for it.
Args:
word_vecs: A dict contains pre-trained word embedding. Each word in this
dict is also in the vocab.
vocab: An ordered dict of which key is the word and value is id.
emb_dim: The dimension of word embeddings.
Returns:
A word embedding list contains all words in vocab. In addition, it contains
PAD and UNK embeddings, too.
"""
embeddings = []
for word in vocab:
emb = word_vecs.get(word, None)
if emb is None:
emb = np.random.uniform(-0.25, 0.25, emb_dim)
embeddings.append(emb)
return np.array(embeddings)
def create_dataset(file_pattern, batch_size):
"""Fetches string values from disk into tf.data.Dataset.
Args:
file_pattern: Comma-separated list of file patterns (e.g.
"/tmp/train_data-?????-of-00100", where '?' acts as a wildcard that
matches any character).
batch_size: Batch size.
Returns:
A dataset read from TFRecord files.
"""
data_files = []
for pattern in file_pattern.split(','):
data_files.extend(glob.glob(pattern))
if not data_files:
logging.fatal('Found no input files matching %s', file_pattern)
else:
logging.info('Prefetching values from %d files matching %s',
len(data_files), file_pattern)
dataset = tf.data.TFRecordDataset(data_files)
def _parse_record(record):
features = {
'sentence': tf.io.VarLenFeature(dtype=tf.int64),
'tags': tf.io.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.io.parse_single_example(record, features)
sent = tf.sparse.to_dense(parsed_features['sentence'])
tags = tf.sparse.to_dense(parsed_features['tags'])
return sent, tags
dataset = dataset.map(_parse_record)
dataset = dataset.shuffle(buffer_size=100000, seed=42)
dataset = dataset.padded_batch(
batch_size,
padded_shapes=([-1], [-1]))
return dataset
def create_padding_mask(seq):
mask = tf.math.logical_not(tf.math.equal(seq, 0))
return tf.cast(mask, tf.float32)