Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.models.memories.interclr_memory
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import BaseModule, get_dist_info
from sklearn.cluster import KMeans
from mmselfsup.utils import AliasMethod
from ..builder import MEMORIES
[docs]@MEMORIES.register_module()
class InterCLRMemory(BaseModule):
"""Memory bank for InterCLR.
Args:
length (int): Number of features stored in the memory bank.
feat_dim (int): Dimension of stored features.
momentum (float): Momentum coefficient for updating features.
num_classes (int): Number of clusters.
min_cluster (int): Minimal cluster size.
"""
def __init__(self, length, feat_dim, momentum, num_classes, min_cluster,
**kwargs):
super(InterCLRMemory, self).__init__()
self.rank, self.num_replicas = get_dist_info()
self.feature_bank = torch.zeros((length, feat_dim),
dtype=torch.float32,
device='cuda')
self.label_bank = torch.zeros((length, ), dtype=torch.long)
self.centroids = torch.zeros((num_classes, feat_dim),
dtype=torch.float32,
device='cuda')
self.kmeans = KMeans(n_clusters=2, random_state=0, max_iter=20)
self.feat_dim = feat_dim
self.initialized = False
self.momentum = momentum
self.num_classes = num_classes
self.min_cluster = min_cluster
self.multinomial = AliasMethod(torch.ones(length))
self.multinomial.cuda()
self.debug = kwargs.get('debug', False)
def init_memory(self, feature, label):
self.initialized = True
self.label_bank.copy_(torch.from_numpy(label).long())
# make sure no empty clusters
assert (np.bincount(label, minlength=self.num_classes) != 0).all()
if self.rank == 0:
feature /= (np.linalg.norm(feature, axis=1).reshape(-1, 1) + 1e-12)
self.feature_bank.copy_(torch.from_numpy(feature))
centroids = self._compute_centroids()
self.centroids.copy_(centroids)
dist.broadcast(self.centroids, 0)
[docs] def assign_label(self, label):
"""Assign offline labels for each epoch."""
self.label_bank.copy_(torch.from_numpy(label).long())
# make sure no empty clusters
assert (np.bincount(label, minlength=self.num_classes) != 0).all()
def broadcast_feature_bank(self):
assert self.initialized
dist.broadcast(self.feature_bank, 0)
def _compute_centroids_ind(self, cinds):
"""Compute a few centroids."""
assert self.rank == 0
num = len(cinds)
centroids = torch.zeros((num, self.feat_dim),
dtype=torch.float32,
device='cuda')
for i, c in enumerate(cinds):
ind = np.where(self.label_bank.numpy() == c)[0]
centroids[i, :] = self.feature_bank[ind, :].mean(dim=0)
return centroids
def _compute_centroids(self):
"""Compute all non-empty centroids."""
assert self.rank == 0
label_bank_np = self.label_bank.numpy()
argl = np.argsort(label_bank_np)
sortl = label_bank_np[argl]
diff_pos = np.where(sortl[1:] - sortl[:-1] != 0)[0] + 1
start = np.insert(diff_pos, 0, 0)
end = np.insert(diff_pos, len(diff_pos), len(label_bank_np))
class_start = sortl[start]
# keep empty class centroids unchanged
centroids = self.centroids.clone()
for i, st, ed in zip(class_start, start, end):
centroids[i, :] = self.feature_bank[argl[st:ed], :].mean(dim=0)
return centroids
def _gather(self, ind, feature):
"""Gather indices and features."""
ind_gathered = [
torch.ones_like(ind).cuda() for _ in range(self.num_replicas)
]
feature_gathered = [
torch.ones_like(feature).cuda() for _ in range(self.num_replicas)
]
dist.all_gather(ind_gathered, ind)
dist.all_gather(feature_gathered, feature)
ind_gathered = torch.cat(ind_gathered, dim=0)
feature_gathered = torch.cat(feature_gathered, dim=0)
return ind_gathered, feature_gathered
[docs] def update_simple_memory(self, ind, feature): # ind, feature: cuda tensor
"""Update features in the memory bank."""
feature_norm = nn.functional.normalize(feature) # normalize
ind, feature_norm = self._gather(
ind, feature_norm) # ind: (N*w), feature: (N*w)xk, cuda tensor
if self.rank == 0:
feature_old = self.feature_bank[ind, :]
feature_new = (1 - self.momentum) * feature_old + \
self.momentum * feature_norm
feature_norm = nn.functional.normalize(feature_new)
self.feature_bank[ind, :] = feature_norm
[docs] def update_samples_memory(self, ind, feature): # ind, feature: cuda tensor
"""Update features and labels in the memory bank."""
feature_norm = nn.functional.normalize(feature) # normalize
ind, feature_norm = self._gather(
ind, feature_norm) # ind: (N*w), feature: (N*w)xk, cuda tensor
if self.rank == 0:
feature_old = self.feature_bank[ind, :]
feature_new = (1 - self.momentum) * feature_old + \
self.momentum * feature_norm
feature_norm = nn.functional.normalize(feature_new)
self.feature_bank[ind, :] = feature_norm
dist.barrier()
dist.broadcast(feature_norm, 0)
# compute new labels
ind = ind.cpu()
similarity_to_centroids = torch.mm(self.centroids,
feature_norm.permute(1, 0)) # CxN
newlabel = similarity_to_centroids.argmax(dim=0) # cuda tensor
newlabel_cpu = newlabel.cpu()
change_ratio = (newlabel_cpu != self.label_bank[ind]
).sum().float().cuda() / float(newlabel_cpu.shape[0])
self.label_bank[ind] = newlabel_cpu.clone() # copy to cpu
return change_ratio
[docs] def deal_with_small_clusters(self):
"""Deal with small clusters."""
# check empty class
histogram = np.bincount(
self.label_bank.numpy(), minlength=self.num_classes)
small_clusters = np.where(histogram < self.min_cluster)[0].tolist()
if self.debug and self.rank == 0:
print(f'mincluster: {histogram.min()}, '
f'num of small class: {len(small_clusters)}')
if len(small_clusters) == 0:
return
# re-assign samples in small clusters to make them empty
for s in small_clusters:
ind = np.where(self.label_bank.numpy() == s)[0]
if len(ind) > 0:
inclusion = torch.from_numpy(
np.setdiff1d(
np.arange(self.num_classes),
np.array(small_clusters),
assume_unique=True)).cuda()
if self.rank == 0:
target_ind = torch.mm(
self.centroids[inclusion, :],
self.feature_bank[ind, :].permute(1, 0)).argmax(dim=0)
target = inclusion[target_ind]
else:
target = torch.zeros((ind.shape[0], ),
dtype=torch.int64,
device='cuda')
dist.all_reduce(target)
self.label_bank[ind] = torch.from_numpy(target.cpu().numpy())
# deal with empty cluster
self._redirect_empty_clusters(small_clusters)
[docs] def update_centroids_memory(self, cinds=None):
"""Update centroids in the memory bank."""
if self.rank == 0:
if self.debug:
print('updating centroids ...')
if cinds is None:
centroids = self._compute_centroids()
self.centroids.copy_(centroids)
else:
centroids = self._compute_centroids_ind(cinds)
self.centroids[torch.cuda.LongTensor(cinds), :] = centroids
dist.broadcast(self.centroids, 0)
def _partition_max_cluster(self, max_cluster):
"""Partition the largest cluster into two sub-clusters."""
assert self.rank == 0
max_cluster_inds = np.where(self.label_bank.numpy() == max_cluster)[0]
assert len(max_cluster_inds) >= 2
max_cluster_features = self.feature_bank[
max_cluster_inds, :].cpu().numpy()
if np.any(np.isnan(max_cluster_features)):
raise Exception('Has nan in features.')
kmeans_ret = self.kmeans.fit(max_cluster_features)
sub_cluster1_ind = max_cluster_inds[kmeans_ret.labels_ == 0]
sub_cluster2_ind = max_cluster_inds[kmeans_ret.labels_ == 1]
if not (len(sub_cluster1_ind) > 0 and len(sub_cluster2_ind) > 0):
print(
'Warning: kmeans partition fails, resort to random partition.')
sub_cluster1_ind = np.random.choice(
max_cluster_inds, len(max_cluster_inds) // 2, replace=False)
sub_cluster2_ind = np.setdiff1d(
max_cluster_inds, sub_cluster1_ind, assume_unique=True)
return sub_cluster1_ind, sub_cluster2_ind
def _redirect_empty_clusters(self, empty_clusters):
"""Re-direct empty clusters."""
for e in empty_clusters:
assert (self.label_bank != e).all().item(), \
f'Cluster #{e} is not an empty cluster.'
max_cluster = np.bincount(
self.label_bank.numpy(),
minlength=self.num_classes).argmax().item()
# gather partitioning indices
if self.rank == 0:
sub_cluster1_ind, sub_cluster2_ind = \
self._partition_max_cluster(max_cluster)
size1 = torch.cuda.LongTensor([len(sub_cluster1_ind)])
size2 = torch.cuda.LongTensor([len(sub_cluster2_ind)])
sub_cluster1_ind_tensor = torch.from_numpy(
sub_cluster1_ind).long().cuda()
sub_cluster2_ind_tensor = torch.from_numpy(
sub_cluster2_ind).long().cuda()
else:
size1 = torch.cuda.LongTensor([0])
size2 = torch.cuda.LongTensor([0])
dist.all_reduce(size1)
dist.all_reduce(size2)
if self.rank != 0:
sub_cluster1_ind_tensor = torch.zeros((size1, ),
dtype=torch.int64,
device='cuda')
sub_cluster2_ind_tensor = torch.zeros((size2, ),
dtype=torch.int64,
device='cuda')
dist.broadcast(sub_cluster1_ind_tensor, 0)
dist.broadcast(sub_cluster2_ind_tensor, 0)
if self.rank != 0:
sub_cluster1_ind = sub_cluster1_ind_tensor.cpu().numpy()
sub_cluster2_ind = sub_cluster2_ind_tensor.cpu().numpy()
# reassign samples in partition #2 to the empty class
self.label_bank[sub_cluster2_ind] = e
# update centroids of max_cluster and e
self.update_centroids_memory([max_cluster, e])