Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dcnv1 #9

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
37b6fc0
init
telin0411 Nov 20, 2021
02fe0ec
updated
telin0411 Nov 20, 2021
be3716c
Merge branch 'choyingw:main' into main
telin0411 Nov 25, 2021
19e875e
update ignore
telin0411 Nov 25, 2021
adc217b
revised losses
telin0411 Nov 25, 2021
0fc631f
updated
telin0411 Nov 26, 2021
d5ad131
Merge branch 'choyingw-main' into main
telin0411 Nov 26, 2021
46a974d
fixed bugs
telin0411 Nov 26, 2021
e846c63
fixed checkpoint mismatch
telin0411 Nov 30, 2021
31b2151
Merge branch 'choyingw:main' into main
telin0411 Dec 3, 2021
438b7f9
updated ignorte
telin0411 Dec 3, 2021
3f84c73
updated scripts
telin0411 Dec 12, 2021
984fa17
revised demo
telin0411 Dec 12, 2021
be2151e
added scripts
telin0411 Dec 12, 2021
acfa056
init commit
DennisVanEe Dec 12, 2021
d57604a
Initial implementation of dcnv1.
jakobd16 Dec 13, 2021
a28aca2
Update dcnv1.py
akuroodi Dec 13, 2021
1700d5c
Added CUDA elements back in
akuroodi Dec 13, 2021
764382a
added gpu compatibility
telin0411 Dec 13, 2021
e71ecaf
Rewrote entire dcnv1 architecture, should be appropriate now.
jakobd16 Dec 13, 2021
8ecb957
Added a couple more deformable convolution layers.
jakobd16 Dec 13, 2021
6cc5453
fixed minor bugs
telin0411 Dec 13, 2021
1b262ae
updated with support for dcnv2
DennisVanEe Dec 15, 2021
49e95c6
Added DCNv1 inference results
akuroodi Dec 15, 2021
707b080
added MAE results
akuroodi Dec 15, 2021
41c2f59
fixed dcnv2
DennisVanEe Dec 15, 2021
09d872a
Added baseline inference outputs
akuroodi Dec 15, 2021
bd1d8ca
Added dcnv1 NME results.
jakobd16 Dec 16, 2021
0606fc9
updated dcnv2
DennisVanEe Dec 16, 2021
91b4113
added landmarks
DennisVanEe Dec 17, 2021
96d17cc
Pushed results folder.
jakobd16 Dec 17, 2021
469af21
Merge branch 'dcnv1' of https://github.com/telin0411/SynergyNet into …
jakobd16 Dec 17, 2021
0049177
dcnv2 results
DennisVanEe Dec 17, 2021
8436dc9
Merge branch 'dcnv1' of https://github.com/telin0411/SynergyNet into …
DennisVanEe Dec 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Binary file added .DS_Store
Binary file not shown.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__pycache__
results
3dmm_data
aflw2000_data
pretrained
*.zip
train_aug_120x120
ckpts
uv_art
art-all
50 changes: 50 additions & 0 deletions DefConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torchvision.ops
from torch import nn

class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False):

super(DeformableConv2d, self).__init__()

assert type(kernel_size) == tuple or type(kernel_size) == int

kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) == tuple else (stride, stride)
self.padding = padding

self.offset_conv = nn.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True)

nn.init.constant_(self.offset_conv.weight, 0.)
nn.init.constant_(self.offset_conv.bias, 0.)

self.regular_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias)

def forward(self, x):

offset = self.offset_conv(x)

x = torchvision.ops.deform_conv2d(input=x,
offset=offset,
weight=self.regular_conv.weight,
bias=self.regular_conv.bias,
padding=self.padding,
stride=self.stride,
)
return x
1 change: 1 addition & 0 deletions MAE_results/dcnv1_MAE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Mean MAE = 4.471 (in deg), [yaw,pitch,roll] = [3.882, 5.725, 3.807]
5 changes: 5 additions & 0 deletions MAE_results/dcnv1_NME.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Facial Alignment on AFLW2000-3D (NME):
[ 0, 30] Mean: 3.447, Std: 1.456
[30, 60] Mean: 4.192, Std: 2.255
[60, 90] Mean: 5.655, Std: 4.049
[ 0, 90] Mean: 4.431, Std: 0.917
1 change: 1 addition & 0 deletions MAE_results/mobilenet_MAE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Mean MAE = 3.388 (in deg), [yaw,pitch,roll] = [3.566, 4.059, 2.539]
121 changes: 121 additions & 0 deletions backbone_nets/dcnv1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
import torchvision.ops
from DefConv import DeformableConv2d
from torch import nn

class DefConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
super(DefConvBNReLU, self).__init__(
DeformableConv2d(in_planes, out_planes, kernel_size, stride, padding, False),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)

class DeformableBackbone(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
norm_layer=None):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
norm_layer: Module specifying the normalization layer to use
"""
super(DeformableBackbone, self).__init__()

if norm_layer is None:
norm_layer = nn.BatchNorm2d

input_channel = 32
last_channel = 1280

# building first layer
self.last_channel = last_channel#_make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [DefConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
features.append(DefConvBNReLU(input_channel, input_channel, stride=2, norm_layer=norm_layer))
features.append(DefConvBNReLU(input_channel, input_channel, stride=2, norm_layer=norm_layer))
features.append(DefConvBNReLU(input_channel, input_channel, stride=2, norm_layer=norm_layer))
features.append(DefConvBNReLU(input_channel, input_channel, stride=2, norm_layer=norm_layer))
features.append(DefConvBNReLU(input_channel, input_channel, stride=2, norm_layer=norm_layer))
# building last several layers
features.append(DefConvBNReLU(input_channel, self.last_channel, stride=2, norm_layer=norm_layer))
# make it nn.Sequential
self.features = nn.Sequential(*features)

# building classifier

self.num_ori = 12
self.num_shape = 40
self.num_exp = 10


self.classifier_ori = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, self.num_ori),
)
self.classifier_shape = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, self.num_shape),
)
self.classifier_exp = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, self.num_exp),
)

# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass

x = self.features(x)

x = nn.functional.adaptive_avg_pool2d(x, 1)
x = x.reshape(x.shape[0], -1)

pool_x = x.clone()

x_ori = self.classifier_ori(x)
x_shape = self.classifier_shape(x)
x_exp = self.classifier_exp(x)

x = torch.cat((x_ori, x_shape, x_exp), dim=1)
return x, pool_x

def forward(self, x):
return self._forward_impl(x)


def dcnv1(pretrained=False, progress=True, **kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = DeformableBackbone(**kwargs)
return model
159 changes: 159 additions & 0 deletions backbone_nets/dcnv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
'''
Deformable Convolution operator courtesy of: https://github.com/developer0hye/PyTorch-Deformable-Convolution-v2
'''

import torch
import torchvision.ops
from torch import nn

class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False):

super(DeformableConv2d, self).__init__()

assert type(kernel_size) == tuple or type(kernel_size) == int

kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) == tuple else (stride, stride)
self.padding = padding

self.offset_conv = nn.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True)

nn.init.constant_(self.offset_conv.weight, 0.)
nn.init.constant_(self.offset_conv.bias, 0.)

self.modulator_conv = nn.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True)

nn.init.constant_(self.modulator_conv.weight, 0.)
nn.init.constant_(self.modulator_conv.bias, 0.)

self.regular_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias)

def forward(self, x):
#h, w = x.shape[2:]
#max_offset = max(h, w)/4.

offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))

x = torchvision.ops.deform_conv2d(input=x,
offset=offset,
weight=self.regular_conv.weight,
bias=self.regular_conv.bias,
padding=self.padding,
mask=modulator,
stride=self.stride)
return x

class DCNv2(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None,
norm_layer=None):

super(DCNv2, self).__init__()

input_channel = 32
last_channel = 1280

# building first layer
self.last_channel = last_channel#_make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [nn.Conv2d(3, input_channel, stride=2, kernel_size=3)]
features.append(nn.ReLU6())
features.append(nn.Conv2d(input_channel, input_channel, stride=2, kernel_size=3))
features.append(nn.ReLU6())
features.append(nn.Conv2d(input_channel, input_channel, stride=2, kernel_size=3))
features.append(nn.ReLU6())
features.append(nn.Conv2d(input_channel, input_channel, stride=2, kernel_size=3))
features.append(nn.ReLU6())
features.append(DeformableConv2d(input_channel, input_channel, stride=2))
features.append(nn.ReLU6())
features.append(DeformableConv2d(input_channel, input_channel, stride=2))
features.append(nn.ReLU6())
# building last several layers
features.append(DeformableConv2d(input_channel, self.last_channel, stride=2))
features.append(nn.ReLU6())
# make it nn.Sequential
self.features = nn.Sequential(*features)

# building classifier

self.num_ori = 12
self.num_shape = 40
self.num_exp = 10


self.classifier_ori = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, self.num_ori),
)
self.classifier_shape = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, self.num_shape),
)
self.classifier_exp = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, self.num_exp),
)

# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass

x = self.features(x)

x = nn.functional.adaptive_avg_pool2d(x, 1)
x = x.reshape(x.shape[0], -1)

pool_x = x.clone()

x_ori = self.classifier_ori(x)
x_shape = self.classifier_shape(x)
x_exp = self.classifier_exp(x)

x = torch.cat((x_ori, x_shape, x_exp), dim=1)
return x, pool_x

def forward(self, x):
return self._forward_impl(x)

def dcnv2(pretrained=False, progress=True, **kwargs):
model = DCNv2(**kwargs)
return model
Loading