Shortcuts

Source code for mmselfsup.models.heads.cae_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union

import torch
from mmengine.model import BaseModule

from mmselfsup.registry import MODELS


[docs]@MODELS.register_module() class CAEHead(BaseModule): """Pretrain Head for CAE. Compute the align loss and the main loss. In addition, this head also generates the prediction target generated by dalle. Args: loss (dict): The config of loss. tokenizer_path (str): The path of the tokenizer. init_cfg (dict or List[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, loss: dict, init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: super().__init__(init_cfg=init_cfg) self.loss = MODELS.build(loss) @torch.no_grad() def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: """Generate the reconstruction target. Args: logits_target (torch.Tensor): The logits generated by DALL-E.s Returns: torch.Tensor: The logits target. """ target = torch.argmax(logits_target, dim=1) return target.flatten(1)
[docs] def forward(self, logits: torch.Tensor, logits_target: torch.Tensor, latent_pred: torch.Tensor, latent_target: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Generate loss. Args: logits (torch.Tensor): Logits generated by decoder. logits_target (img_target): Target generated by dalle for decoder prediction. latent_pred (torch.Tensor): Latent prediction by regressor. latent_target (torch.Tensor): Target for latent prediction, generated by teacher. Returns: Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. - loss_main (torch.Tensor): Cross entropy loss. - loss_align (torch.Tensor): MSE loss. """ target = self._generate_target(logits_target) # target features target = target[mask].detach() # loss main for decoder, loss align for regressor loss_main, loss_align = self.loss(logits, target, latent_pred, latent_target) return (loss_main, loss_align)
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.