diff --git a/networks.py b/networks.py index 07ff051..6a2ef47 100644 --- a/networks.py +++ b/networks.py @@ -26,6 +26,22 @@ def forward(self, inp): class PanNet(nn.Module): def __init__(self, nbands, ratio): super(PanNet, self).__init__() + + bfilter_ms = torch.ones(nbands, 1, 5, 5) + bfilter_ms = bfilter_ms / (bfilter_ms.shape[-2] * bfilter_ms.shape[-1]) + bfilter_pan = torch.ones(1, 1, 5, 5) + bfilter_pan = bfilter_pan / (bfilter_pan.shape[-2] * bfilter_pan.shape[-1]) + + self.dephtconv_ms = nn.Conv2d(in_channels=nbands, out_channels=nbands, padding=(2, 2), + kernel_size=bfilter_ms.shape, groups=nbands, bias=False, padding_mode='replicate') + self.dephtconv_ms.weight.data = bfilter_ms + self.dephtconv_ms.weight.requires_grad = False + + self.dephtconv_pan = nn.Conv2d(in_channels=1, out_channels=1, padding=(2, 2), + kernel_size=bfilter_pan.shape, groups=1, bias=False, padding_mode='replicate') + self.dephtconv_pan.weight.data = bfilter_pan + self.dephtconv_pan.weight.requires_grad = False + self.ratio = ratio self.Conv2d_transpose = nn.ConvTranspose2d(nbands, nbands, 8, 4, padding=(2, 2), bias=False) self.Conv = nn.Conv2d(nbands + 1, 32, 3, padding=(1, 1)) @@ -43,10 +59,13 @@ def forward(self, inp): lms = inp[:, :-1, 2::self.ratio, 2::self.ratio] pan = torch.unsqueeze(inp[:, -1, :, :], dim=1) - x = self.Conv2d_transpose(lms) - x = torch.cat((x, pan), dim=1) + lms_hp = lms - self.dephtconv_ms(lms) + pan_hp = pan - self.dephtconv_pan(pan) - x1 = F.relu(self.Conv(x)) + x = self.Conv2d_transpose(lms_hp) + net_inp = torch.cat((x, pan_hp), dim=1) + + x1 = F.relu(self.Conv(net_inp)) x2 = F.relu(self.Conv_1(x1)) x3 = self.Conv_2(x2) + x1 @@ -62,7 +81,9 @@ def forward(self, inp): x10 = self.Conv_9(x9) - return x10 + x11 = inp[:, :-1, :, :] + x10 + + return x11 class DRPNN(nn.Module): diff --git a/weights/GE1_PanNet-TA-FR_model.tar b/weights/GE1_PanNet-TA-FR_model.tar index 4360b27..1f10145 100644 Binary files a/weights/GE1_PanNet-TA-FR_model.tar and b/weights/GE1_PanNet-TA-FR_model.tar differ diff --git a/weights/GE1_Z-PanNet_model.tar b/weights/GE1_Z-PanNet_model.tar index 559027a..d049b27 100644 Binary files a/weights/GE1_Z-PanNet_model.tar and b/weights/GE1_Z-PanNet_model.tar differ diff --git a/weights/WV2_PanNet-TA-FR_model.tar b/weights/WV2_PanNet-TA-FR_model.tar index 4bf0a87..9955d22 100644 Binary files a/weights/WV2_PanNet-TA-FR_model.tar and b/weights/WV2_PanNet-TA-FR_model.tar differ diff --git a/weights/WV2_Z-PanNet_model.tar b/weights/WV2_Z-PanNet_model.tar index 42189e5..3cb097b 100644 Binary files a/weights/WV2_Z-PanNet_model.tar and b/weights/WV2_Z-PanNet_model.tar differ diff --git a/weights/WV3_PanNet-TA-FR_model.tar b/weights/WV3_PanNet-TA-FR_model.tar index f244e17..a4dd18e 100644 Binary files a/weights/WV3_PanNet-TA-FR_model.tar and b/weights/WV3_PanNet-TA-FR_model.tar differ diff --git a/weights/WV3_Z-PanNet_model.tar b/weights/WV3_Z-PanNet_model.tar index deae100..0ab4241 100644 Binary files a/weights/WV3_Z-PanNet_model.tar and b/weights/WV3_Z-PanNet_model.tar differ