Shortcuts

Source code for mmselfsup.utils.distributed_sinkhorn

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.

# This file is modified from
# https://github.com/facebookresearch/swav/blob/main/main_swav.py
import torch
from mmengine.dist import all_reduce


[docs]@torch.no_grad() def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int, world_size: int, epsilon: float) -> torch.Tensor: """Apply the distributed sinknorn optimization on the scores matrix to find the assignments. Args: out (torch.Tensor): The scores matrix sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp algorithm. world_size (int): The world size of the process group. epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. Returns: torch.Tensor: Output of sinkhorn algorithm. """ eps_num_stab = 1e-12 Q = torch.exp(out / epsilon).t( ) # Q is K-by-B for consistency with notations from our paper B = Q.shape[1] * world_size # number of samples to assign K = Q.shape[0] # how many prototypes # make the matrix sums to 1 sum_Q = torch.sum(Q) all_reduce(sum_Q) Q /= sum_Q for it in range(sinkhorn_iterations): # normalize each row: total weight per prototype must be 1/K u = torch.sum(Q, dim=1, keepdim=True) if len(torch.nonzero(u == 0)) > 0: Q += eps_num_stab u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype) all_reduce(u) Q /= u Q /= K # normalize each column: total weight per sample must be 1/B Q /= torch.sum(Q, dim=0, keepdim=True) Q /= B Q *= B # the columns must sum to 1 so that Q is an assignment return Q.t()
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.