注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.heads.cae_head 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
import torch
from mmcv.runner import BaseModule
from torch import nn
from ..builder import HEADS
from ..utils import Encoder
[文档]@HEADS.register_module()
class CAEHead(BaseModule):
"""Pretrain Head for CAE.
Compute the align loss and the main loss. In addition, this head also
generates the prediction target generated by dalle.
Args:
tokenizer_path (str): The path of the tokenizer.
lambd (float): The weight for the align loss.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
tokenizer_path: str,
lambd: float,
init_cfg: dict = None) -> None:
super(CAEHead, self).__init__(init_cfg=init_cfg)
self.tokenizer_path = tokenizer_path
self.lambd = lambd
self.encoder = self._load_encoder()
self.loss_cross_entropy = nn.CrossEntropyLoss()
self.loss_mse = nn.MSELoss()
def _load_encoder(self) -> nn.Module:
encoder = Encoder()
if os.path.exists(self.tokenizer_path):
state_dict = torch.load(self.tokenizer_path)
encoder.load_state_dict(state_dict)
else:
warnings.warn(
f'Do not find {self.tokenizer_path}, please download from https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth' # noqa: E501
)
return encoder
@torch.no_grad()
def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor:
logits = self.encoder(img_target)
target = torch.argmax(logits, dim=1)
return target.flatten(1)
[文档] def forward(self, img_target: torch.Tensor, outputs: torch.Tensor,
latent_pred: torch.Tensor, latent_target: torch.Tensor,
mask: torch.Tensor) -> dict:
losses = dict()
target = self._generate_target(img_target)
target = target[mask]
loss_main = self.loss_cross_entropy(outputs, target)
loss_align = self.loss_mse(latent_pred,
latent_target.detach()) * self.lambd
losses['loss'] = loss_main + loss_align
losses['main'] = loss_main
losses['align'] = loss_align
return losses