-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict_resize_TTA.py
281 lines (239 loc) · 10.8 KB
/
predict_resize_TTA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 25 21:06:36 2020
@author: dell
"""
import gdal
import numpy as np
from keras.models import load_model
import datetime
import math
import cv2
from skimage import morphology
import keras.backend as K
from Model.unet_BN_dilationConv import unet
from matplotlib import pyplot as plt
from timeit import default_timer as timer
# delete warning
import logging
logging.disable(30)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def dsc(y_true, y_pred):
smooth = 1.
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return score
def dice_loss(y_true, y_pred):
loss = 1 - dsc(y_true, y_pred)
return loss
def center_dice_loss(y_true, y_pred):
return dice_loss(y_true, y_pred)
def mse_center_dice_loss(y_true, y_pred):
loss = center_dice_loss(y_true, y_pred) + mean_squared_error(y_true, y_pred)
return loss
# 读取tif数据集
def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件无法打开")
# 栅格矩阵的列数
width = dataset.RasterXSize
# 栅格矩阵的行数
height = dataset.RasterYSize
# 波段数
bands = dataset.RasterCount
# 获取数据
if(data_width == 0 and data_height == 0):
data_width = width
data_height = height
data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
# 获取仿射矩阵信息
geotrans = dataset.GetGeoTransform()
# 获取投影信息
proj = dataset.GetProjection()
return width, height, bands, data, geotrans, proj
# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
#创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
dataset.SetProjection(im_proj) #写入投影
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
# 对测试图片进行归一化,并使其维度上和训练图片保持一致
def testGenerator(img):
# 归一化
img = img / 255.0
# 在不改变数据内容情况下,改变shape
img = np.reshape(img,(1,)+img.shape)
yield img
Tmain = 0.3
ModelPath = r"Model\unet_BN_dilationConv_model_weighted_mse_HRSC2016_resize_addval_best.hdf5"
TifFolder = r"HRSC2016_resize\test\image"
HeatmapFolder = r"HRSC2016_resize\test\heatmap_resize_addval_TTA_" + str(Tmain)
BboxFolder = r"HRSC2016_resize\test\bbox_resize_addval_TTA_" + str(Tmain)
HbboxFolder = r"HRSC2016_resize\test\predict_resize_addval_TTA_" + str(Tmain)
DrawBboxFolder = r"HRSC2016_resize\test\drawBbox_resize_addval_TTA_" + str(Tmain)
if not os.path.exists(HeatmapFolder):
os.makedirs(HeatmapFolder)
if not os.path.exists(BboxFolder):
os.makedirs(BboxFolder)
if not os.path.exists(HbboxFolder):
os.makedirs(HbboxFolder)
if not os.path.exists(DrawBboxFolder):
os.makedirs(DrawBboxFolder)
image_list = os.listdir(TifFolder)
model = unet(ModelPath)
start = timer()
for image_index in range(len(image_list)):
TifPath = TifFolder + "\\" + image_list[image_index]
HeatmapPath = HeatmapFolder + "\\" + image_list[image_index]
DrawBboxPath = DrawBboxFolder + "\\" + image_list[image_index]
BboxPath = BboxFolder + "\\" + image_list[image_index][:-4] + ".txt"
HbboxPath = HbboxFolder + "\\" + image_list[image_index][:-4] + ".txt"
# 记录测试消耗时间
testtime = []
# 获取当前时间
starttime = datetime.datetime.now()
im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(TifPath)
im_data = im_data[0:3]
im_data = im_data.swapaxes(1, 0)
im_data = im_data.swapaxes(1, 2)
im_data = cv2.resize(im_data, (512, 512))
testGene = testGenerator(im_data)
heatmap = model.predict_generator(testGene,
1,
verbose = 1)
# endtime = datetime.datetime.now()
# text = "模型预测完毕,目前耗时间: " + str((endtime - starttime).seconds) + "s"
# print(text)
# testtime.append(text)
heatmap1 = heatmap[0,:,:,0]
# #拼接结果
# writeTiff(heatmap1, im_geotrans, im_proj, HeatmapPath + "1.tif")
# 水平翻转
# --------------------------------------------------------------------------------------------------
im_data_old = im_data.copy()
im_data = cv2.flip(im_data_old, 1)
testGene = testGenerator(im_data)
heatmap = model.predict_generator(testGene,
1,
verbose = 1)
heatmap2 = heatmap[0,:,:,0]
# writeTiff(heatmap2, im_geotrans, im_proj, HeatmapPath + "2.tif")
# 垂直翻转
# --------------------------------------------------------------------------------------------------
im_data = cv2.flip(im_data_old, 0)
testGene = testGenerator(im_data)
heatmap = model.predict_generator(testGene,
1,
verbose = 1)
heatmap3 = heatmap[0,:,:,0]
# writeTiff(heatmap3, im_geotrans, im_proj, HeatmapPath + "3.tif")
# 对角翻转
# --------------------------------------------------------------------------------------------------
im_data = cv2.flip(im_data_old, 0)
im_data = cv2.flip(im_data, 1)
testGene = testGenerator(im_data)
heatmap = model.predict_generator(testGene,
1,
verbose = 1)
heatmap4 = heatmap[0,:,:,0]
# writeTiff(heatmap4, im_geotrans, im_proj, HeatmapPath + "4.tif")
heatmap = (heatmap1 + np.flip(heatmap2, 1) + np.flip(heatmap3, axis = 0) + np.flip(np.flip(heatmap4, 1),0)) / 4
writeTiff(heatmap, im_geotrans, im_proj, HeatmapPath)
image = im_data_old
# 大于中心阈值赋值为1,反之为0
center_threshold = Tmain
heatmap_center = np.where(heatmap > center_threshold, 1, 0).astype(np.uint8)
# 连通域分析->连通区域的个数、整张图的标签、每个区域的左上角坐标,宽,长和面积、每个连通区域的中心点
nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(heatmap_center, connectivity = 4)
# 热力图的宽和高
img_h, img_w = heatmap_center.shape[:2]
# 目标框
bboxes = []
for k in range(1, nLabels):
# 区域的面积,若面积小于2500,舍弃
area = stats[k, cv2.CC_STAT_AREA]
if area < 100: continue
# 区域最大值小于0.9,舍弃
k_heatmap = heatmap.copy()
k_heatmap[labels != k] = 0
k_heatmapMax = np.max(k_heatmap)
if k_heatmapMax < 0.9: continue
# 区域的左上角坐标,宽,高
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
# 区域(水平垂直)外接矩形的面积
size = w * h
# print("区域的左上角坐标,宽,高:", x, y, w, h, "\n区域(水平垂直)外接矩形的面积:", size, "\n区域的面积:", area)
# 计算area与size比例->得到区域rbox(最小外接矩形)的倾斜程度
# 值高于0.4认为倾斜一般
if area * 1. / size > 0.4:
# 经验的方法->得到膨胀核的大小niter
niter = int(math.sqrt(area * min(w, h) / size) * 4.3)
# 值低于0.4认为倾斜比较严重
else:
# 倾斜严重的话,宽可以利用对角线长度近似代替
new_w = math.sqrt(w**2 + h**2)
# 经验的方法->得到膨胀核的大小niter
niter = int(math.sqrt(area * 1.0 / new_w) * 4.3)
# print("区域面积与外接矩形面积比例:", area * 1. / size, '\n膨胀核的大小:', niter)
# 膨胀边界&边界检查
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
if sx < 0 : sx = 0
if sy < 0 : sy = 0
if ex >= img_w: ex = img_w
if ey >= img_h: ey = img_h
# 分割图
segmap = np.zeros(heatmap_center.shape, np.uint8)
segmap[labels == k] = 255
# 限制膨胀区域膨胀
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 1 + niter))
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
segmap = cv2.resize(segmap, (im_width, im_height))
np_contours = np.roll(np.array(np.where(segmap != 0)), 1, axis = 0).transpose().reshape(-1, 2)
# 生成最小外接矩形->矩形的中心点、长和宽、旋转角度
rectangle = cv2.minAreaRect(np_contours)
# 获取该矩形的四个顶点坐标
box = cv2.boxPoints(rectangle)
box = box.astype('int32')
bboxes.append(box)
print("最小外接矩形坐标:\n", box)
image = cv2.resize(image, (im_width, im_height))
# 在原图画框
for i in range(len(bboxes)):
xymin = np.min(bboxes[i], 0)
xymax = np.max(bboxes[i], 0)
cv2.polylines(image, np.array([bboxes[i]],np.int), 1, 255)
cv2.rectangle(image, (xymin[0], xymin[1]), (xymax[0], xymax[1]), (0, 0, 255), thickness = 3)
# # 保存Bbox
# with open(BboxPath, "w") as fBbox:
# for i in range(len(bboxes)):
# fBbox.write("ship 1 " + str(bboxes[i][0][0]) + " " + str(bboxes[i][0][1]) + " " + str(bboxes[i][1][0]) + " " + str(bboxes[i][1][1]) + " " + str(bboxes[i][2][0]) + " " + str(bboxes[i][2][1]) + " " + str(bboxes[i][3][0]) + " " + str(bboxes[i][3][1]) + "\n")
# 保存Hbbox
with open(HbboxPath, "w") as fHbbox:
for i in range(len(bboxes)):
xymin = np.min(bboxes[i], 0)
xymax = np.max(bboxes[i], 0)
fHbbox.write("ship 1 " + str(xymin[0]) + " " + str(xymin[1]) + " " + str(xymax[0]) + " " + str(xymax[1]) + "\n")
cv2.imwrite(DrawBboxPath, image)
end = timer()
print(end - start)