注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.algorithms.swav 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
[文档]@ALGORITHMS.register_module()
class SwAV(BaseModel):
"""SwAV.
Implementation of `Unsupervised Learning of Visual Features by Contrasting
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_.
The queue is built in `core/hooks/swav_hook.py`.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact
feature vectors. Defaults to None.
head (dict): Config dict for module of loss functions.
Defaults to None.
"""
def __init__(self,
backbone,
neck=None,
head=None,
init_cfg=None,
**kwargs):
super(SwAV, self).__init__(init_cfg)
self.backbone = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
assert head is not None
self.head = build_head(head)
[文档] def extract_feat(self, img):
"""Function to extract features from backbone.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
Returns:
tuple[Tensor]: Backbone outputs.
"""
x = self.backbone(img)
return x
[文档] def forward_train(self, img, **kwargs):
"""Forward computation during training.
Args:
img (list[Tensor]): A list of input images with shape
(N, C, H, W). Typically these should be mean centered
and std scaled.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert isinstance(img, list)
# multi-res forward passes
idx_crops = torch.cumsum(
torch.unique_consecutive(
torch.tensor([i.shape[-1] for i in img]),
return_counts=True)[1], 0)
start_idx = 0
output = []
for end_idx in idx_crops:
_out = self.backbone(torch.cat(img[start_idx:end_idx]))
output.append(_out)
start_idx = end_idx
output = self.neck(output)[0]
loss = self.head(output)
return loss