-
Notifications
You must be signed in to change notification settings - Fork 25
/
visualizer.py
54 lines (45 loc) · 2.07 KB
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import os
import time
class Visualizer():
def __init__(self, checkpoints_path, visdom_port=None):
if visdom_port is not None:
import visdom
self.vis = visdom.Visdom(port=visdom_port)
self.use_vis = True
else:
self.use_vis = False
self.checkpoints_path = checkpoints_path
self.log_name = os.path.join(checkpoints_path, 'loss_log.txt')
now = time.strftime('%c')
self.start_time = time.time()
with open(self.log_name, 'a') as log_file:
log_file.write('Training {} \n'.format(now))
def plot_quality(self, name, quality, epoch):
if not hasattr(self, 'plot_data'):
self.plot_data = {}
if name not in self.plot_data:
self.plot_data[name] = {'X':[],'Y':[], 'legend':list(quality.keys())}
self.plot_data[name]['X'].append(epoch)
self.plot_data[name]['Y'].append([quality[k] for k in self.plot_data[name]['legend']])
self.plot_data[name]['X'].append(epoch)
self.plot_data[name]['Y'].append([quality[k] for k in self.plot_data[name]['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data[name]['X'])]*len(self.plot_data[name]['legend']), 1),
Y=np.array(self.plot_data[name]['Y']),
opts={
'title': name,
'legend': self.plot_data[name]['legend'],
'xlabel': 'epoch'},
win=name)
def print_quality(self, quality, epoch, epochs):
message = '[Epoch {}/{}] Time elapsed: {:.2f}; '.format(epoch, epochs, time.time() - self.start_time)
for k, v in quality.items():
message += '{}: {:.4f}; '.format(k, v)
print(message)
with open(self.log_name, 'a') as log_file:
log_file.write('{}\n'.format(message))
def quality(self, name, quality, epoch, epochs):
self.print_quality(quality, epoch, epochs)
if self.use_vis :
self.plot_quality(name, quality, epoch)