

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.necks.mae_neck 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule

from ..builder import NECKS
from ..utils import build_2d_sincos_position_embedding

[文档]@NECKS.register_module() class MAEPretrainDecoder(BaseModule): """Decoder for MAE Pre-training. Args: num_patches (int): The number of total patches. Defaults to 196. patch_size (int): Image patch size. Defaults to 16. in_chans (int): The channel of input image. Defaults to 3. embed_dim (int): Encoder's embedding dimension. Defaults to 1024. decoder_embed_dim (int): Decoder's embedding dimension. Defaults to 512. decoder_depth (int): The depth of decoder. Defaults to 8. decoder_num_heads (int): Number of attention heads of decoder. Defaults to 16. mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. Defaults to 4. norm_cfg (dict): Normalization layer. Defaults to LayerNorm. Some of the code is borrowed from ``. Example: >>> from mmselfsup.models import MAEPretrainDecoder >>> import torch >>> self = MAEPretrainDecoder() >>> self.eval() >>> inputs = torch.rand(1, 50, 1024) >>> ids_restore = torch.arange(0, 196).unsqueeze(0) >>> level_outputs = self.forward(inputs, ids_restore) >>> print(tuple(level_outputs.shape)) (1, 196, 768) """ def __init__(self, num_patches=196, patch_size=16, in_chans=3, embed_dim=1024, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_cfg=dict(type='LN', eps=1e-6)): super(MAEPretrainDecoder, self).__init__() self.num_patches = num_patches self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, decoder_embed_dim), requires_grad=False) self.decoder_blocks = nn.ModuleList([ TransformerEncoderLayer( decoder_embed_dim, decoder_num_heads, int(mlp_ratio * decoder_embed_dim), qkv_bias=True, norm_cfg=norm_cfg) for _ in range(decoder_depth) ]) self.decoder_norm_name, decoder_norm = build_norm_layer( norm_cfg, decoder_embed_dim, postfix=1) self.add_module(self.decoder_norm_name, decoder_norm) self.decoder_pred = nn.Linear( decoder_embed_dim, patch_size**2 * in_chans, bias=True)
[文档] def init_weights(self): super(MAEPretrainDecoder, self).init_weights() # initialize position embedding of MAE decoder decoder_pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.decoder_pos_embed.shape[-1], cls_token=True) torch.nn.init.normal_(self.mask_token, std=.02) self.apply(self._init_weights)
def _init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @property def decoder_norm(self): return getattr(self, self.decoder_norm_name)
[文档] def forward(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ =[x[:, 1:, :], mask_tokens], dim=1) x_ = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) x =[x[:, :1, :], x_], dim=1) # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] return x
