Skip to content

Commit

Permalink
[Enhance] Add iTPN Supports for Non-three channel image (#1735)
Browse files Browse the repository at this point in the history
* Add channel argments to mae_head

When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead.

* Transfer other argments from iTPNHiViT to HiViT

The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels.

* Update itpn.py

Fix hint problem
  • Loading branch information
MGAMZ authored Sep 4, 2023
1 parent e1675e8 commit da1da48
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
23 changes: 13 additions & 10 deletions mmpretrain/models/heads/mae_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,54 +14,57 @@ class MAEPretrainHead(BaseModule):
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
in_channels (int): Number of input channels. Defaults to 3.
"""

def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
patch_size: int = 16,
in_channels: int = 3) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
self.in_channels = in_channels
self.loss_module = MODELS.build(loss)

def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
r"""Split images into non-overlapped patches.
Args:
imgs (torch.Tensor): A batch of images. The shape should
be :math:`(B, 3, H, W)`.
be :math:`(B, C, H, W)`.
Returns:
torch.Tensor: Patchified images. The shape is
:math:`(B, L, \text{patch_size}^2 \times 3)`.
:math:`(B, L, \text{patch_size}^2 \times C)`.
"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels))
return x

def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
r"""Combine non-overlapped patches into images.
Args:
x (torch.Tensor): The shape is
:math:`(B, L, \text{patch_size}^2 \times 3)`.
:math:`(B, L, \text{patch_size}^2 \times C)`.
Returns:
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
torch.Tensor: The shape is :math:`(B, C, H, W)`.
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]

x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p))
return imgs

def construct_target(self, target: torch.Tensor) -> torch.Tensor:
Expand All @@ -71,7 +74,7 @@ def construct_target(self, target: torch.Tensor) -> torch.Tensor:
normalize the image according to ``norm_pix``.
Args:
target (torch.Tensor): Image with the shape of B x 3 x H x W
target (torch.Tensor): Image with the shape of B x C x H x W
Returns:
torch.Tensor: Tokenized images with the shape of B x L x C
Expand Down
5 changes: 4 additions & 1 deletion mmpretrain/models/selfsup/itpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
layer_scale_init_value: float = 0.0,
mask_ratio: float = 0.75,
reconstruction_type: str = 'pixel',
**kwargs,
):
super().__init__(
arch=arch,
Expand All @@ -80,7 +81,9 @@ def __init__(
norm_cfg=norm_cfg,
ape=ape,
rpe=rpe,
layer_scale_init_value=layer_scale_init_value)
layer_scale_init_value=layer_scale_init_value,
**kwargs,
)

self.pos_embed.requires_grad = False
self.mask_ratio = mask_ratio
Expand Down

0 comments on commit da1da48

Please sign in to comment.