-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict_rsimage.py
227 lines (212 loc) · 10.4 KB
/
predict_rsimage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# -*- coding: utf-8 -*-
"""
@author: xqxqxxq
"""
import gdal
import numpy as np
from keras.models import load_model
from keras import losses
import datetime
import math,os
import sys
from dataProcess import color_dict
# 读取tif数据集
def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件无法打开")
# 栅格矩阵的列数
width = dataset.RasterXSize
# 栅格矩阵的行数
height = dataset.RasterYSize
# 波段数
bands = dataset.RasterCount
# 获取数据
if(data_width == 0 and data_height == 0):
data_width = width
data_height = height
data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
# 获取仿射矩阵信息
geotrans = dataset.GetGeoTransform()
# 获取投影信息
proj = dataset.GetProjection()
return width, height, bands, data, geotrans, proj
# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
#创建文件
path_compress=path
path=path.replace('.tif','orig.tif')
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
dataset.SetProjection(im_proj) #写入投影
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
compress(path,path_compress)
def compress(path, target_path,method="LZW"): #
"""使用gdal进行文件压缩,
LZW方法属于无损压缩,黑白图像效果更好
"""
dataset = gdal.Open(path)
driver = gdal.GetDriverByName('GTiff')
driver.CreateCopy(target_path, dataset, strict=1, options=["TILED=YES", "COMPRESS={0}".format(method)])
del dataset
os.remove(path)
# tif裁剪(tif像素数据,裁剪边长)
def TifCroppingArray(img, SideLength):
# 裁剪链表
TifArrayReturn = []
# 列上图像块数目
ColumnNum = int((img.shape[0] - SideLength * 2) / (256 - SideLength * 2))
# 行上图像块数目
RowNum = int((img.shape[1] - SideLength * 2) / (256 - SideLength * 2))
for i in range(ColumnNum):
TifArray = []
for j in range(RowNum):
cropped = img[i * (256 - SideLength * 2) : i * (256 - SideLength * 2) + 256,
j * (256 - SideLength * 2) : j * (256 - SideLength * 2) + 256]
TifArray.append(cropped)
TifArrayReturn.append(TifArray)
# 考虑到行列会有剩余的情况,向前裁剪一行和一列
# 向前裁剪最后一列
for i in range(ColumnNum):
cropped = img[i * (256 - SideLength * 2) : i * (256 - SideLength * 2) + 256,
(img.shape[1] - 256) : img.shape[1]]
TifArrayReturn[i].append(cropped)
# 向前裁剪最后一行
TifArray = []
for j in range(RowNum):
cropped = img[(img.shape[0] - 256) : img.shape[0],
j * (256-SideLength*2) : j * (256 - SideLength * 2) + 256]
TifArray.append(cropped)
# 向前裁剪右下角
cropped = img[(img.shape[0] - 256) : img.shape[0],
(img.shape[1] - 256) : img.shape[1]]
TifArray.append(cropped)
TifArrayReturn.append(TifArray)
# 列上的剩余数
ColumnOver = (img.shape[0] - SideLength * 2) % (256 - SideLength * 2) + SideLength
# 行上的剩余数
RowOver = (img.shape[1] - SideLength * 2) % (256 - SideLength * 2) + SideLength
return TifArrayReturn, RowOver, ColumnOver
# 标签可视化,即为第n类赋上n值
def labelVisualize(img):
img_out = np.zeros((img.shape[0],img.shape[1]))
for i in range(img.shape[0]):
for j in range(img.shape[1]):
# 为第n类赋上n值
img_out[i][j] = np.argmax(img[i][j])
return img_out
# 对测试图片进行归一化,并使其维度上和训练图片保持一致
def testGenerator(TifArray):
for i in range(len(TifArray)):
for j in range(len(TifArray[0])):
img = TifArray[i][j]
# 归一化
img = img / 255.0
# 在不改变数据内容情况下,改变shape
img = np.reshape(img,(1,)+img.shape)
yield img
# 获得结果矩阵
def Result(shape, TifArray, npyfile, num_class, RepetitiveLength, RowOver, ColumnOver):
result = np.zeros(shape, np.uint8)
# j来标记行数
j = 0
for i,item in enumerate(npyfile):
img = labelVisualize(item)
img = img.astype(np.uint8)
# 最左侧一列特殊考虑,左边的边缘要拼接进去
if(i % len(TifArray[0]) == 0):
# 第一行的要再特殊考虑,上边的边缘要考虑进去
if(j == 0):
result[0 : 256 - RepetitiveLength, 0 : 256-RepetitiveLength] = img[0 : 256 - RepetitiveLength, 0 : 256 - RepetitiveLength]
# 最后一行的要再特殊考虑,下边的边缘要考虑进去
elif(j == len(TifArray) - 1):
# 原来错误的
#result[shape[0] - ColumnOver : shape[0], 0 : 256 - RepetitiveLength] = img[0 : ColumnOver, 0 : 256 - RepetitiveLength]
# 后来修改的
result[shape[0] - ColumnOver - RepetitiveLength: shape[0], 0 : 256 - RepetitiveLength] = img[256 - ColumnOver - RepetitiveLength : 256, 0 : 256 - RepetitiveLength]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength : (j + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength,
0:256-RepetitiveLength] = img[RepetitiveLength : 256 - RepetitiveLength, 0 : 256 - RepetitiveLength]
# 最右侧一列特殊考虑,右边的边缘要拼接进去
elif(i % len(TifArray[0]) == len(TifArray[0]) - 1):
# 第一行的要再特殊考虑,上边的边缘要考虑进去
if(j == 0):
result[0 : 256 - RepetitiveLength, shape[1] - RowOver: shape[1]] = img[0 : 256 - RepetitiveLength, 256 - RowOver: 256]
# 最后一行的要再特殊考虑,下边的边缘要考虑进去
elif(j == len(TifArray) - 1):
result[shape[0] - ColumnOver : shape[0], shape[1] - RowOver : shape[1]] = img[256 - ColumnOver : 256, 256 - RowOver : 256]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength : (j + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength,
shape[1] - RowOver : shape[1]] = img[RepetitiveLength : 256 - RepetitiveLength, 256 - RowOver : 256]
# 走完每一行的最右侧,行数+1
j = j + 1
# 不是最左侧也不是最右侧的情况
else:
# 第一行的要特殊考虑,上边的边缘要考虑进去
if(j == 0):
result[0 : 256 - RepetitiveLength,
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength : (i - j * len(TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength
] = img[0 : 256 - RepetitiveLength, RepetitiveLength : 256 - RepetitiveLength]
# 最后一行的要特殊考虑,下边的边缘要考虑进去
if(j == len(TifArray) - 1):
result[shape[0] - ColumnOver : shape[0],
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength : (i - j * len(TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength
] = img[256 - ColumnOver : 256, RepetitiveLength : 256 - RepetitiveLength]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength : (j + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength,
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength : (i - j * len(TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength,
] = img[RepetitiveLength : 256 - RepetitiveLength, RepetitiveLength : 256 - RepetitiveLength]
return result
def predict(TifPath,ResultPath):
# 获取当前时间
model = load_model(ModelPath)
print("\r", end="")
files=os.listdir(TifPath)
for i,file in enumerate(files):
im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(os.path.join(TifPath,file))
im_data = im_data.swapaxes(1, 0)
im_data = im_data.swapaxes(1, 2)
TifArray, RowOver, ColumnOver = TifCroppingArray(im_data, RepetitiveLength)
testGene = testGenerator(TifArray)
results = model.predict_generator(testGene,
len(TifArray) * len(TifArray[0]),
verbose = 0)
#保存结果
result_shape = (im_data.shape[0], im_data.shape[1])
result_data = Result(result_shape, TifArray, results, 2, RepetitiveLength, RowOver, ColumnOver)
#获取colorbar
img_color = np.uint8(colorDict_RGB[result_data.astype(np.uint8)])
img_color=np.transpose(img_color,[2,0,1])
writeTiff(img_color, im_geotrans, im_proj, os.path.join(ResultPath,file))
print("\r", end="")
print("Predict progress: {}%: ".format(int((i+1)/len(files)*100)), "▋" * (i+1), end="")
sys.stdout.flush()
#重叠面积参数
area_perc = 0.6
# 算背景类
classNum=2
# 训练模型保存地址
ModelPath = r"./model/unet_model.hdf5"
# 测试数据路径
TifPath = r"D:\XUQI\dataset\NEW2-AerialImageDataset\austin\image"
# 结果保存路径
ResultPath = r"D:\XUQI\dataset\NEW2-AerialImageDataset\austin\predict"
train_label_path="./dataset/train/label"
RepetitiveLength = int((1 - math.sqrt(area_perc)) * 256 / 2)
colorDict_RGB, colorDict_GRAY = color_dict(train_label_path, classNum)
predict(TifPath,ResultPath)