Source code for mmselfsup.models.backbones.mixmim_backbone
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Optional, Union
import torch
from mmcls.models.backbones import MixMIMTransformer
from torch import nn
from torch.nn import functional as F
from mmselfsup.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
[docs]@MODELS.register_module()
class MixMIMTransformerPretrain(MixMIMTransformer):
"""MixMIM backbone during pretraining.
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
Modeling for Efficient Visual Representation Learning
<https://arxiv.org/abs/2205.13137>`_
Args:
arch (str | dict): MixMIM architecture. If use string,
choose from 'base','large' and 'huge'.
If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
Defaults to 'base'.
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to mlp_ratio
the most common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
window_size (list): The height and width of the window.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
attn_drop_rate (float): Attention drop rate. Defaults to 0.
use_checkpoint (bool): Whether use the checkpoint to
reduce GPU memory cost
range_mask_ratio (float): The range of mask ratio.
Defaults to 0.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'base',
mlp_ratio: float = 4,
img_size: int = 224,
patch_size: int = 4,
in_channels: int = 3,
window_size: List = [14, 14, 14, 7],
qkv_bias: bool = True,
patch_cfg: dict = dict(),
norm_cfg: dict = dict(type='LN'),
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
attn_drop_rate: float = 0.0,
use_checkpoint: bool = False,
range_mask_ratio: float = 0.0,
init_cfg: Optional[dict] = None) -> None:
super().__init__(
arch=arch,
mlp_ratio=mlp_ratio,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
window_size=window_size,
qkv_bias=qkv_bias,
patch_cfg=patch_cfg,
norm_cfg=norm_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
use_checkpoint=use_checkpoint,
init_cfg=init_cfg)
self.range_mask_ratio = range_mask_ratio
[docs] def init_weights(self):
"""Initialize position embedding, patch embedding."""
super(MixMIMTransformer, self).init_weights()
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.absolute_pos_embed.shape[-1],
cls_token=False)
self.absolute_pos_embed.data.copy_(pos_embed.float())
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
[docs] def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5):
"""Generate the mask for MixMIM Pretraining.
Args:
x (torch.Tensor): Image with data augmentation applied, which is
of shape B x L x C.
mask_ratio (float): The mask ratio of total patches.
Defaults to 0.5.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- mask_s1 (torch.Tensor): mask with stride of
self.encoder_stride // 8.
- mask_s2 (torch.Tensor): mask with stride of
self.encoder_stride // 4.
- mask_s3 (torch.Tensor): mask with stride of
self.encoder_stride // 2.
- mask (torch.Tensor): mask with stride of
self.encoder_stride.
"""
B, C, H, W = x.shape
out_H = H // self.encoder_stride
out_W = W // self.encoder_stride
s3_H, s3_W = out_H * 2, out_W * 2
s2_H, s2_W = out_H * 4, out_W * 4
s1_H, s1_W = out_H * 8, out_W * 8
seq_l = out_H * out_W
# use a shared mask for a batch images
mask = torch.zeros([1, 1, seq_l], device=x.device)
mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1]
# ascend: small is keep, large is removed
mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
mask.scatter_(2, mask_idx, 1)
mask = mask.reshape(1, 1, out_H, out_W)
mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')
mask = mask.reshape(1, out_H * out_W, 1).contiguous()
mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()
return mask_s1, mask_s2, mask_s3, mask
[docs] def forward(self, x: torch.Tensor, mask_ratio=0.5):
"""Generate features for masked images.
This function generates mask and masks some patches randomly and get
the hidden features for visible patches.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- x (torch.Tensor): hidden features, which is of shape
B x L x C.
- mask_s4 (torch.Tensor): the mask tensor for the last layer.
"""
mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio)
x, _ = self.patch_embed(x)
x = x * (1. - mask_s1) + x.flip(0) * mask_s1
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
for idx, layer in enumerate(self.layers):
if idx == 0:
x = layer(x, attn_mask=mask_s1)
elif idx == 1:
x = layer(x, attn_mask=mask_s2)
elif idx == 2:
x = layer(x, attn_mask=mask_s3)
elif idx == 3:
x = layer(x, attn_mask=mask_s4)
x = self.norm(x)
return x, mask_s4