-
Notifications
You must be signed in to change notification settings - Fork 20
/
utils.py
71 lines (57 loc) · 1.91 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import tensorflow as tf
import imageio
import numpy as np
import math
def imread(path):
img = imageio.imread(path).astype(np.float32)
img=img/255.
return img
def load(saver, sess, checkpoint_dir, folder):
print(" ========== Reading Checkpoints ============")
checkpoint=os.path.join(checkpoint_dir, folder)
ckpt = tf.train.get_checkpoint_state(checkpoint)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(checkpoint, ckpt_name))
print(" ============== Success to read {} ===============".format(ckpt_name))
return True
else:
print(" ============= Failed to find a checkpoint =============")
return False
def save(saver, sess, checkpoint_dir, trial, step):
model_name="model"
checkpoint=os.path.join(checkpoint_dir, "Model%d" % trial)
if not os.path.exists(checkpoint):
os.makedirs(checkpoint)
saver.save(sess,os.path.join(checkpoint,model_name),global_step=step)
def psnr(img1, img2):
img1=np.float64(img1)
img2=np.float64(img2)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return 100
if np.max(img1) <= 1.0:
PIXEL_MAX= 1.0
else:
PIXEL_MAX = 255.0
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
def modcrop(imgs, modulo):
sz=imgs.shape
sz=np.asarray(sz)
if len(sz)==2:
sz = sz - sz% modulo
out = imgs[0:sz[0], 0:sz[1]]
elif len(sz)==3:
szt = sz[0:2]
szt = szt - szt % modulo
out = imgs[0:szt[0], 0:szt[1],:]
return out
def rgb2y(x):
if x.dtype==np.uint8:
x=np.float64(x)
y=65.481/255.*x[:,:,0]+128.553/255.*x[:,:,1]+24.966/255.*x[:,:,2]+16
y=np.round(y).astype(np.uint8)
else:
y = 65.481 / 255. * x[:, :, 0] + 128.553 / 255. * x[:, :, 1] + 24.966 / 255. * x[:, :, 2] + 16 /255
return y