Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.utils.accuracy 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule


[文档]def accuracy(pred, target, topk=1): """Compute accuracy of predictions. Args: pred (Tensor): The output of the model. target (Tensor): The labels of data. topk (int | list[int]): Top-k metric selection. Defaults to 1. """ assert isinstance(topk, (int, tuple)) if isinstance(topk, int): topk = (topk, ) return_single = True else: return_single = False maxk = max(topk) _, pred_label = pred.topk(maxk, dim=1) pred_label = pred_label.t() correct = pred_label.eq(target.contiguous().view(1, -1).expand_as(pred_label)) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum( 0, keepdim=True) res.append(correct_k.mul_(100.0 / pred.size(0))) return res[0] if return_single else res
[文档]class Accuracy(BaseModule): """Implementation of accuracy computation.""" def __init__(self, topk=(1, )): super().__init__() self.topk = topk
[文档] def forward(self, pred, target): return accuracy(pred, target, self.topk)
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.