Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.algorithms.maskfeat 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch

from ..builder import ALGORITHMS, build_backbone, build_head
from ..utils.hog_layer import HOGLayerC
from .base import BaseModel


[文档]@ALGORITHMS.register_module() class MaskFeat(BaseModel): """MaskFeat. Implementation of `Masked Feature Prediction for Self-Supervised Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_. Args: backbone (dict): Config dict for encoder. head (dict): Config dict for loss functions. hog_para (dict): Config dict for hog layer. dict['nbins', int]: Number of bin. Defaults to 9. dict['pool', float]: Number of cell. Defaults to 8. dict['gaussian_window', int]: Size of gaussian kernel. Defaults to 16. init_cfg (dict): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, head: dict, hog_para: dict, init_cfg: Optional[dict] = None) -> None: super().__init__(init_cfg) assert backbone is not None self.backbone = build_backbone(backbone) assert head is not None self.head = build_head(head) assert hog_para is not None self.hog_layer = HOGLayerC(**hog_para)
[文档] def extract_feat(self, input: List[torch.Tensor]) -> torch.Tensor: """Function to extract features from backbone. Args: input (List[torch.Tensor, torch.Tensor]): Input images and masks. Returns: tuple[Tensor]: backbone outputs. """ img = input[0] mask = input[1] return self.backbone(img, mask)
[文档] def forward_train(self, input: List[torch.Tensor], **kwargs) -> dict: """Forward computation during training. Args: input (List[torch.Tensor, torch.Tensor]): Input images and masks. kwargs: Any keyword arguments to be used to forward. Returns: dict[str, Tensor]: A dictionary of loss components. """ img = input[0] mask = input[1] hog = self.hog_layer(img) latent = self.backbone(img, mask) losses = self.head(latent, hog, mask) return losses
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.