-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·177 lines (149 loc) · 6.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# main.py
# Cloud Cho - Book Segmentation
#
# May 6, 2018 ~
#
# Step
# (1) Book segmentation
# (2) Title collection
#
# To do:
#
# Error:
# OpenCV version collision at faster_rcnn/simple_parser.py (essential)
# RoiPoolingConv function should be added in faster_rcnn
#
# Source:
# 1st Trial
# https://github.com/FraPochetti/ImageTextRecognition
# 2nd Trial
#
# Work? - no
from __future__ import division
# from data import OcrData
# from cifar import Cifar
# from userimageski import UserData
import random
import pprint
import sys
import time
import numpy as np
import math
from optparse import OptionParser
import pickle
from keras import backend as K
from keras.optimizers import Adam, SGD, RMSprop
from keras.layers import Input
from keras.models import Model
# from faster_rcnn import config, data_generators
# from faster_rcnn import losses as losses
# import faster_rcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils
# from faster_rcnn import simple_parser, vgg_sixteen
from faster_rcnn import vgg_sixteen
# ----- ----- ----- ----- ----- ----- ----- ----- ----- -----
# 1st Trial
def detection_model():
################################################################
# 1- GENERATE MODEL TO PREDICT WHETHER AN OBJECT CONTAINS TEXT OR NOT
################################################################
# CREATES AN INSTANCE OF THE CLASS LOADING THE OCR DATA
data = OcrData('/home/francesco/Dropbox/DSR/OCR/ocr-config.py')
# GENERATES A UNIQUE DATA SET MERGING NON-TEXT WITH TEXT IMAGES
data.merge_with_cifar()
# PERFORMS GRID SEARCH CROSS VALIDATION GETTING BEST MODEL OUT OF PASSED PARAMETERS
data.perform_grid_search_cv('linearsvc-hog')
# TAKES THE PARAMETERS LINKED TO BEST MODEL AND RE-TRAINS THE MODEL ON THE WHOLE TRAIN SET
data.generate_best_hog_model()
# TAKES THE JUST GENERATED MODEL AND EVALUATES IT ON TRAIN SET
data.evaluate('/media/francesco/Francesco/CharacterProject/linearsvc-hog-fulltrain2-90.pickle')
def extraion_model():
###################################################################
# 2- GENERATE MODEL TO CLASSIFY SINGLE CHARACTERS
###################################################################
# CREATES AN INSTANCE OF THE CLASS LOADING THE OCR DATA
data = OcrData('/home/francesco/Dropbox/DSR/OCR/ocr-config.py')
# PERFORMS GRID SEARCH CROSS VALIDATION GETTING BEST MODEL OUT OF PASSED PARAMETERS
data.perform_grid_search_cv('linearsvc-hog')
# TAKES THE PARAMETERS LINKED TO BEST MODEL AND RE-TRAINS THE MODEL ON THE WHOLE TRAIN SET
data.generate_best_hog_model()
# TAKES THE JUST GENERATED MODEL AND EVALUATES IT ON TRAIN SET
data.evaluate('/media/francesco/Francesco/CharacterProject/linearsvc-hog-fulltrain36-90.pickle')
def test_model():
##### the following code includes all the steps to get from a raw image to a prediction.
##### the working code is the uncommented one.
##### the two pickle models which are passed as argument to the select_text_among_candidates
##### and classify_text methods are obviously the result of a previously implemented pipeline.
##### just for the purpose of clearness below the code is provided.
##### I want to emphasize that the commented code is the one necessary to get the models trained.
# creates instance of class and loads image
user = UserData('lao.jpg')
# plots preprocessed imae
user.plot_preprocessed_image()
# detects objects in preprocessed image
candidates = user.get_text_candidates()
# plots objects detected
user.plot_to_check(candidates, 'Total Objects Detected')
# selects objects containing text
maybe_text = user.select_text_among_candidates('/media/francesco/Francesco/CharacterProject/linearsvc-hog-fulltrain2-90.pickle')
# plots objects after text detection
user.plot_to_check(maybe_text, 'Objects Containing Text Detected')
# classifies single characters
classified = user.classify_text('/media/francesco/Francesco/CharacterProject/linearsvc-hog-fulltrain36-90.pickle')
# plots letters after classification
user.plot_to_check(classified, 'Single Character Recognition')
# plots the realigned text
user.realign_text()
# ----- ----- ----- ----- ----- ----- ----- ----- ----- -----
# 2nd Trial
def train_model(dataset):
# (1) Prepare data and label
input_shape_img = (None, None, 3)
img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4))
# Error spot
# all_imgs, classes_count, class_mapping = simple_parser.get_data(train_path)
classes = ['human', 'car']
# simple_parser.py shoud be fixed
classes_count = {classes[0]: 10, classes[1]: 10}
# define the base network
shared_layers = vgg_sixteen.nn_base(img_input, trainable=True)
# (2) Neural Network defining
# define the RPN, built on the base layers
anchor_box_scales = [128, 256, 512]
anchor_box_ratios = [[1, 1], [1./math.sqrt(2), 2./math.sqrt(2)],
[2./math.sqrt(2), 1./math.sqrt(2)]]
num_rois = 32
num_anchors = len(anchor_box_scales) * len(anchor_box_ratios)
rpn = vgg_sixteen.rpn(shared_layers, num_anchors)
classifier = vgg_sixteen.classifier(shared_layers, roi_input, num_rois,
nb_classes=len(classes_count), trainable=True)
model_rpn = Model(img_input, rpn[:2])
model_classifier = Model([img_input, roi_input], classifier)
# This is a model that holds both the RPN and the classifier, used
# to load/save weights for the models
model_all = Model([img_input, roi_input], rpn[:2] + classifier)
print(type(model_all))
# print('Training completed')
# -----
#
# -----
def main():
# # 1st Trial
# detection_model()
# extraction_model()
#
# test_model()
# 2nd Trial
width = 12
height = 12
depth = 3
dataset_size = 10
data = np.random.randint(0, 2, (width, height, depth))
label = np.random.randint(0, 10, (dataset_size))
print(data.shape, label.shape)
dataset = [data, label]
print(type(dataset))
train_model(dataset)
if __name__ == '__main__':
main()