注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.algorithms.cae 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
from torchvision.transforms import Normalize
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
[文档]@ALGORITHMS.register_module()
class CAE(BaseModel):
"""CAE.
Implementation of `Context Autoencoder for Self-Supervised Representation
Learning <https://arxiv.org/abs/2202.03026>`_.
Args:
backbone (dict, optional): Config dict for module of backbone.
neck (dict, optional): Config dict for module of deep features to
compact feature vectors. Defaults to None.
head (dict, optional): Config dict for module of loss functions.
Defaults to None.
base_momentum (float): The base momentum coefficient for the target
network. Defaults to 0.0.
init_cfg (dict, optional): the config to control the initialization.
"""
def __init__(self,
backbone: dict = None,
neck: dict = None,
head: dict = None,
base_momentum: float = 0.0,
init_cfg: dict = None,
**kwargs) -> None:
super(CAE, self).__init__(init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
self.teacher = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
assert head is not None
self.head = build_head(head)
self.momentum = base_momentum
self.img_norm = Normalize(
mean=torch.tensor((0.485, 0.456, 0.406)),
std=torch.tensor((0.229, 0.224, 0.225)))
def _init_teacher(self) -> None:
# init the weights of teacher with those of backbone
for param_backbone, param_teacher in zip(self.backbone.parameters(),
self.teacher.parameters()):
param_teacher.detach()
param_teacher.data.copy_(param_backbone.data)
param_teacher.requires_grad = False
[文档] def momentum_update(self) -> None:
"""Momentum update of the teacher network."""
for param_bacbone, param_teacher in zip(self.backbone.parameters(),
self.teacher.parameters()):
param_teacher.data = param_teacher.data * self.momentum + \
param_bacbone.data * (1. - self.momentum)
[文档] def extract_feat(self, img: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
x = self.backbone(img, mask)
return x
[文档] def forward_train(self, samples: Sequence, **kwargs) -> dict:
img, img_target, mask = samples
# normalize images and the images to get the target
img_list = [self.img_norm(x).unsqueeze(0) for x in img]
img = torch.cat(img_list)
img_target = 0.8 * img_target + 0.1
mask = mask.flatten(1).to(torch.bool)
unmasked = self.backbone(img, mask)
# get the latent prediction for the masked patches
with torch.no_grad():
latent_target = self.teacher(img, ~mask)
latent_target = latent_target[:, 1:, :]
self.momentum_update()
pos_embed = self.backbone.pos_embed.expand(img.shape[0], -1, -1)
pos_embed_masked = pos_embed[:,
1:][mask].reshape(img.shape[0], -1,
pos_embed.shape[-1])
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
img.shape[0], -1, pos_embed.shape[-1])
# input the unmasked tokens and masked tokens to the decoder
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,
pos_embed_unmasked)
logits = logits.view(-1, logits.shape[-1])
losses = self.head(img_target, logits, latent_pred, latent_target,
mask)
return losses