Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.heads.latent_pred_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, get_dist_info

from ..builder import HEADS, build_neck


[文档]@HEADS.register_module() class LatentPredictHead(BaseModule): """Head for latent feature prediction. This head builds a predictor, which can be any registered neck component. For example, BYOL and SimSiam call this head and build NonLinearNeck. It also implements similarity loss between two forward features. Args: predictor (dict): Config dict for the predictor. """ def __init__(self, predictor: dict) -> None: super(LatentPredictHead, self).__init__() self.predictor = build_neck(predictor)
[文档] def forward(self, input: torch.Tensor, target: torch.Tensor) -> dict: """Forward head. Args: input (Tensor): NxC input features. target (Tensor): NxC target features. Returns: dict[str, Tensor]: A dictionary of loss components. """ pred = self.predictor([input])[0] target = target.detach() pred_norm = nn.functional.normalize(pred, dim=1) target_norm = nn.functional.normalize(target, dim=1) loss = -(pred_norm * target_norm).sum(dim=1).mean() return dict(loss=loss)
[文档]@HEADS.register_module class LatentClsHead(BaseModule): """Head for latent feature classification. Args: in_channels (int): Number of input channels. num_classes (int): Number of classes. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__( self, in_channels: int, num_classes: int, init_cfg: dict = dict( type='Normal', std=0.01, layer='Linear', ) ) -> None: super(LatentClsHead, self).__init__(init_cfg) self.predictor = nn.Linear(in_channels, num_classes) self.criterion = nn.CrossEntropyLoss()
[文档] def forward(self, input: torch.Tensor, target: torch.Tensor) -> dict: """Forward head. Args: input (Tensor): NxC input features. target (Tensor): NxC target features. Returns: dict[str, Tensor]: A dictionary of loss components. """ pred = self.predictor(input) with torch.no_grad(): label = torch.argmax(self.predictor(target), dim=1).detach() loss = self.criterion(pred, label) return dict(loss=loss)
[文档]@HEADS.register_module() class LatentCrossCorrelationHead(BaseModule): """Head for latent feature cross correlation. Part of the code is borrowed from: `https://github.com/facebookresearch/barlowtwins/blob/main/main.py>`_. Args: in_channels (int): Number of input channels. lambd (float): Weight on off-diagonal terms. Defaults to 0.0051. """ def __init__(self, in_channels: int, lambd: float = 0.0051) -> None: super(LatentCrossCorrelationHead, self).__init__() self.lambd = lambd _, self.world_size = get_dist_info() self.bn = nn.BatchNorm1d(in_channels, affine=False)
[文档] def forward(self, input: torch.Tensor, target: torch.Tensor) -> dict: """Forward head. Args: input (Tensor): NxC input features. target (Tensor): NxC target features. Returns: dict[str, Tensor]: A dictionary of loss components. """ # cross-correlation matrix cross_correlation_matrix = self.bn(input).T @ self.bn(target) cross_correlation_matrix.div_(input.size(0) * self.world_size) if torch.distributed.is_initialized(): torch.distributed.all_reduce(cross_correlation_matrix) # loss on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_( 2).sum() off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum() loss = on_diag + self.lambd * off_diag return dict(loss=loss)
[文档] def off_diagonal(self, x: torch.Tensor) -> torch.Tensor: """Rreturn a flattened view of the off-diagonal elements of a square matrix.""" n, m = x.shape assert n == m return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.