-
Notifications
You must be signed in to change notification settings - Fork 6
/
s1_offset_generator.py
121 lines (96 loc) · 4.06 KB
/
s1_offset_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
import cv2
import torch
import argparse
import subprocess
import numpy as np
from glob import glob
from PIL import Image
import os.path as osp
import scipy.io as io
from tqdm import tqdm
from scipy.ndimage.morphology import distance_transform_edt, distance_transform_cdt
script_path = osp.abspath(osp.join(osp.dirname(__file__)))
os.chdir(osp.join(script_path, '..', '..', '..', '..'))
sys.path.insert(0, os.getcwd())
os.environ['PYTHONPATH'] = os.getcwd() + ':' + os.environ.get('PYTHONPATH', '')
DATA_ROOT = subprocess.check_output(
['bash', '-c', "source config.profile; echo $DATA_ROOT"]
).decode().strip()
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", dest='datadir', default='/data/haonan.guo/Inria')
parser.add_argument("--outname", default='dt_offset')
parser.add_argument('--split', nargs='+', default=['train','test'])
parser.add_argument("--ksize", type=int, default=5)
parser.add_argument('--metric', default='euc', choices=['euc', 'taxicab'])
args = parser.parse_args()
def sobel_kernel(shape, axis):
"""
shape must be odd: eg. (5,5)
axis is the direction, with 0 to positive x and 1 to positive y
"""
k = np.zeros(shape)
p = [
(j, i)
for j in range(shape[0])
for i in range(shape[1])
if not (i == (shape[1] - 1) / 2.0 and j == (shape[0] - 1) / 2.0)
]
for j, i in p:
j_ = int(j - (shape[0] - 1) / 2.0)
i_ = int(i - (shape[1] - 1) / 2.0)
k[j, i] = (i_ if axis == 0 else j_) / float(i_ * i_ + j_ * j_)
return torch.from_numpy(k).unsqueeze(0)
label_list = [0,255]
def _encode_label(labelmap):
encoded_labelmap = np.ones_like(labelmap, dtype=np.uint16) * 255
for i, class_id in enumerate(label_list):
encoded_labelmap[labelmap == class_id] = i
return encoded_labelmap
def process(inp):
(indir, outdir, basename) = inp
print(inp)
labelmap = np.array(Image.open(osp.join(indir, basename)).convert("P")).astype(np.int16)
labelmap = _encode_label(labelmap)
labelmap = labelmap + 1
depth_map = np.zeros(labelmap.shape, dtype=np.float32)
dir_map = np.zeros((*labelmap.shape, 2), dtype=np.float32)
for id in range(1, len(label_list) + 1):
labelmap_i = labelmap.copy()
labelmap_i[labelmap_i != id] = 0
labelmap_i[labelmap_i == id] = 1
if args.metric == 'euc':
depth_i = distance_transform_edt(labelmap_i)
elif args.metric == 'taxicab':
depth_i = distance_transform_cdt(labelmap_i, metric='taxicab')
else:
raise RuntimeError
depth_map += depth_i
dir_i_before = dir_i = np.zeros_like(dir_map)
dir_i = torch.nn.functional.conv2d(torch.from_numpy(depth_i).float().view(1, 1, *depth_i.shape), sobel_ker,
padding=ksize // 2).squeeze().permute(1, 2, 0).numpy()
# The following line is necessary
dir_i[(labelmap_i == 0), :] = 0
dir_map += dir_i
depth_map[depth_map > 250] = 250
depth_map = depth_map.astype(np.uint8)
deg_reduce = 2
dir_deg_map = np.degrees(np.arctan2(dir_map[:, :, 0], dir_map[:, :, 1])) + 180
dir_deg_map = (dir_deg_map / deg_reduce)
print(dir_deg_map.min(), dir_deg_map.max())
dir_deg_map = dir_deg_map.astype(np.uint8)
io.savemat(
osp.join(outdir, basename.replace("png", "mat")),
{"dir_deg": dir_deg_map, "depth": depth_map, 'deg_reduce': deg_reduce},
do_compression=True,
)
ksize = args.ksize
sobel_x, sobel_y = (sobel_kernel((ksize, ksize), i) for i in (0, 1))
sobel_ker = torch.cat([sobel_y, sobel_x], dim=0).view(2, 1, ksize, ksize).float()
for dataset in args.split:
indir = osp.join(args.datadir, dataset, 'label')
outdir = osp.join(args.datadir, dataset, args.outname)
args_to_apply = [(indir, outdir, osp.basename(basename)) for basename in glob(osp.join(indir, "*.png"))]
for i in tqdm(range(len(args_to_apply))):
process(args_to_apply[i])