-
Notifications
You must be signed in to change notification settings - Fork 0
/
lstm.ts
117 lines (90 loc) · 2.8 KB
/
lstm.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import * as tf from "@tensorflow/tfjs-node-gpu"
import * as fs from 'fs'
import * as tf from "@tensorflow/tfjs-node-gpu"
import * as fs from 'fs'
// class Model {
// constructor({
// path = "./models/default",
// data = "",
// parameters,
// }) {}
// public train() {}
// public generate() {}
// public save() {}
// }
const LSTM_LAYER_SIZE = 32;
const SAMPLE_LENGTH = 128;
const SAMPLE_STEP = SAMPLE_LENGTH;
const NUM_EPOCS = 1;
const NUM_ERA = 5;
const BATCH_SIZE = 512 ;
const LENGTH = 128;
const TEMPERATURE = 0.05;
const EXAMPLES_PER_EPOC = 10;
import CharacterSet from "./charset"
const Models = {}
class LSTM {
model:tf.LayersModel
static create({
sampleLength = SAMPLE_LENGTH,
charsetSize = 128
}) {
// define the models layer structure
const model = tf.sequential();
const lstm1 = tf.layers.lstm({
units: LSTM_LAYER_SIZE,
inputShape: [sampleLength, charsetSize],
returnSequences: true,
});
// const lstm2 = tf.layers.lstm({
// units: LSTM_LAYER_SIZE,
// returnSequences: false,
// });
const optimizer = tf.train.rmsprop(0.05);
model.add(lstm1);
//model.add(lstm2);
model.add(tf.layers.dense({ units: charsetSize, activation: 'softmax' }));
model.compile({ optimizer: optimizer, loss: 'categoricalCrossentropy' });
}
constructor({
model = tf.sequential()
}) {}
}
export function create() {
// define the models layer structure
const model = tf.sequential();
const lstm1 = tf.layers.lstm({
units: LSTM_LAYER_SIZE,
inputShape: [SAMPLE_LENGTH, charset.length],
returnSequences: true,
});
const lstm2 = tf.layers.lstm({
units: LSTM_LAYER_SIZE,
returnSequences: false,
});
const optimizer = tf.train.rmsprop(0.05);
model.add(lstm1);
//model.add(lstm2);
model.add(tf.layers.dense({ units: charset.length, activation: 'softmax' }));
model.compile({ optimizer: optimizer, loss: 'categoricalCrossentropy' });
}
export function encode(text, {
sampleLength = SAMPLE_LENGTH,
charset = CharacterSet.create(text)
}) {
const trainingData = tf.buffer([text.length,sampleLength,charset.size])
const labelData = tf.buffer([text.length,charset.size])
console.log("generating training data")
for(let i = 0; i < text.length ; i++) {
for(let j = 0; j < SAMPLE_LENGTH; j++) {
let k = charset.getIndexFromChar(text[i+j])
trainingData.set(1,i,j,k)
}
labelData.set(1,i,charset.getIndexFromChar(text[i]))
}
return [trainingData,labelData]
}
export function train() {}
export function generate() {}
export function save() {}
export function load(path) {}