注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.algorithms.simclr 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from ..utils import GatherLayer
from .base import BaseModel
[文档]@ALGORITHMS.register_module()
class SimCLR(BaseModel):
"""SimCLR.
Implementation of `A Simple Framework for Contrastive Learning
of Visual Representations <https://arxiv.org/abs/2002.05709>`_.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact feature
vectors. Defaults to None.
head (dict): Config dict for module of loss functions.
Defaults to None.
"""
def __init__(self, backbone, neck=None, head=None, init_cfg=None):
super(SimCLR, self).__init__(init_cfg)
self.backbone = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
assert head is not None
self.head = build_head(head)
@staticmethod
def _create_buffer(N):
"""Compute the mask and the index of positive samples.
Args:
N (int): batch size.
"""
mask = 1 - torch.eye(N * 2, dtype=torch.uint8).cuda()
pos_ind = (torch.arange(N * 2).cuda(),
2 * torch.arange(N, dtype=torch.long).unsqueeze(1).repeat(
1, 2).view(-1, 1).squeeze().cuda())
neg_mask = torch.ones((N * 2, N * 2 - 1), dtype=torch.uint8).cuda()
neg_mask[pos_ind] = 0
return mask, pos_ind, neg_mask
[文档] def extract_feat(self, img):
"""Function to extract features from backbone.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
Returns:
tuple[Tensor]: backbone outputs.
"""
x = self.backbone(img)
return x
[文档] def forward_train(self, img, **kwargs):
"""Forward computation during training.
Args:
img (list[Tensor]): A list of input images with shape
(N, C, H, W). Typically these should be mean centered
and std scaled.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert isinstance(img, list)
img = torch.stack(img, 1)
img = img.reshape(
(img.size(0) * 2, img.size(2), img.size(3), img.size(4)))
x = self.extract_feat(img) # 2n
z = self.neck(x)[0] # (2n)xd
z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10)
z = torch.cat(GatherLayer.apply(z), dim=0) # (2N)xd
assert z.size(0) % 2 == 0
N = z.size(0) // 2
s = torch.matmul(z, z.permute(1, 0)) # (2N)x(2N)
mask, pos_ind, neg_mask = self._create_buffer(N)
# remove diagonal, (2N)x(2N-1)
s = torch.masked_select(s, mask == 1).reshape(s.size(0), -1)
positive = s[pos_ind].unsqueeze(1) # (2N)x1
# select negative, (2N)x(2N-2)
negative = torch.masked_select(s, neg_mask == 1).reshape(s.size(0), -1)
losses = self.head(positive, negative)
return losses