-
Notifications
You must be signed in to change notification settings - Fork 7
/
preprocess.py
155 lines (122 loc) · 4.34 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import sys
import os
import numpy as np
import csv
import pickle
import argparse
# s = subject
# r = relation
# o = object
from util import AttributeDict
def read_data(file_path):
s_dict = dict()
with open(file_path) as csv_file:
csv_reader = csv.reader(csv_file, delimiter='\t')
for s, r, o in csv_reader:
try:
s_dict[s][r].append(o)
except KeyError:
if s_dict.get(s) is None:
s_dict[s] = dict()
s_dict[s][r] = [o]
return s_dict
def create_dataset(s_dict):
x, y = list(), list()
e_to_index, index_to_e, r_to_index, index_to_r = dict(), dict(), dict(), dict()
for s, ro in s_dict.items():
try:
_ = e_to_index[s]
except KeyError:
index = len(e_to_index)
e_to_index[s] = index
index_to_e[index] = s
for r, os in ro.items():
try:
_ = r_to_index[r]
except KeyError:
index = len(r_to_index)
r_to_index[r] = index
index_to_r[index] = r
for o in os:
# sometimes an entity only occurs as an object
try:
_ = e_to_index[o]
except KeyError:
index = len(e_to_index)
e_to_index[o] = index
index_to_e[index] = o
x.append((s, r))
y.append(os)
return x, y, e_to_index, index_to_e, r_to_index, index_to_r
def preprocess_train(file_path):
s_dict = read_data(file_path)
x, y, e_to_index, index_to_e, r_to_index, index_to_r = create_dataset(s_dict)
data = {
'x': x,
'y': y,
'e_to_index': e_to_index,
'index_to_e': index_to_e,
'r_to_index': r_to_index,
'index_to_r': index_to_r
}
print('#entities: ', len(e_to_index))
print('#relations: ', len(r_to_index))
for i in range(np.minimum(len(x), 200)):
print(x[i], y[i])
choice = np.random.choice(len(e_to_index))
assert choice == e_to_index[index_to_e[choice]]
choice = np.random.choice(len(r_to_index))
assert choice == r_to_index[index_to_r[choice]]
save_file_path = os.path.splitext(file_path)[0] + '.pkl'
pickle.dump(data, open(save_file_path, 'wb'))
def preprocess_valid(train_path, valid_path):
x, y = list(), list()
with open(train_path, 'rb') as f:
train_data = AttributeDict(pickle.load(f))
s_dict = read_data(valid_path)
for s, ro in s_dict.items():
try:
_ = train_data.e_to_index[s]
except KeyError:
continue
for r, objects in ro.items():
try:
_ = train_data.r_to_index[r]
except KeyError:
continue
filtered_objects = list()
for o in objects:
# sometimes an entity only occurs as an object
try:
_ = train_data.e_to_index[o]
filtered_objects.append(o)
except KeyError:
continue
x.append((s, r))
y.append(filtered_objects)
data = {
'x': x,
'y': y,
}
save_file_path = os.path.splitext(valid_path)[0] + '.pkl'
pickle.dump(data, open(save_file_path, 'wb'))
def parse_args():
parser = argparse.ArgumentParser(description='Preprocess knowledge graph csv train/valid (test) data.')
sub_parsers = parser.add_subparsers(help='mode', dest='mode')
sub_parsers.required = True
train_parser = sub_parsers.add_parser('train', help='Preprocess a training set')
valid_parser = sub_parsers.add_parser('valid', help='Preprocess a valid or test set')
train_parser.add_argument('train_path', action='store', type=str,
help='Path to train dataset (csv or tsv)')
valid_parser.add_argument('train_path', action='store', type=str, help='Path to train .pkl')
valid_parser.add_argument('valid_path', action='store', type=str,
help='Path to valid dataset (csv or tsv)')
return parser.parse_args()
def main():
args = parse_args()
if args.mode == 'train':
preprocess_train(args.train_path)
else:
preprocess_valid(args.train_path, args.valid_path)
if __name__ == '__main__':
main()