Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.algorithms.mocov3 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel


[文档]@ALGORITHMS.register_module() class MoCoV3(BaseModel): """MoCo v3. Implementation of `An Empirical Study of Training Self-Supervised Vision Transformers <https://arxiv.org/abs/2104.02057>`_. Args: backbone (dict): Config dict for module of backbone. neck (dict): Config dict for module of deep features to compact feature vectors. Defaults to None. head (dict): Config dict for module of loss functions. Defaults to None. base_momentum (float): Momentum coefficient for the momentum-updated encoder. Defaults to 0.99. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None """ def __init__(self, backbone, neck, head, base_momentum=0.99, init_cfg=None, **kwargs): super(MoCoV3, self).__init__(init_cfg) assert neck is not None self.base_encoder = nn.Sequential( build_backbone(backbone), build_neck(neck)) self.momentum_encoder = nn.Sequential( build_backbone(backbone), build_neck(neck)) self.backbone = self.base_encoder[0] self.neck = self.base_encoder[1] assert head is not None self.head = build_head(head) self.base_momentum = base_momentum self.momentum = base_momentum
[文档] def init_weights(self): """Initialize base_encoder with init_cfg defined in backbone.""" super(MoCoV3, self).init_weights() for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): param_m.data.copy_(param_b.data) param_m.requires_grad = False
[文档] @torch.no_grad() def momentum_update(self): """Momentum update of the momentum encoder.""" for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): param_m.data = param_m.data * self.momentum + param_b.data * ( 1. - self.momentum)
[文档] def extract_feat(self, img): """Function to extract features from backbone. Args: img (Tensor): Input images. Typically these should be mean centered and std scaled. Returns: tuple[Tensor]: backbone outputs. """ x = self.backbone(img) return x
[文档] def forward_train(self, img, **kwargs): """Forward computation during training. Args: img (list[Tensor]): A list of input images. Typically these should be mean centered and std scaled. Returns: dict[str, Tensor]: A dictionary of loss components. """ assert isinstance(img, list) view_1 = img[0].cuda(non_blocking=True) view_2 = img[1].cuda(non_blocking=True) # compute query features, [N, C] each q1 = self.base_encoder(view_1)[0] q2 = self.base_encoder(view_2)[0] # compute key features, [N, C] each, no gradient with torch.no_grad(): # here we use hook to update momentum encoder, which is a little # bit different with the official version but it has negligible # influence on the results k1 = self.momentum_encoder(view_1)[0] k2 = self.momentum_encoder(view_2)[0] losses = self.head(q1, k2)['loss'] + self.head(q2, k1)['loss'] return dict(loss=losses)
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.