Shortcuts

mmselfsup.models.losses.simmim_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.

import torch
from mmengine.model import BaseModule
from torch.nn import functional as F

from mmselfsup.registry import MODELS


[文档]@MODELS.register_module() class SimMIMReconstructionLoss(BaseModule): """Loss function for MAE. Compute the loss in masked region. Args: encoder_in_channels (int): Number of input channels for encoder. """ def __init__(self, encoder_in_channels: int) -> None: super().__init__() self.encoder_in_channels = encoder_in_channels
[文档] def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Forward function of MAE Loss. Args: pred (torch.Tensor): The reconstructed image. target (torch.Tensor): The target image. mask (torch.Tensor): The mask of the target image. Returns: torch.Tensor: The reconstruction loss. """ loss_rec = F.l1_loss(target, pred, reduction='none') loss = (loss_rec * mask).sum() / (mask.sum() + 1e-5) / self.encoder_in_channels return loss
Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.