注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.models.heads.mae_head 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models import LabelSmoothLoss
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule
from torch import nn
from ..builder import HEADS
[文档]@HEADS.register_module()
class MAEPretrainHead(BaseModule):
"""Pre-training head for MAE.
Args:
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
"""
def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
[文档] def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""
Args:
imgs (torch.Tensor): The shape is (N, 3, H, W)
Returns:
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
[文档] def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
Returns:
imgs (torch.Tensor): The shape is (N, 3, H, W)
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
[文档] def forward(self, x: torch.Tensor, pred: torch.Tensor,
mask: torch.Tensor) -> dict:
losses = dict()
target = self.patchify(x)
if self.norm_pix:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target)**2
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum()
losses['loss'] = loss
return losses
[文档]@HEADS.register_module()
class MAEFinetuneHead(BaseModule):
"""Fine-tuning head for MAE.
Args:
embed_dim (int): The dim of the feature before the classifier head.
num_classes (int): The total classes. Defaults to 1000.
"""
def __init__(self, embed_dim, num_classes=1000, label_smooth_val=0.1):
super().__init__()
self.head = nn.Linear(embed_dim, num_classes)
self.criterion = LabelSmoothLoss(label_smooth_val, num_classes)
[文档] def init_weights(self):
nn.init.constant_(self.head.bias, 0)
trunc_normal_(self.head.weight, std=2e-5)
[文档] def loss(self, outputs, labels):
"""Compute the loss."""
losses = dict()
losses['loss'] = self.criterion(outputs[0], labels)
return losses
[文档]@HEADS.register_module()
class MAELinprobeHead(BaseModule):
"""Linear probing head for MAE.
Args:
embed_dim (int): The dim of the feature before the classifier head.
num_classes (int): The total classes. Defaults to 1000.
"""
def __init__(self, embed_dim, num_classes=1000):
super().__init__()
self.head = nn.Linear(embed_dim, num_classes)
self.bn = nn.BatchNorm1d(embed_dim, affine=False, eps=1e-6)
self.criterion = nn.CrossEntropyLoss()
[文档] def init_weights(self):
nn.init.constant_(self.head.bias, 0)
trunc_normal_(self.head.weight, std=0.01)
[文档] def forward(self, x):
""""Get the logits."""
x = self.bn(x)
outputs = self.head(x)
return [outputs]
[文档] def loss(self, outputs, labels):
"""Compute the loss."""
losses = dict()
losses['loss'] = self.criterion(outputs[0], labels)
return losses