注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.heads.swav_head 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import BaseModule
from mmselfsup.utils import distributed_sinkhorn
from ..builder import HEADS
from ..utils import MultiPrototypes
[文档]@HEADS.register_module()
class SwAVHead(BaseModule):
"""The head for SwAV.
This head contains clustering and sinkhorn algorithms to compute Q codes.
Part of the code is borrowed from:
`<https://github.com/facebookresearch/swav`_.
The queue is built in `core/hooks/swav_hook.py`.
Args:
feat_dim (int): feature dimension of the prototypes.
sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp
algorithm. Defaults to 3.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
Defaults to 0.05.
temperature (float): temperature parameter in training loss.
Defaults to 0.1.
crops_for_assign (list[int]): list of crops id used for computing
assignments. Defaults to [0, 1].
num_crops (list[int]): list of number of crops. Defaults to [2].
num_prototypes (int): number of prototypes. Defaults to 3000.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
feat_dim,
sinkhorn_iterations=3,
epsilon=0.05,
temperature=0.1,
crops_for_assign=[0, 1],
num_crops=[2],
num_prototypes=3000,
init_cfg=None):
super(SwAVHead, self).__init__(init_cfg)
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.num_crops = num_crops
self.use_queue = False
self.queue = None
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
# prototype layer
self.prototypes = None
if isinstance(num_prototypes, list):
self.prototypes = MultiPrototypes(feat_dim, num_prototypes)
elif num_prototypes > 0:
self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False)
assert self.prototypes is not None
[文档] def forward(self, x):
"""Forward head of swav to compute the loss.
Args:
x (Tensor): NxC input features.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# normalize the prototypes
with torch.no_grad():
w = self.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
self.prototypes.weight.copy_(w)
embedding, output = x, self.prototypes(x)
embedding = embedding.detach()
bs = int(embedding.size(0) / sum(self.num_crops))
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id:bs * (crop_id + 1)].detach()
# time to use the queue
if self.queue is not None:
if self.use_queue or not torch.all(self.queue[i,
-1, :] == 0):
self.use_queue = True
out = torch.cat(
(torch.mm(self.queue[i],
self.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) *
bs]
# get assignments (batch_size * num_prototypes)
q = distributed_sinkhorn(out, self.sinkhorn_iterations,
self.world_size, self.epsilon)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id):
x = output[bs * v:bs * (v + 1)] / self.temperature
subloss -= torch.mean(
torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(self.num_crops) - 1)
loss /= len(self.crops_for_assign)
return dict(loss=loss)