Shortcuts

mmselfsup.models.backbones.mixmim_backbone 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Optional, Union

import torch
from mmcls.models.backbones import MixMIMTransformer
from torch import nn
from torch.nn import functional as F

from mmselfsup.registry import MODELS
from ..utils import build_2d_sincos_position_embedding


[文档]@MODELS.register_module() class MixMIMTransformerPretrain(MixMIMTransformer): """MixMIM backbone during pretraining. A PyTorch implement of : ` MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning <https://arxiv.org/abs/2205.13137>`_ Args: arch (str | dict): MixMIM architecture. If use string, choose from 'base','large' and 'huge'. If use dict, it should have below keys: - **embed_dims** (int): The dimensions of embedding. - **depths** (int): The number of transformer encoder layers. - **num_heads** (int): The number of heads in attention modules. Defaults to 'base'. mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to mlp_ratio the most common input image shape. Defaults to 224. patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. window_size (list): The height and width of the window. qkv_bias (bool): Whether to add bias for qkv in attention modules. Defaults to True. patch_cfg (dict): Extra config dict for patch embedding. Defaults to an empty dict. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. attn_drop_rate (float): Attention drop rate. Defaults to 0. use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory cost range_mask_ratio (float): The range of mask ratio. Defaults to 0. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, arch: Union[str, dict] = 'base', mlp_ratio: float = 4, img_size: int = 224, patch_size: int = 4, in_channels: int = 3, window_size: List = [14, 14, 14, 7], qkv_bias: bool = True, patch_cfg: dict = dict(), norm_cfg: dict = dict(type='LN'), drop_rate: float = 0.0, drop_path_rate: float = 0.0, attn_drop_rate: float = 0.0, use_checkpoint: bool = False, range_mask_ratio: float = 0.0, init_cfg: Optional[dict] = None) -> None: super().__init__( arch=arch, mlp_ratio=mlp_ratio, img_size=img_size, patch_size=patch_size, in_channels=in_channels, window_size=window_size, qkv_bias=qkv_bias, patch_cfg=patch_cfg, norm_cfg=norm_cfg, drop_rate=drop_rate, drop_path_rate=drop_path_rate, attn_drop_rate=attn_drop_rate, use_checkpoint=use_checkpoint, init_cfg=init_cfg) self.range_mask_ratio = range_mask_ratio
[文档] def init_weights(self): """Initialize position embedding, patch embedding.""" super(MixMIMTransformer, self).init_weights() pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.absolute_pos_embed.shape[-1], cls_token=False) self.absolute_pos_embed.data.copy_(pos_embed.float()) self.apply(self._init_weights)
def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[文档] def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5): """Generate the mask for MixMIM Pretraining. Args: x (torch.Tensor): Image with data augmentation applied, which is of shape B x L x C. mask_ratio (float): The mask ratio of total patches. Defaults to 0.5. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - mask_s1 (torch.Tensor): mask with stride of self.encoder_stride // 8. - mask_s2 (torch.Tensor): mask with stride of self.encoder_stride // 4. - mask_s3 (torch.Tensor): mask with stride of self.encoder_stride // 2. - mask (torch.Tensor): mask with stride of self.encoder_stride. """ B, C, H, W = x.shape out_H = H // self.encoder_stride out_W = W // self.encoder_stride s3_H, s3_W = out_H * 2, out_W * 2 s2_H, s2_W = out_H * 4, out_W * 4 s1_H, s1_W = out_H * 8, out_W * 8 seq_l = out_H * out_W # use a shared mask for a batch images mask = torch.zeros([1, 1, seq_l], device=x.device) mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio) noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1] # ascend: small is keep, large is removed mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)] mask.scatter_(2, mask_idx, 1) mask = mask.reshape(1, 1, out_H, out_W) mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest') mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest') mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest') mask = mask.reshape(1, out_H * out_W, 1).contiguous() mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous() mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous() mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous() return mask_s1, mask_s2, mask_s3, mask
[文档] def forward(self, x: torch.Tensor, mask_ratio=0.5): """Generate features for masked images. This function generates mask and masks some patches randomly and get the hidden features for visible patches. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. Returns: Tuple[torch.Tensor, torch.Tensor]: - x (torch.Tensor): hidden features, which is of shape B x L x C. - mask_s4 (torch.Tensor): the mask tensor for the last layer. """ mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio) x, _ = self.patch_embed(x) x = x * (1. - mask_s1) + x.flip(0) * mask_s1 x = x + self.absolute_pos_embed x = self.drop_after_pos(x) for idx, layer in enumerate(self.layers): if idx == 0: x = layer(x, attn_mask=mask_s1) elif idx == 1: x = layer(x, attn_mask=mask_s2) elif idx == 2: x = layer(x, attn_mask=mask_s3) elif idx == 3: x = layer(x, attn_mask=mask_s4) x = self.norm(x) return x, mask_s4
Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.