-
Notifications
You must be signed in to change notification settings - Fork 1
/
bricks.py
106 lines (86 loc) · 3.79 KB
/
bricks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torch.nn.functional as F
from sync_bn.nn.modules import SynchronizedBatchNorm2d
from functools import partial
norm_layer = partial(SynchronizedBatchNorm2d, momentum=3e-4)
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape # N -> HxW
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
#//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
def stochastic_depth(input: torch.Tensor, p: float,
mode: str, training: bool = True):
if not training or p == 0.0:
# print(f'not adding stochastic depth of: {p}')
return input
survival_rate = 1.0 - p
if mode == 'row':
shape = [input.shape[0]] + [1] * (input.ndim - 1) # just converts BXCXHXW -> [B,1,1,1] list
elif mode == 'batch':
shape = [1] * input.ndim
noise = torch.empty(shape, dtype=input.dtype, device=input.device)
noise = noise.bernoulli_(survival_rate)
if survival_rate > 0.0:
noise.div_(survival_rate)
# print(f'added sDepth of: {p}')
return input * noise
class StochasticDepth(nn.Module):
'''
Stochastic Depth module.
It performs ROW-wise dropping rather than sample-wise.
mode (str): ``"batch"`` or ``"row"``.
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
randomly selected rows from the batch.
References:
- https://pytorch.org/vision/stable/_modules/torchvision/ops/stochastic_depth.html#stochastic_depth
'''
def __init__(self, p=0.5, mode='row'):
super().__init__()
self.p = p
self.mode = mode
def forward(self, input):
return stochastic_depth(input, self.p, self.mode, self.training)
def __repr__(self):
s = f"{self.__class__.__name__}(p={self.p})"
return s
#//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
def resize(input,
size=None,
scale_factor=None,
mode='bilinear',
align_corners=None,
warning=True):
return F.interpolate(input, size, scale_factor, mode, align_corners)
#//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class ConvModule(nn.Module):
def __init__(self, inChannels, outChannels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, act_layer=nn.ReLU):
super().__init__()
self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.norm = norm_layer(outChannels)
self.act = act_layer()
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x
class DepthWiseConv(nn.Module):
def __init__(self, inChannels, outChannels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
super(DepthWiseConv, self).__init__()
self.kernel_size = kernel_size
if self.kernel_size != 1:
self.depthwise = ConvModule(inChannels, inChannels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=inChannels, bias=bias)
self.pointwise = ConvModule(inChannels, outChannels, kernel_size=1, bias=bias)
def forward(self, x):
if self.kernel_size != 1:
x = self.depthwise(x)
out = self.pointwise(x)
return out