-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
28 lines (23 loc) · 789 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from textgenrnn import textgenrnn
def train(
txt_path,
weight_path,
num_epochs=10,
):
"""
Trains char rnn model using the data in txt_path for num_epochs epochs.
Then saves the trained model to weight_path.
"""
# train textgenrnn
textgen = textgenrnn()
textgen.train_from_file(file_path=txt_path, num_epochs=num_epochs)
# save weights
textgen.save(weights_path=weight_path)
def generate(weight_path, output_path, num_sentences=20):
"""
Generates num_sentences number of sentences using the model weights given in weight_path.
Then saves the generated sentences to output_path.
"""
# generate textgenrnn results to file
textgen = textgenrnn(weight_path)
textgen.generate_to_file(output_path, n=num_sentences)