Shortcuts

Note

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.models.heads.cae_head

# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings

import torch
from mmcv.runner import BaseModule
from torch import nn

from ..builder import HEADS
from ..utils import Encoder


[docs]@HEADS.register_module() class CAEHead(BaseModule): """Pretrain Head for CAE. Compute the align loss and the main loss. In addition, this head also generates the prediction target generated by dalle. Args: tokenizer_path (str): The path of the tokenizer. lambd (float): The weight for the align loss. init_cfg (dict, optional): Initialization config dict. Defaults to None. """ def __init__(self, tokenizer_path: str, lambd: float, init_cfg: dict = None) -> None: super(CAEHead, self).__init__(init_cfg=init_cfg) self.tokenizer_path = tokenizer_path self.lambd = lambd self.encoder = self._load_encoder() self.loss_cross_entropy = nn.CrossEntropyLoss() self.loss_mse = nn.MSELoss() def _load_encoder(self) -> nn.Module: encoder = Encoder() if os.path.exists(self.tokenizer_path): state_dict = torch.load(self.tokenizer_path) encoder.load_state_dict(state_dict) else: warnings.warn( f'Do not find {self.tokenizer_path}, please download from https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth' # noqa: E501 ) return encoder @torch.no_grad() def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor: logits = self.encoder(img_target) target = torch.argmax(logits, dim=1) return target.flatten(1)
[docs] def forward(self, img_target: torch.Tensor, outputs: torch.Tensor, latent_pred: torch.Tensor, latent_target: torch.Tensor, mask: torch.Tensor) -> dict: losses = dict() target = self._generate_target(img_target) target = target[mask] loss_main = self.loss_cross_entropy(outputs, target) loss_align = self.loss_mse(latent_pred, latent_target.detach()) * self.lambd losses['loss'] = loss_main + loss_align losses['main'] = loss_main losses['align'] = loss_align return losses
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.