-
Notifications
You must be signed in to change notification settings - Fork 28
/
LDA_SLIC.py
188 lines (160 loc) · 7.19 KB
/
LDA_SLIC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import numpy as np
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from skimage.segmentation import slic,mark_boundaries,felzenszwalb,quickshift,random_walker
from sklearn import preprocessing
import cv2
import math
def LSC_superpixel(I, nseg):
superpixelNum = nseg
ratio = 0.075
size = int(math.sqrt(((I.shape[0] * I.shape[1]) / nseg)))
superpixelLSC = cv2.ximgproc.createSuperpixelLSC(
I,
region_size=size,
ratio=0.005)
superpixelLSC.iterate()
superpixelLSC.enforceLabelConnectivity(min_element_size=25)
segments = superpixelLSC.getLabels()
return np.array(segments,np.int64)
def SEEDS_superpixel(I, nseg):
I=np.array(I[:, :, 0:3], np.float32).copy()
I_new = cv2.cvtColor(I, cv2.COLOR_BGR2HSV)
# I_new =np.array( I[:,:,0:3],np.float32).copy()
height, width, channels = I_new.shape
superpixelNum = nseg
seeds = cv2.ximgproc.createSuperpixelSEEDS(width, height, channels, int(superpixelNum), num_levels=2,prior=1,histogram_bins=5)
seeds.iterate(I_new,4)
segments = seeds.getLabels()
# segments=SegmentsLabelProcess(segments) # 排除labels中不连续的情况
return segments
def SegmentsLabelProcess(labels):
'''
对labels做后处理,防止出现label不连续现象
'''
labels = np.array(labels, np.int64)
H, W = labels.shape
ls = list(set(np.reshape(labels, [-1]).tolist()))
dic = {}
for i in range(len(ls)):
dic[ls[i]] = i
new_labels = labels
for i in range(H):
for j in range(W):
new_labels[i, j] = dic[new_labels[i, j]]
return new_labels
class SLIC(object):
def __init__(self, HSI,labels, n_segments=1000, compactness=20, max_iter=20, sigma=0, min_size_factor=0.3,
max_size_factor=2):
self.n_segments = n_segments
self.compactness = compactness
self.max_iter = max_iter
self.min_size_factor = min_size_factor
self.max_size_factor = max_size_factor
self.sigma = sigma
# 数据standardization标准化,即提前全局BN
height, width, bands = HSI.shape # 原始高光谱数据的三个维度
data = np.reshape(HSI, [height * width, bands])
minMax = preprocessing.StandardScaler()
data = minMax.fit_transform(data)
self.data = np.reshape(data, [height, width, bands])
self.labels=labels
def get_Q_and_S_and_Segments(self):
# 执行 SLCI 并得到Q(nxm),S(m*b)
img = self.data
(h, w, d) = img.shape
# 计算超像素S以及相关系数矩阵Q
segments = slic(img, n_segments=self.n_segments, compactness=self.compactness, max_iter=self.max_iter,
convert2lab=False,sigma=self.sigma, enforce_connectivity=True,
min_size_factor=self.min_size_factor, max_size_factor=self.max_size_factor,slic_zero=False)
# segments = felzenszwalb(img, scale=1,sigma=0.5,min_size=25)
# segments = quickshift(img,ratio=1,kernel_size=5,max_dist=4,sigma=0.8, convert2lab=False)
# segments=LSC_superpixel(img,self.n_segments)
# segments=SEEDS_superpixel(img,self.n_segments)
# 判断超像素label是否连续,否则予以校正
if segments.max()+1!=len(list(set(np.reshape(segments,[-1]).tolist()))): segments = SegmentsLabelProcess(segments)
self.segments = segments
superpixel_count = segments.max() + 1
self.superpixel_count = superpixel_count
print("superpixel_count", superpixel_count)
# ######################################显示超像素图片
out = mark_boundaries(img[:,:,[0,1,2]], segments)
# out = (img[:, :, [0, 1, 2]]-np.min(img[:, :, [0, 1, 2]]))/(np.max(img[:, :, [0, 1, 2]])-np.min(img[:, :, [0, 1, 2]]))
plt.figure()
plt.imshow(out)
plt.show()
segments = np.reshape(segments, [-1])
S = np.zeros([superpixel_count, d], dtype=np.float32)
Q = np.zeros([w * h, superpixel_count], dtype=np.float32)
x = np.reshape(img, [-1, d])
for i in range(superpixel_count):
idx = np.where(segments == i)[0]
count = len(idx)
pixels = x[idx]
superpixel = np.sum(pixels, 0) / count
S[i] = superpixel
Q[idx, i] = 1
self.S = S
self.Q = Q
return Q, S , self.segments
def get_A(self, sigma: float):
'''
根据 segments 判定邻接矩阵
:return:
'''
A = np.zeros([self.superpixel_count, self.superpixel_count], dtype=np.float32)
(h, w) = self.segments.shape
for i in range(h - 2):
for j in range(w - 2):
sub = self.segments[i:i + 2, j:j + 2]
sub_max = np.max(sub).astype(np.int32)
sub_min = np.min(sub).astype(np.int32)
# if len(sub_set)>1:
if sub_max != sub_min:
idx1 = sub_max
idx2 = sub_min
if A[idx1, idx2] != 0:
continue
pix1 = self.S[idx1]
pix2 = self.S[idx2]
diss = np.exp(-np.sum(np.square(pix1 - pix2)) / sigma ** 2)
A[idx1, idx2] = A[idx2, idx1] = diss
return A
class LDA_SLIC(object):
def __init__(self,data,labels,n_component):
self.data=data
self.init_labels=labels
self.curr_data=data
self.n_component=n_component
self.height,self.width,self.bands=data.shape
self.x_flatt=np.reshape(data,[self.width*self.height,self.bands])
self.y_flatt=np.reshape(labels,[self.height*self.width])
self.labes=labels
def LDA_Process(self,curr_labels):
'''
:param curr_labels: height * width
:return:
'''
curr_labels=np.reshape(curr_labels,[-1])
idx=np.where(curr_labels!=0)[0]
x=self.x_flatt[idx]
y=curr_labels[idx]
lda = LinearDiscriminantAnalysis()#n_components=self.n_component
lda.fit(x,y-1)
X_new = lda.transform(self.x_flatt)
return np.reshape(X_new,[self.height, self.width,-1])
def SLIC_Process(self,img,scale=25):
n_segments_init=self.height*self.width/scale
print("n_segments_init",n_segments_init)
myslic=SLIC(img,n_segments=n_segments_init,labels=self.labes, compactness=1,sigma=1, min_size_factor=0.1, max_size_factor=2)
Q, S, Segments = myslic.get_Q_and_S_and_Segments()
A=myslic.get_A(sigma=10)
return Q,S,A,Segments
def simple_superpixel(self,scale):
curr_labels = self.init_labels
X = self.LDA_Process(curr_labels)
Q, S, A, Seg = self.SLIC_Process(X,scale=scale)
return Q, S, A,Seg
def simple_superpixel_no_LDA(self,scale):
Q, S, A, Seg = self.SLIC_Process(self.data,scale=scale)
return Q, S, A,Seg