Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.algorithms.mae 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple

import torch

from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel


[文档]@ALGORITHMS.register_module() class MAE(BaseModel): """MAE. Implementation of `Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>`_. Args: backbone (dict): Config dict for encoder. Defaults to None. neck (dict): Config dict for encoder. Defaults to None. head (dict): Config dict for loss functions. Defaults to None. init_cfg (dict, optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict, head: dict, init_cfg: Optional[dict] = None) -> None: super().__init__(init_cfg) assert backbone is not None self.backbone = build_backbone(backbone) assert neck is not None self.neck = build_neck(neck) self.neck.num_patches = self.backbone.num_patches assert head is not None self.head = build_head(head)
[文档] def init_weights(self): super().init_weights()
[文档] def extract_feat(self, img: torch.Tensor) -> Tuple[torch.Tensor]: """Function to extract features from backbone. Args: img (torch.Tensor): Input images of shape (N, C, H, W). Returns: Tuple[torch.Tensor]: backbone outputs. """ return self.backbone(img)
[文档] def forward_train(self, img: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: """Forward computation during training. Args: img (torch.Tensor): Input images of shape (N, C, H, W). kwargs: Any keyword arguments to be used to forward. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ latent, mask, ids_restore = self.backbone(img) pred = self.neck(latent, ids_restore) losses = self.head(img, pred, mask) return losses
[文档] def forward_test(self, img: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """Forward computation during testing. Args: img (torch.Tensor): Input images of shape (N, C, H, W). kwargs: Any keyword arguments to be used to forward. Returns: Tuple[torch.Tensor, torch.Tensor]: Output of model test. - mask: Mask used to mask image. - pred: The output of neck. """ latent, mask, ids_restore = self.backbone(img) pred = self.neck(latent, ids_restore) pred = self.head.unpatchify(pred) pred = torch.einsum('nchw->nhwc', pred).detach().cpu() mask = mask.detach() mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * 3) # (N, H*W, p*p*3) mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping mask = torch.einsum('nchw->nhwc', mask).detach().cpu() return mask, pred
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.