Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.heads.cae_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings

import torch
from mmcv.runner import BaseModule
from torch import nn

from ..builder import HEADS
from ..utils import Encoder


[文档]@HEADS.register_module() class CAEHead(BaseModule): """Pretrain Head for CAE. Compute the align loss and the main loss. In addition, this head also generates the prediction target generated by dalle. Args: tokenizer_path (str): The path of the tokenizer. lambd (float): The weight for the align loss. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, tokenizer_path: str, lambd: float, init_cfg: dict = None) -> None: super(CAEHead, self).__init__(init_cfg=init_cfg) self.tokenizer_path = tokenizer_path self.lambd = lambd self.encoder = self._load_encoder() self.loss_cross_entropy = nn.CrossEntropyLoss() self.loss_mse = nn.MSELoss() def _load_encoder(self) -> nn.Module: encoder = Encoder() if os.path.exists(self.tokenizer_path): state_dict = torch.load(self.tokenizer_path) encoder.load_state_dict(state_dict) else: warnings.warn( f'Do not find {self.tokenizer_path}, please download from https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth' # noqa: E501 ) return encoder @torch.no_grad() def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor: logits = self.encoder(img_target) target = torch.argmax(logits, dim=1) return target.flatten(1)
[文档] def forward(self, img_target: torch.Tensor, outputs: torch.Tensor, latent_pred: torch.Tensor, latent_target: torch.Tensor, mask: torch.Tensor) -> dict: losses = dict() target = self._generate_target(img_target) target = target[mask] loss_main = self.loss_cross_entropy(outputs, target) loss_align = self.loss_mse(latent_pred, latent_target.detach()) * self.lambd losses['loss'] = loss_main + loss_align losses['main'] = loss_main losses['align'] = loss_align return losses
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.