

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import BaseModule

from mmselfsup.utils import distributed_sinkhorn
from ..builder import HEADS
from ..utils import MultiPrototypes

[文档]@HEADS.register_module() class SwAVHead(BaseModule): """The head for SwAV. This head contains clustering and sinkhorn algorithms to compute Q codes. Part of the code is borrowed from: `<`_. The queue is built in `core/hooks/`. Args: feat_dim (int): feature dimension of the prototypes. sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp algorithm. Defaults to 3. epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm. Defaults to 0.05. temperature (float): temperature parameter in training loss. Defaults to 0.1. crops_for_assign (list[int]): list of crops id used for computing assignments. Defaults to [0, 1]. num_crops (list[int]): list of number of crops. Defaults to [2]. num_prototypes (int): number of prototypes. Defaults to 3000. init_cfg (dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, feat_dim, sinkhorn_iterations=3, epsilon=0.05, temperature=0.1, crops_for_assign=[0, 1], num_crops=[2], num_prototypes=3000, init_cfg=None): super(SwAVHead, self).__init__(init_cfg) self.sinkhorn_iterations = sinkhorn_iterations self.epsilon = epsilon self.temperature = temperature self.crops_for_assign = crops_for_assign self.num_crops = num_crops self.use_queue = False self.queue = None self.world_size = dist.get_world_size() if dist.is_initialized() else 1 # prototype layer self.prototypes = None if isinstance(num_prototypes, list): self.prototypes = MultiPrototypes(feat_dim, num_prototypes) elif num_prototypes > 0: self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False) assert self.prototypes is not None
[文档] def forward(self, x): """Forward head of swav to compute the loss. Args: x (Tensor): NxC input features. Returns: dict[str, Tensor]: A dictionary of loss components. """ # normalize the prototypes with torch.no_grad(): w = w = nn.functional.normalize(w, dim=1, p=2) self.prototypes.weight.copy_(w) embedding, output = x, self.prototypes(x) embedding = embedding.detach() bs = int(embedding.size(0) / sum(self.num_crops)) loss = 0 for i, crop_id in enumerate(self.crops_for_assign): with torch.no_grad(): out = output[bs * crop_id:bs * (crop_id + 1)].detach() # time to use the queue if self.queue is not None: if self.use_queue or not torch.all(self.queue[i, -1, :] == 0): self.use_queue = True out = ([i], self.prototypes.weight.t()), out)) # fill the queue self.queue[i, bs:] = self.queue[i, :-bs].clone() self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) * bs] # get assignments (batch_size * num_prototypes) q = distributed_sinkhorn(out, self.sinkhorn_iterations, self.world_size, self.epsilon)[-bs:] # cluster assignment prediction subloss = 0 for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id): x = output[bs * v:bs * (v + 1)] / self.temperature subloss -= torch.mean( torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1)) loss += subloss / (np.sum(self.num_crops) - 1) loss /= len(self.crops_for_assign) return dict(loss=loss)
