-
Notifications
You must be signed in to change notification settings - Fork 25
/
ModelLib.py
32 lines (24 loc) · 1.44 KB
/
ModelLib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import transfer_learning
class ModelLib():
def build_classifier_model(self, dataset):
raise NotImplementedError
def corriculum_svm_based_training_data(self, dataset, anti_corriculum=False, random=False):
(transfer_values_train, transfer_values_test) = transfer_learning.get_transfer_values_inception(dataset)
train_scores, test_scores = transfer_learning.get_svm_scores(transfer_values_train, dataset.y_train,
transfer_values_test, dataset.y_test, dataset)
order = transfer_learning.rank_data_according_to_score(train_scores, dataset.y_train, reverse=anti_corriculum,
random=random)
size_data = dataset.x_train.shape[0]
epochs_each_data = 10
jumps = 0.1
data_sizes = list(int(size_data * frac) for frac in (np.arange(0, 1, jumps) + jumps))
epochs = [epochs_each_data] * len(data_sizes)
total_batchs = sum(epoch * data_size for epoch, data_size in zip(epochs, data_sizes))
total_batchs_original = 100 * size_data
epochs[-1] += (total_batchs_original - total_batchs) // size_data
def data_function(x, y, cur_phase, num_phases):
data_limit = data_sizes[cur_phase]
new_data = order[:data_limit]
return x[new_data, :, :, :], y[new_data, :]
return epochs, data_function