Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.models.heads.cae_head
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
import torch
from mmcv.runner import BaseModule
from torch import nn
from ..builder import HEADS
from ..utils import Encoder
[docs]@HEADS.register_module()
class CAEHead(BaseModule):
"""Pretrain Head for CAE.
Compute the align loss and the main loss. In addition, this head also
generates the prediction target generated by dalle.
Args:
tokenizer_path (str): The path of the tokenizer.
lambd (float): The weight for the align loss.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
tokenizer_path: str,
lambd: float,
init_cfg: dict = None) -> None:
super(CAEHead, self).__init__(init_cfg=init_cfg)
self.tokenizer_path = tokenizer_path
self.lambd = lambd
self.encoder = self._load_encoder()
self.loss_cross_entropy = nn.CrossEntropyLoss()
self.loss_mse = nn.MSELoss()
def _load_encoder(self) -> nn.Module:
encoder = Encoder()
if os.path.exists(self.tokenizer_path):
state_dict = torch.load(self.tokenizer_path)
encoder.load_state_dict(state_dict)
else:
warnings.warn(
f'Do not find {self.tokenizer_path}, please download from https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth' # noqa: E501
)
return encoder
@torch.no_grad()
def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor:
logits = self.encoder(img_target)
target = torch.argmax(logits, dim=1)
return target.flatten(1)
[docs] def forward(self, img_target: torch.Tensor, outputs: torch.Tensor,
latent_pred: torch.Tensor, latent_target: torch.Tensor,
mask: torch.Tensor) -> dict:
losses = dict()
target = self._generate_target(img_target)
target = target[mask]
loss_main = self.loss_cross_entropy(outputs, target)
loss_align = self.loss_mse(latent_pred,
latent_target.detach()) * self.lambd
losses['loss'] = loss_main + loss_align
losses['main'] = loss_main
losses['align'] = loss_align
return losses