Shortcuts

mmselfsup.models.necks.beitv2_neck 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from mmcls.models.backbones.beit import BEiTTransformerEncoderLayer
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule

from mmselfsup.registry import MODELS


[文档]@MODELS.register_module() class BEiTV2Neck(BaseModule): """Neck for BEiTV2 Pre-training. This module construct the decoder for the final prediction. Args: num_layers (int): Number of encoder layers of neck. Defaults to 2. early_layers (int): The layer index of the early output from the backbone. Defaults to 9. backbone_arch (str): Vision Transformer architecture. Defaults to base. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. layer_scale_init_value (float): The initialization value for the learnable scaling of attention and FFN. Defaults to 0.1. use_rel_pos_bias (bool): Whether to use unique relative position bias, if False, use shared relative position bias defined in backbone. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ arch_zoo = { **dict.fromkeys( ['b', 'base'], { 'embed_dims': 768, 'depth': 12, 'num_heads': 12, 'feedforward_channels': 3072, }), **dict.fromkeys( ['l', 'large'], { 'embed_dims': 1024, 'depth': 24, 'num_heads': 16, 'feedforward_channels': 4096, }), } def __init__( self, num_layers: int = 2, early_layers: int = 9, backbone_arch: str = 'base', drop_rate: float = 0., drop_path_rate: float = 0., layer_scale_init_value: float = 0.1, use_rel_pos_bias: bool = False, norm_cfg: dict = dict(type='LN', eps=1e-6), init_cfg: Optional[Union[dict, List[dict]]] = dict( type='TruncNormal', layer='Linear', std=0.02, bias=0) ) -> None: super().__init__(init_cfg=init_cfg) if isinstance(backbone_arch, str): backbone_arch = backbone_arch.lower() assert backbone_arch in set(self.arch_zoo), \ (f'Arch {backbone_arch} is not in default archs ' f'{set(self.arch_zoo)}') self.arch_settings = self.arch_zoo[backbone_arch] else: essential_keys = { 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' } assert isinstance(backbone_arch, dict) and essential_keys <= set( backbone_arch ), f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = backbone_arch # stochastic depth decay rule self.early_layers = early_layers depth = self.arch_settings['depth'] dpr = np.linspace(0, drop_path_rate, max(depth, early_layers + num_layers)) self.patch_aggregation = nn.ModuleList() for i in range(early_layers, early_layers + num_layers): _layer_cfg = dict( embed_dims=self.arch_settings['embed_dims'], num_heads=self.arch_settings['num_heads'], feedforward_channels=self. arch_settings['feedforward_channels'], drop_rate=drop_rate, drop_path_rate=dpr[i], norm_cfg=norm_cfg, layer_scale_init_value=layer_scale_init_value, window_size=None, use_rel_pos_bias=use_rel_pos_bias) self.patch_aggregation.append( BEiTTransformerEncoderLayer(**_layer_cfg)) self.rescale_patch_aggregation_init_weight() embed_dims = self.arch_settings['embed_dims'] _, norm = build_norm_layer(norm_cfg, embed_dims) self.add_module('norm', norm)
[文档] def rescale_patch_aggregation_init_weight(self): """Rescale the initialized weights.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.patch_aggregation): rescale(layer.attn.proj.weight.data, self.early_layers + layer_id + 1) rescale(layer.ffn.layers[1].weight.data, self.early_layers + layer_id + 1)
[文档] def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """Get the latent prediction and final prediction. Args: x (Tuple[torch.Tensor]): Features of tokens. rel_pos_bias (torch.Tensor): Shared relative position bias table. Returns: Tuple[torch.Tensor, torch.Tensor]: - ``x``: The final layer features from backbone, which are normed in ``BEiTV2Neck``. - ``x_cls_pt``: The early state features from backbone, which are consist of final layer cls_token and early state patch_tokens from backbone and sent to PatchAggregation layers in the neck. """ early_states, x = inputs[0], inputs[1] x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1) for layer in self.patch_aggregation: x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias) # shared norm x, x_cls_pt = self.norm(x), self.norm(x_cls_pt) # remove cls_token x = x[:, 1:] x_cls_pt = x_cls_pt[:, 1:] return x, x_cls_pt
Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.