Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViTAE_NC_Win_RVSA_V3_WSZ7预训练权重加载 #36

Open
Dawn-creat opened this issue May 24, 2024 · 1 comment
Open

ViTAE_NC_Win_RVSA_V3_WSZ7预训练权重加载 #36

Dawn-creat opened this issue May 24, 2024 · 1 comment

Comments

@Dawn-creat
Copy link

当我加载vitae-b-checkpoint-1599-transform-no-average.pth时,出现了如下错误,我使用的数据集时potsdam
Error(s) in loading state_dict for ViTAE_NC_Win_RVSA_V3_WSZ7:
size mismatch for pos_embed: copying a param with shape torch.Size([1, 197, 768]) from checkpoint, the shape in current model is torch.Size([1, 1024, 768]).

` def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.

    Args:
        pretrained (str, optional): Path to pre-trained weights.
            Defaults to None.
    """
    pretrained = pretrained or self.pretrained
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    if isinstance(pretrained, str):
        self.apply(_init_weights)
        logger = get_root_logger()
        print(f"load from {pretrained}")
        checkpoint = _load_checkpoint(self.pretrained, logger=logger,map_location='cpu')
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model' in checkpoint:
            state_dict = checkpoint['model']
        else:
            state_dict = checkpoint
        self.load_state_dict(state_dict, False)
    elif pretrained is None:
        self.apply(_init_weights)
    else:
        raise TypeError('pretrained must be a str or None')`
@DotWang
Copy link
Collaborator

DotWang commented May 26, 2024

@Dawn-creat 不影响,位置编码会跟随尺寸自动插值缩放

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants