注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.utils.accuracy 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule
[文档]def accuracy(pred, target, topk=1):
"""Compute accuracy of predictions.
Args:
pred (Tensor): The output of the model.
target (Tensor): The labels of data.
topk (int | list[int]): Top-k metric selection. Defaults to 1.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
_, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t()
correct = pred_label.eq(target.contiguous().view(1,
-1).expand_as(pred_label))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(
0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res