-
Notifications
You must be signed in to change notification settings - Fork 1
/
start.py
122 lines (104 loc) · 4.16 KB
/
start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import tornado.web
import tornado.websocket
import tornado.httpserver
import tornado.ioloop
import tornado.options
from uuid import uuid4
import base64
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
import mnist_train
import mnist_cnn as mnist_interence
import numpy as np
import os
def preparePng(path):
lena = mpimg.imread(path)
lena.astype(np.float32)
img = np.dot(lena[..., :3], [0.299, 0.587, 0.114])
return img
def predictInt2(imgdata):
with tf.Graph().as_default() as g: # 将默认图设为g
# 定义输入输出的格式
x = tf.placeholder(tf.float32, shape=[1,
mnist_interence.IMAGE_SIZE,
mnist_interence.IMAGE_SIZE,
mnist_interence.NUM_CHANNEL], name='x-input')
regularizer = tf.contrib.layers.l2_regularizer(0.0001)
y = mnist_interence.interence(x, True, regularizer)
variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variable_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variable_to_restore) # 这些值要从模型中提取
with tf.Session() as sess:
if (os.path.exists(os.path.join('model', 'checkpoint'))):
# Restore variables from disk.
saver = tf.train.import_meta_graph(os.path.join('model', 'model.meta'))
saver.restore(sess, tf.train.latest_checkpoint(os.path.join('model')))
xs = imgdata.reshape(1,28,28,1)
prediction = tf.argmax(y, 1)
num = prediction.eval(feed_dict={x: xs}, session=sess)
num = num[0]
return num
#服务器与客户端连接实例
class Connect(object):
users=[]
#建立客户端的连接
def newUser(self,newMan):
self.users.append(newMan) #加入
#客户端用户退出,删除该用户的连接
def exit(self,quitter):
self.users.remove(quitter)
#接受用户端的图片
def receiveMessage(self,sender,message):
src=message.split(',')[1]
img = base64.b64decode(src)
u=str(sender.get_argument('u'))
path = 'static/img/'+u+'.png'
file = open(path, 'wb')
file.write(img)
file.close()
data=preparePng(path)
num = str(predictInt2(data))
self.sendMessage(sender,num)
#发送一条消息给user客户端img
def sendMessage(self,user,message):
user.write_message(message)
class LoginHandler(tornado.web.RequestHandler):
'''进行登陆'''
def get(self):
self.render('index.html')
class ShowHandler(tornado.web.RequestHandler):
def get(self):
u=uuid4()
image_src=str(self.get_argument('image_src'))
self.render('show.html',image_src=image_src,u=u)
class UpdatesMssageHandler(tornado.websocket.WebSocketHandler):
'''
websocket, 记录客户端连接,删除客户端连接,接收最新消息
'''
def open(self):
self.application.connect.newUser(self) #记录客户端连接
def on_close(self):
self.application.connect.exit(self) #删除客户端连接
def on_message(self, message):
self.application.connect.receiveMessage(self, message) #处理客户端提交的最新消息
class Application(tornado.web.Application):
def __init__(self):
self.connect = Connect()
handlers=[
(r'/', LoginHandler),
(r'/show/update/', UpdatesMssageHandler),
(r'/show', ShowHandler),
]
settings = {
'template_path': 'templates',
'static_path': 'static',
}
tornado.web.Application.__init__(self, handlers, **settings)
if __name__ == "__main__":
tornado.options.parse_command_line()
server = tornado.httpserver.HTTPServer(Application())
server.listen(8888)
tornado.ioloop.IOLoop.instance().start()