-
Notifications
You must be signed in to change notification settings - Fork 0
/
usquarednet.py
400 lines (309 loc) · 17 KB
/
usquarednet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
#!/usr/bin/env python
# coding: utf-8
# U^2 Net - https://github.com/NathanUA/U-2-Net/blob/master/model/u2net.py
#
# Based off of implementation https://github.com/NathanUA/U-2-Net/blob/master/model/u2net.py
# In[ ]:
import tensorflow as tf
# In[ ]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
BatchNormalization,
Conv2D,
Conv2DTranspose,
MaxPooling2D,
Dropout,
SpatialDropout2D,
UpSampling2D,
Input,
concatenate,
multiply,
add,
Activation,
GlobalAveragePooling2D,
Dense,
Multiply,
Input,
)
from tensorflow.keras import backend as K
# In[ ]:
def rebnconv_block(
inputs,
filters = 3,
dirate = 1,
afunc = 'swish',
dropout = 0.0,
kernel_initializer="he_normal",
padding="same"
):
conv = Conv2D(
filters,
3,
kernel_initializer=kernel_initializer,
padding=padding,
dilation_rate=(dirate, dirate),
)(inputs)
bn = BatchNormalization()(conv)
if dropout > 0.0:
bn = SpatialDropout2D(dropout)(bn)
activation = Activation(afunc)(bn)
return activation
# In[ ]:
def rsu4f(inputs, in_ch=3, mid_ch=12, out_ch=3, enc_dropout=0.0, dec_dropout=0.0):
rebnconvin = rebnconv_block(inputs, out_ch, dropout=enc_dropout)
rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1, dropout=enc_dropout)
rebnconv2 = rebnconv_block(rebnconv1, mid_ch, dirate=2, dropout=enc_dropout)
rebnconv3 = rebnconv_block(rebnconv2, mid_ch, dirate=4, dropout=enc_dropout)
rebnconv4 = rebnconv_block(rebnconv3, mid_ch, dirate=8, dropout=dec_dropout)
rebnconv3d = rebnconv_block(concatenate([rebnconv4, rebnconv3]), mid_ch, dirate=4, dropout=dec_dropout)
rebnconv2d = rebnconv_block(concatenate([rebnconv3d, rebnconv2]), mid_ch, dirate=2, dropout=dec_dropout)
rebnconv1d = rebnconv_block(concatenate([rebnconv2d, rebnconv1]), out_ch, dirate=1, dropout=dec_dropout)
return rebnconv1d
# In[ ]:
def rsu4(inputs, in_ch=3, mid_ch=12, out_ch=3, enc_dropout=0.0, dec_dropout=0.0):
rebnconvin = rebnconv_block(inputs, out_ch, dropout=enc_dropout)
rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1, dropout=enc_dropout)
pool1 = MaxPooling2D((2, 2))(rebnconv1)
rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1, dropout=enc_dropout)
pool2 = MaxPooling2D((2, 2))(rebnconv2)
rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1, dropout=enc_dropout)
rebnconv4 = rebnconv_block(rebnconv3, mid_ch, dirate=2, dropout=dec_dropout)
rebnconv3d = rebnconv_block(concatenate([rebnconv4, rebnconv3]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), out_ch, dirate=1, dropout=dec_dropout)
return rebnconv1d
# In[ ]:
def rsu5(inputs, in_ch=3, mid_ch=12, out_ch=3, enc_dropout=0.0, dec_dropout=0.0):
rebnconvin = rebnconv_block(inputs, out_ch, dropout=enc_dropout)
rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1, dropout=enc_dropout)
pool1 = MaxPooling2D((2, 2))(rebnconv1)
rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1, dropout=enc_dropout)
pool2 = MaxPooling2D((2, 2))(rebnconv2)
rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1, dropout=enc_dropout)
pool3 = MaxPooling2D((2, 2))(rebnconv3)
rebnconv4 = rebnconv_block(pool3, mid_ch, dirate=1, dropout=enc_dropout)
rebnconv5 = rebnconv_block(rebnconv4, mid_ch, dirate=2, dropout=dec_dropout)
rebnconv4d = rebnconv_block(concatenate([rebnconv5, rebnconv4]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv4dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv4d)
rebnconv3d = rebnconv_block(concatenate([rebnconv4dup, rebnconv3]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), out_ch, dirate=1, dropout=dec_dropout)
return rebnconv1d
# In[ ]:
def rsu6(inputs, in_ch=3, mid_ch=12, out_ch=3, enc_dropout=0.0, dec_dropout=0.0):
rebnconvin = rebnconv_block(inputs, out_ch, dropout=enc_dropout)
rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1, dropout=enc_dropout)
pool1 = MaxPooling2D((2, 2))(rebnconv1)
rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1, dropout=enc_dropout)
pool2 = MaxPooling2D((2, 2))(rebnconv2)
rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1, dropout=enc_dropout)
pool3 = MaxPooling2D((2, 2))(rebnconv3)
rebnconv4 = rebnconv_block(pool3, mid_ch, dirate=1, dropout=enc_dropout)
pool4 = MaxPooling2D((2, 2))(rebnconv4)
rebnconv5 = rebnconv_block(pool4, mid_ch, dirate=1, dropout=enc_dropout)
rebnconv6 = rebnconv_block(rebnconv5, mid_ch, dirate=2, dropout=dec_dropout)
rebnconv5d = rebnconv_block(concatenate([rebnconv6, rebnconv5]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv5dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv5d)
rebnconv4d = rebnconv_block(concatenate([rebnconv5dup, rebnconv4]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv4dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv4d)
rebnconv3d = rebnconv_block(concatenate([rebnconv4dup, rebnconv3]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), out_ch, dirate=1, dropout=dec_dropout)
return rebnconv1d
# In[ ]:
def rsu7(inputs, in_ch=3, mid_ch=12, out_ch=3, enc_dropout=0.0, dec_dropout=0.0):
rebnconvin = rebnconv_block(inputs, out_ch, dropout=enc_dropout)
rebnconv1 = rebnconv_block(rebnconvin, mid_ch, dirate=1, dropout=enc_dropout)
pool1 = MaxPooling2D((2, 2))(rebnconv1)
rebnconv2 = rebnconv_block(pool1, mid_ch, dirate=1, dropout=enc_dropout)
pool2 = MaxPooling2D((2, 2))(rebnconv2)
rebnconv3 = rebnconv_block(pool2, mid_ch, dirate=1, dropout=enc_dropout)
pool3 = MaxPooling2D((2, 2))(rebnconv3)
rebnconv4 = rebnconv_block(pool3, mid_ch, dirate=1, dropout=enc_dropout)
pool4 = MaxPooling2D((2, 2))(rebnconv4)
rebnconv5 = rebnconv_block(pool4, mid_ch, dirate=1, dropout=enc_dropout)
pool5 = MaxPooling2D((2, 2))(rebnconv5)
rebnconv6 = rebnconv_block(pool5, mid_ch, dirate=1, dropout=enc_dropout)
rebnconv7 = rebnconv_block(rebnconv6, mid_ch, dirate=2, dropout=dec_dropout)
rebnconv6d = rebnconv_block(concatenate([rebnconv7, rebnconv6]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv6dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv6d)
rebnconv5d = rebnconv_block(concatenate([rebnconv6dup, rebnconv5]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv5dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv5d)
rebnconv4d = rebnconv_block(concatenate([rebnconv5dup, rebnconv4]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv4dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv4d)
rebnconv3d = rebnconv_block(concatenate([rebnconv4dup, rebnconv3]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv3dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv3d)
rebnconv2d = rebnconv_block(concatenate([rebnconv3dup, rebnconv2]), mid_ch, dirate=1, dropout=dec_dropout)
rebnconv2dup = UpSampling2D((2, 2), interpolation='bilinear')(rebnconv2d)
rebnconv1d = rebnconv_block(concatenate([rebnconv2dup, rebnconv1]), out_ch, dirate=1, dropout=dec_dropout)
return rebnconv1d
# In[ ]:
def u2net_block(
inputs,
in_ch = 3,
out_ch = 1,
kernel_initializer = 'he_normal',
padding = 'same',
inter_enc_dropout = 0.0,
inter_dec_dropout = 0.0,
enc_intra_enc_dropout = 0.0,
enc_intra_dec_dropout = 0.0,
dec_intra_enc_dropout = 0.0,
dec_intra_dec_dropout = 0.0,
):
stage1 = rsu7(inputs, in_ch, 32, 64, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage1 = SpatialDropout2D(inter_enc_dropout)(stage1)
pool12 = MaxPooling2D((2, 2))(stage1)
stage2 = rsu6(pool12, 64, 32, 128, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage2 = SpatialDropout2D(inter_enc_dropout)(stage2)
pool23 = MaxPooling2D((2, 2))(stage2)
stage3 = rsu5(pool23, 128, 64, 256, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage3 = SpatialDropout2D(inter_enc_dropout)(stage3)
pool34 = MaxPooling2D((2, 2))(stage3)
stage4 = rsu4(pool34, 256, 128, 512, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage4 = SpatialDropout2D(inter_enc_dropout)(stage4)
pool45 = MaxPooling2D((2, 2))(stage4)
stage5 = rsu4f(pool45, 512, 256, 512, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage5 = SpatialDropout2D(inter_enc_dropout)(stage5)
pool56 = MaxPooling2D((2, 2))(stage5)
stage6 = rsu4f(pool56, 512, 256, 512, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage6 = SpatialDropout2D(inter_dec_dropout)(stage6)
stage6up = UpSampling2D((2, 2), interpolation='bilinear')(stage6)
stage5d = rsu4f(concatenate([stage6up, stage5]), 1024, 256, 512, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage5d = SpatialDropout2D(inter_dec_dropout)(stage5d)
stage5dup = UpSampling2D((2, 2), interpolation='bilinear')(stage5d)
stage4d = rsu4(concatenate([stage5dup, stage4]), 1024, 128, 256, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage4d = SpatialDropout2D(inter_dec_dropout)(stage4d)
stage4dup = UpSampling2D((2, 2), interpolation='bilinear')(stage4d)
stage3d = rsu5(concatenate([stage4dup, stage3]), 512, 64, 128, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage3d = SpatialDropout2D(inter_dec_dropout)(stage3d)
stage3dup = UpSampling2D((2, 2), interpolation='bilinear')(stage3d)
stage2d = rsu6(concatenate([stage3dup, stage2]), 256, 32, 64, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage2d = SpatialDropout2D(inter_dec_dropout)(stage2d)
stage2dup = UpSampling2D((2, 2), interpolation='bilinear')(stage2d)
stage1d = rsu6(concatenate([stage2dup, stage1]), 128, 16, 64, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage1d = SpatialDropout2D(inter_dec_dropout)(stage1d)
side1 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage1d)
side2 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage2d)
side3 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage3d)
side4 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage4d)
side5 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage5d)
side6 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage6)
d1 = side1
d2 = UpSampling2D((2, 2), interpolation='bilinear')(side2)
d3 = UpSampling2D((4, 4), interpolation='bilinear')(side3)
d4 = UpSampling2D((8, 8), interpolation='bilinear')(side4)
d5 = UpSampling2D((16, 16), interpolation='bilinear')(side5)
d6 = UpSampling2D((32, 32), interpolation='bilinear')(side6)
d0 = Conv2D(
out_ch, 3,
kernel_initializer=kernel_initializer,
padding=padding
)(concatenate([d1, d2, d3, d4, d5, d6]))
def sig(x, n):
return Activation('sigmoid', name=n)(x)
return sig(d0, 'ad0'), sig(d1, 'ad1'), sig(d2, 'ad2'), sig(d3, 'ad3'), sig(d4, 'ad4'), sig(d5, 'ad5'), sig(d6, 'ad6')
# In[ ]:
def u2netp_block(
inputs,
in_ch = 3,
out_ch = 1,
kernel_initializer = 'he_normal',
padding = 'same',
rsu_mid_ch = 16,
rsu_out_ch = 64,
inter_enc_dropout = 0.0,
inter_dec_dropout = 0.0,
enc_intra_enc_dropout = 0.0,
enc_intra_dec_dropout = 0.0,
dec_intra_enc_dropout = 0.0,
dec_intra_dec_dropout = 0.0,
):
stage1 = rsu7(inputs, in_ch, rsu_mid_ch, rsu_out_ch, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage1 = SpatialDropout2D(inter_enc_dropout)(stage1)
pool12 = MaxPooling2D((2, 2))(stage1)
stage2 = rsu6(pool12, rsu_out_ch, rsu_mid_ch, rsu_out_ch, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage2 = SpatialDropout2D(inter_enc_dropout)(stage2)
pool23 = MaxPooling2D((2, 2))(stage2)
stage3 = rsu5(pool23, rsu_out_ch, rsu_mid_ch, rsu_out_ch, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage3 = SpatialDropout2D(inter_enc_dropout)(stage3)
pool34 = MaxPooling2D((2, 2))(stage3)
stage4 = rsu4(pool34, rsu_out_ch, rsu_mid_ch, rsu_out_ch, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage4 = SpatialDropout2D(inter_enc_dropout)(stage4)
pool45 = MaxPooling2D((2, 2))(stage4)
stage5 = rsu4f(pool45, rsu_out_ch, rsu_mid_ch, rsu_out_ch, enc_intra_enc_dropout, enc_intra_dec_dropout)
if inter_enc_dropout > 0.0:
stage5 = SpatialDropout2D(inter_enc_dropout)(stage5)
pool56 = MaxPooling2D((2, 2))(stage5)
stage6 = rsu4f(pool56, rsu_out_ch, rsu_mid_ch, rsu_out_ch, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage6 = SpatialDropout2D(inter_dec_dropout)(stage6)
stage6up = UpSampling2D((2, 2), interpolation='bilinear')(stage6)
stage5d = rsu4f(concatenate([stage6up, stage5]), 2 * rsu_out_ch, rsu_mid_ch, rsu_out_ch, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage5d = SpatialDropout2D(inter_dec_dropout)(stage5d)
stage5dup = UpSampling2D((2, 2), interpolation='bilinear')(stage5d)
stage4d = rsu4(concatenate([stage5dup, stage4]), 2 * rsu_out_ch, rsu_mid_ch, rsu_out_ch, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage4d = SpatialDropout2D(inter_dec_dropout)(stage4d)
stage4dup = UpSampling2D((2, 2), interpolation='bilinear')(stage4d)
stage3d = rsu5(concatenate([stage4dup, stage3]), 2 * rsu_out_ch, rsu_mid_ch, rsu_out_ch, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage3d = SpatialDropout2D(inter_dec_dropout)(stage3d)
stage3dup = UpSampling2D((2, 2), interpolation='bilinear')(stage3d)
stage2d = rsu6(concatenate([stage3dup, stage2]), 2 * rsu_out_ch, rsu_mid_ch, rsu_out_ch, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage2d = SpatialDropout2D(inter_dec_dropout)(stage2d)
stage2dup = UpSampling2D((2, 2), interpolation='bilinear')(stage2d)
stage1d = rsu6(concatenate([stage2dup, stage1]), 2 * rsu_out_ch, rsu_mid_ch, rsu_out_ch, dec_intra_enc_dropout, dec_intra_dec_dropout)
if inter_dec_dropout > 0.0:
stage1d = SpatialDropout2D(inter_dec_dropout)(stage1d)
side1 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage1d)
side2 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage2d)
side3 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage3d)
side4 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage4d)
side5 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage5d)
side6 = Conv2D(out_ch, 3, kernel_initializer=kernel_initializer, padding=padding)(stage6)
d1 = side1
d2 = UpSampling2D((2, 2), interpolation='bilinear')(side2)
d3 = UpSampling2D((4, 4), interpolation='bilinear')(side3)
d4 = UpSampling2D((8, 8), interpolation='bilinear')(side4)
d5 = UpSampling2D((16, 16), interpolation='bilinear')(side5)
d6 = UpSampling2D((32, 32), interpolation='bilinear')(side6)
d0 = Conv2D(
out_ch, 3,
kernel_initializer=kernel_initializer,
padding=padding
)(concatenate([d1, d2, d3, d4, d5, d6]))
def sig(x, n):
return Activation('sigmoid', name=n)(x)
return sig(d0, 'ad0'), sig(d1, 'ad1'), sig(d2, 'ad2'), sig(d3, 'ad3'), sig(d4, 'ad4'), sig(d5, 'ad5'), sig(d6, 'ad6')
# In[ ]:
if __name__ == "__main__":
inputs = Input((256, 256, 1))
sides = u2net_block(inputs)
model = Model(inputs=[inputs], outputs=sides)
print(model.summary())
# In[ ]:
# In[ ]: