Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.models.utils.knn_classifier
# Copyright (c) Facebook, Inc. and its affiliates.
# This file is borrowed from
# https://github.com/facebookresearch/dino/blob/main/eval_knn.py
import torch
import torch.nn as nn
[docs]@torch.no_grad()
def knn_classifier(train_features,
train_labels,
test_features,
test_labels,
k,
T,
num_classes=1000):
"""Compute accuracy of knn classifier predictions.
Args:
train_features (Tensor): Extracted features in the training set.
train_labels (Tensor): Labels in the training set.
test_features (Tensor): Extracted features in the testing set.
test_labels (Tensor): Labels in the testing set.
k (int): Number of NN to use.
T (float): Temperature used in the voting coefficient.
num_classes (int): Number of classes. Defaults to 1000.
"""
top1, top5, total = 0.0, 0.0, 0
train_features = nn.functional.normalize(train_features, dim=1)
test_features = nn.functional.normalize(test_features, dim=1)
train_features = train_features.t()
num_test_images, num_chunks = test_labels.shape[0], 100
# split all test images into several chunks to prevent out-of-memory
imgs_per_chunk = num_test_images // num_chunks
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
for idx in range(0, num_test_images, imgs_per_chunk):
# get the features for test images
features = test_features[idx:min((idx +
imgs_per_chunk), num_test_images), :]
targets = test_labels[idx:min((idx + imgs_per_chunk), num_test_images)]
batch_size = targets.shape[0]
# calculate the dot product and compute top-k neighbors
similarity = torch.mm(features, train_features)
distances, indices = similarity.topk(k, largest=True, sorted=True)
candidates = train_labels.view(1, -1).expand(batch_size, -1)
retrieved_neighbors = torch.gather(candidates, 1, indices)
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
distances_transform = distances.clone().div_(T).exp_()
probs = torch.sum(
torch.mul(
retrieval_one_hot.view(batch_size, -1, num_classes),
distances_transform.view(batch_size, -1, 1),
),
1,
)
_, predictions = probs.sort(1, True)
# find the predictions that match the target
correct = predictions.eq(targets.data.view(-1, 1))
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
top5 = top5 + correct.narrow(1, 0, min(
5, k)).sum().item() # top5 does not make sense if k < 5
total += targets.size(0)
top1 = top1 * 100.0 / total
top5 = top5 * 100.0 / total
return top1, top5