-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
executable file
·183 lines (158 loc) · 7.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python
__author__ = 'Daan van Stigt'
import argparse
import os
import subprocess
import re
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from conllxdataset import ConllXDataset
from conlludataset import ConllUDataset
from features import get_feature_opts
from parser import DependencyParser
from evaluate import evaluate
from predict import predict
from plot import plot
from utils import get_size, UD_LANG, UD_SPLIT
def get_data(args):
data_dir = os.path.expanduser(args.data)
if args.use_ptb:
train_dataset = ConllXDataset(os.path.join(data_dir, 'train.conll'))
dev_dataset = ConllXDataset(os.path.join(data_dir, 'dev.conll'))
test_dataset = ConllXDataset(os.path.join(data_dir, 'test.conll'))
else:
data_path = os.path.join(data_dir, UD_LANG[args.lang])
train_dataset = ConllUDataset(data_path + UD_SPLIT['train'])
dev_dataset = ConllUDataset(data_path + UD_SPLIT['dev'])
test_dataset = ConllUDataset(data_path + UD_SPLIT['test'])
return train_dataset, dev_dataset, test_dataset
def train(args):
print(f'Loading dataset from `{args.data}`...')
train_dataset, dev_dataset, test_dataset = get_data(args)
train_tokens = train_dataset.tokens[:args.max_lines] # lose the last sentence with -1 but OK.
dev_tokens = dev_dataset.tokens
test_tokens = dev_dataset.tokens
for dir in ('models', 'out'):
if not os.path.exists(dir):
os.makedirs(dir)
# Make model.
feature_opts = get_feature_opts(args.features)
model = DependencyParser(feature_opts, args.decoder)
if args.load:
print(f'Loading model from-set from `{args.model}`...')
model.load(args.model)
else:
print('Creating feature-set...')
if len(feature_opts) > 0:
print(f'Additional features: {", ".join(feature_opts.keys())}.')
model.make_features(train_tokens)
print(f'Number of features: {len(model.weights):,}.')
print(f'Memory used by model: {get_size(model):.3f} GB.')
# Train model.
try:
if args.parallel:
model.train_parallel(args.epochs, train_tokens, dev_tokens, args.nprocs)
else:
objective = args.decoder if args.structured else 'greedy'
print(f'Training with {objective} decoding...')
model.train(args.epochs, train_tokens, dev_tokens, structured=args.structured)
except KeyboardInterrupt:
print('Exiting training early.')
if args.parallel:
model.restore_from_parallel() # restores weights that were distributed
# Evaluate model.
print('Evaluating on dev set...')
dev_acc = model.evaluate(dev_tokens)
print(f'Dev UAS {dev_acc:.2f}')
print('Top features:')
top_features = model.top_features(30)
print('\n'.join(f'{f} {v:.4f}' for f, v in top_features))
print()
if not args.parallel:
print('Averaging model weights...')
model.average_weights()
# Evaluate again (to see difference).
print('Evaluating on dev set...')
dev_acc = model.evaluate(dev_tokens)
print(f'Dev UAS {dev_acc:.2f}')
print('Top features:')
top_features = model.top_features(30)
print('\n'.join(f'{f} {v:.4f}' for f, v in top_features))
print()
else:
print('Weight averaging not yet supported for parallel.')
before, zeros = model.prune(args.eps)
print(f'Pruned weights with threshold {args.eps}...')
print(f'Number of weights before: {before:,} ({zeros:,} exactly zero).')
print(f'Number of pruned weights: {len(model.weights):,}.')
print(f'Saving model to `{args.model}`...')
data_path = os.path.join(data_dir, UD_LANG[args.lang]) if not args.use_ptb else args.data
model.save(
args.model, data=data_path, epochs=args.epochs, accuracy=dict(dev=round(dev_acc, 2)))
print()
print('Producing predictions in conll format and final evaluation...')
evaluate(args)
print('Finished.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('mode', choices=['train', 'eval', 'plot', 'predict'],
help='choose action')
# Data args.
parser.add_argument('--data', type=str, default='data/ud',
help='data dir')
parser.add_argument('--lang', type=str, default='en',
choices=['cs', 'de', 'en', 'es', 'hi', 'fr', 'nl'], # TODO: more langs
help='language (universal dependencies only)')
parser.add_argument('--ptb-dir', type=str, default='~/data/stanford-ptb',
help='data dir for ptb')
parser.add_argument('--use-ptb', action='store_true',
help='using penn treebank')
# Model args.
parser.add_argument('--decoder', choices=['mst', 'eisner'], default='mst',
help='decoder used to extract tree from score matrix')
parser.add_argument('--model', type=str, default='models/model.json',
help='path to save model to, or load model from')
parser.add_argument('--eps', type=float, default=1e-2,
help='prune threshold for feature weights')
parser.add_argument('--load', action='store_true',
help='load a pretrained model, specify which with --model')
# Training args.
parser.add_argument('--epochs', type=int, default=10,
help='epochs to train')
parser.add_argument('--features', nargs='+', default=[],
help='space separated list of additional features',
choices=['dist', 'surround', 'between'])
parser.add_argument('--parallel', action='store_true',
help='training in parallel')
parser.add_argument('--nprocs', type=int, default=-1,
help='number of processors to use (default is all)')
parser.add_argument('--structured', action='store_true',
help='using decoding algorithm to train on structured objective '
'specified by --decoder')
parser.add_argument('--out', type=str, default='out',
help='dir to put predicted conll files')
parser.add_argument('-n', '--max-lines', type=int, default=-1,
help='number of lines to train on.')
# Predict args.
parser.add_argument('--examples', action='store_true',
help='parse some example lines that should be pretty hard')
parser.add_argument('--jabber', action='store_true',
help='parse some lines from the jabberwocky poem')
parser.add_argument('--no-tags', action='store_true',
help='parse with no tags')
parser.add_argument('--plot-dir', default='image',
help='dir for the heatmap plot')
parser.add_argument('--plot-name', default='input',
help='name for the heatmap plot')
parser.add_argument('--ext', default='pdf')
args = parser.parse_args()
if args.mode == 'train':
train(args)
if args.mode == 'eval':
evaluate(args)
if args.mode == 'predict':
predict(args)
if args.mode == 'plot':
plot(args)