mmselfsup.models.backbones.maskfeat_vit 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import torch
from mmcls.models import VisionTransformer
from torch import nn
from mmselfsup.registry import MODELS
[文档]@MODELS.register_module()
class MaskFeatViT(VisionTransformer):
"""Vision Transformer for MaskFeat pre-training.
A PyTorch implement of: `Masked Feature Prediction for Self-Supervised
Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: Union[Sequence, int] = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.mask_token = nn.parameter.Parameter(
torch.zeros(1, 1, self.embed_dims), requires_grad=True)
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
[文档] def init_weights(self) -> None:
"""Initialize position embedding, mask token and cls token."""
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
nn.init.trunc_normal_(self.cls_token, std=.02)
nn.init.trunc_normal_(self.mask_token, std=.02)
nn.init.trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m: torch.nn.Module) -> None:
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[文档] def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate features for masked images.
Args:
x (torch.Tensor): Input images.
mask (torch.Tensor): Input masks.
Returns:
torch.Tensor: Features with cls_tokens.
"""
B = x.shape[0]
x = self.patch_embed(x)[0]
# masking: length -> length * mask_ratio
B, L, _ = x.shape
mask_tokens = self.mask_token.expand(B, L, -1)
mask = mask.flatten(1).unsqueeze(-1)
x = x * (1 - mask.int()) + mask_tokens * mask
# append cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
return x