Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.utils.distributed_sinkhorn 源代码

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.

# This file is modified from
# https://github.com/facebookresearch/swav/blob/main/main_swav.py

import torch
import torch.distributed as dist


[文档]@torch.no_grad() def distributed_sinkhorn(out, sinkhorn_iterations, world_size, epsilon): """Apply the distributed sinknorn optimization on the scores matrix to find the assignments.""" eps_num_stab = 1e-12 Q = torch.exp(out / epsilon).t( ) # Q is K-by-B for consistency with notations from our paper B = Q.shape[1] * world_size # number of samples to assign K = Q.shape[0] # how many prototypes # make the matrix sums to 1 sum_Q = torch.sum(Q) if dist.is_initialized(): dist.all_reduce(sum_Q) Q /= sum_Q for it in range(sinkhorn_iterations): # normalize each row: total weight per prototype must be 1/K u = torch.sum(Q, dim=1, keepdim=True) if len(torch.nonzero(u == 0)) > 0: Q += eps_num_stab u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype) if dist.is_initialized(): dist.all_reduce(u) Q /= u Q /= K # normalize each column: total weight per sample must be 1/B Q /= torch.sum(Q, dim=0, keepdim=True) Q /= B Q *= B # the columns must sum to 1 so that Q is an assignment return Q.t()
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.