Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.heads.cls_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import BaseModule

from ..builder import HEADS
from ..utils import accuracy


[文档]@HEADS.register_module() class ClsHead(BaseModule): """Simplest classifier head, with only one fc layer. Args: with_avg_pool (bool): Whether to apply the average pooling after neck. Defaults to False. in_channels (int): Number of input channels. Defaults to 2048. num_classes (int): Number of classes. Defaults to 1000. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, with_avg_pool=False, in_channels=2048, num_classes=1000, vit_backbone=False, init_cfg=[ dict(type='Normal', std=0.01, layer='Linear'), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ]): super(ClsHead, self).__init__(init_cfg) self.with_avg_pool = with_avg_pool self.in_channels = in_channels self.num_classes = num_classes self.vit_backbone = vit_backbone self.criterion = nn.CrossEntropyLoss() if self.with_avg_pool: self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc_cls = nn.Linear(in_channels, num_classes)
[文档] def forward(self, x): """Forward head. Args: x (list[Tensor] | tuple[Tensor]): Feature maps of backbone, each tensor has shape (N, C, H, W). Returns: list[Tensor]: A list of class scores. """ assert isinstance(x, (tuple, list)) and len(x) == 1 x = x[0] if self.vit_backbone: x = x[-1] if self.with_avg_pool: assert x.dim() == 4, \ f'Tensor must has 4 dims, got: {x.dim()}' x = self.avg_pool(x) x = x.view(x.size(0), -1) cls_score = self.fc_cls(x) return [cls_score]
[文档] def loss(self, cls_score, labels): """Compute the loss.""" losses = dict() assert isinstance(cls_score, (tuple, list)) and len(cls_score) == 1 losses['loss'] = self.criterion(cls_score[0], labels) losses['acc'] = accuracy(cls_score[0], labels) return losses
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.