

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.models.utils.knn_classifier

# Copyright (c) Facebook, Inc. and its affiliates.

# This file is borrowed from

import torch
import torch.nn as nn

[docs]@torch.no_grad() def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): """Compute accuracy of knn classifier predictions. Args: train_features (Tensor): Extracted features in the training set. train_labels (Tensor): Labels in the training set. test_features (Tensor): Extracted features in the testing set. test_labels (Tensor): Labels in the testing set. k (int): Number of NN to use. T (float): Temperature used in the voting coefficient. num_classes (int): Number of classes. Defaults to 1000. """ top1, top5, total = 0.0, 0.0, 0 train_features = nn.functional.normalize(train_features, dim=1) test_features = nn.functional.normalize(test_features, dim=1) train_features = train_features.t() num_test_images, num_chunks = test_labels.shape[0], 100 # split all test images into several chunks to prevent out-of-memory imgs_per_chunk = num_test_images // num_chunks retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) for idx in range(0, num_test_images, imgs_per_chunk): # get the features for test images features = test_features[idx:min((idx + imgs_per_chunk), num_test_images), :] targets = test_labels[idx:min((idx + imgs_per_chunk), num_test_images)] batch_size = targets.shape[0] # calculate the dot product and compute top-k neighbors similarity =, train_features) distances, indices = similarity.topk(k, largest=True, sorted=True) candidates = train_labels.view(1, -1).expand(batch_size, -1) retrieved_neighbors = torch.gather(candidates, 1, indices) retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) distances_transform = distances.clone().div_(T).exp_() probs = torch.sum( torch.mul( retrieval_one_hot.view(batch_size, -1, num_classes), distances_transform.view(batch_size, -1, 1), ), 1, ) _, predictions = probs.sort(1, True) # find the predictions that match the target correct = predictions.eq(, 1)) top1 = top1 + correct.narrow(1, 0, 1).sum().item() top5 = top5 + correct.narrow(1, 0, min( 5, k)).sum().item() # top5 does not make sense if k < 5 total += targets.size(0) top1 = top1 * 100.0 / total top5 = top5 * 100.0 / total return top1, top5
Read the Docs v: 0.x
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.