-
Notifications
You must be signed in to change notification settings - Fork 2
/
vai_pesudo.py
69 lines (50 loc) · 2.18 KB
/
vai_pesudo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare pseudo labels to train the segmentation network
"""
from utils.mayavi_visu import *
import pickle
from sklearn.neighbors import NearestNeighbors
import glob
base_folder = '.../test'
model_n =
t_list = [0, 1, 2]
ck = # 'chkp_**'
for t in t_list:
base_path = join(base_folder, model_n, ck)
fn_list = glob.glob(base_path+'/predictions/*.ply')
# compare normal .ply and lbs.ply
for fn in fn_list:
data = read_ply(fn)
pts_lbs = np.array([data['x'],data['y'],data['z']]).T
pseudo_lbs = data['preds']
ff = fn.split('/')[-1].split('.pts.ply')[0]
fnfn = #'/.ply'
data_sub = read_ply(fnfn)
pts_sub = np.array([data_sub['x'],data_sub['y'],data_sub['z']]).T
lbs = data_sub['class']
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(pts_lbs[:,:3])
distance, indices = nbrs.kneighbors(pts_sub[:,:3])
indices = np.squeeze(indices)
prob_p=join(base_path, 'probs/Vaihingen3D_Train.pts.ply')
data = read_ply(prob_p)
probs = np.vstack((data['Powerline'], data['Low_vegetation'], data['Impervious_surfaces'],
data['Car'], data['Fence/Hedge'], data['Roof'], data['Facade'],
data['Shrub'], data['Tree'])).T
region_class = np.genfromtxt('class_lb.txt', delimiter=' ')
probs = probs[indices]
probs = probs*region_class
empty = np.max(probs, axis=-1)<(0.1*t)
pseudo_lbs = pseudo_lbs[indices]
pseudo_lbs[empty]=10
uu, cc = np.unique(pseudo_lbs, return_counts=True)
cc = cc[:9]
w = np.log(1/(cc/np.sum(cc)))
w_n = w/np.sum(w)
new_lbs = pseudo_lbs#[indices]
save_path = join(base_folder, model_n, ff+'_'+ck+'_'+str(t)+'reglb'+'_pseudo.txt')
np.savetxt(save_path, new_lbs, fmt='%i')
w_path = join(base_folder, model_n, ff+'_'+ck+'_'+str(t)+'reglb'+'_weight.txt')
np.savetxt(w_path, w_n, fmt='%.3f')
print(fn)