Skip to content

Commit

Permalink
[FIX] Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
matciotola@gmail.com authored and matciotola@gmail.com committed May 26, 2022
1 parent 693e377 commit b0c47bd
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ def forward(self, inp):
class PanNet(nn.Module):
def __init__(self, nbands, ratio):
super(PanNet, self).__init__()

bfilter_ms = torch.ones(nbands, 1, 5, 5)
bfilter_ms = bfilter_ms / (bfilter_ms.shape[-2] * bfilter_ms.shape[-1])
bfilter_pan = torch.ones(1, 1, 5, 5)
bfilter_pan = bfilter_pan / (bfilter_pan.shape[-2] * bfilter_pan.shape[-1])

self.dephtconv_ms = nn.Conv2d(in_channels=nbands, out_channels=nbands, padding=(2, 2),
kernel_size=bfilter_ms.shape, groups=nbands, bias=False, padding_mode='replicate')
self.dephtconv_ms.weight.data = bfilter_ms
self.dephtconv_ms.weight.requires_grad = False

self.dephtconv_pan = nn.Conv2d(in_channels=1, out_channels=1, padding=(2, 2),
kernel_size=bfilter_pan.shape, groups=1, bias=False, padding_mode='replicate')
self.dephtconv_pan.weight.data = bfilter_pan
self.dephtconv_pan.weight.requires_grad = False

self.ratio = ratio
self.Conv2d_transpose = nn.ConvTranspose2d(nbands, nbands, 8, 4, padding=(2, 2), bias=False)
self.Conv = nn.Conv2d(nbands + 1, 32, 3, padding=(1, 1))
Expand All @@ -43,10 +59,13 @@ def forward(self, inp):
lms = inp[:, :-1, 2::self.ratio, 2::self.ratio]
pan = torch.unsqueeze(inp[:, -1, :, :], dim=1)

x = self.Conv2d_transpose(lms)
x = torch.cat((x, pan), dim=1)
lms_hp = lms - self.dephtconv_ms(lms)
pan_hp = pan - self.dephtconv_pan(pan)

x1 = F.relu(self.Conv(x))
x = self.Conv2d_transpose(lms_hp)
net_inp = torch.cat((x, pan_hp), dim=1)

x1 = F.relu(self.Conv(net_inp))

x2 = F.relu(self.Conv_1(x1))
x3 = self.Conv_2(x2) + x1
Expand All @@ -62,7 +81,9 @@ def forward(self, inp):

x10 = self.Conv_9(x9)

return x10
x11 = inp[:, :-1, :, :] + x10

return x11


class DRPNN(nn.Module):
Expand Down
Binary file modified weights/GE1_PanNet-TA-FR_model.tar
Binary file not shown.
Binary file modified weights/GE1_Z-PanNet_model.tar
Binary file not shown.
Binary file modified weights/WV2_PanNet-TA-FR_model.tar
Binary file not shown.
Binary file modified weights/WV2_Z-PanNet_model.tar
Binary file not shown.
Binary file modified weights/WV3_PanNet-TA-FR_model.tar
Binary file not shown.
Binary file modified weights/WV3_Z-PanNet_model.tar
Binary file not shown.

0 comments on commit b0c47bd

Please sign in to comment.