Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.algorithms.simmim 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch

from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel


[文档]@ALGORITHMS.register_module() class SimMIM(BaseModel): """SimMIM. Implementation of `SimMIM: A Simple Framework for Masked Image Modeling <https://arxiv.org/abs/2111.09886>`_. Args: backbone (dict): Config dict for encoder. Defaults to None. neck (dict): Config dict for encoder. Defaults to None. head (dict): Config dict for loss functions. Defaults to None. init_cfg (dict, optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict, head: dict, init_cfg: Optional[dict] = None) -> None: super(SimMIM, self).__init__(init_cfg) assert backbone is not None self.backbone = build_backbone(backbone) assert neck is not None self.neck = build_neck(neck) assert head is not None self.head = build_head(head)
[文档] def extract_feat(self, img: torch.Tensor) -> tuple: """Function to extract features from backbone. Args: img (torch.Tensor): Input images of shape (N, C, H, W). Returns: tuple[Tensor]: Latent representations of images. """ return self.backbone(img)
[文档] def forward_train(self, x: List[torch.Tensor], **kwargs) -> dict: """Forward the masked image and get the reconstruction loss. Args: x (List[torch.Tensor, torch.Tensor]): Images and masks. Returns: dict: Reconstructed loss. """ img, mask = x img_latent = self.backbone(img, mask) img_rec = self.neck(img_latent[0]) losses = self.head(img, img_rec, mask) return losses
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.