Shortcuts

mmselfsup.models.algorithms.maskfeat 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple

import torch
from mmengine.structures import BaseDataElement

from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from .base import BaseModel


[文档]@MODELS.register_module() class MaskFeat(BaseModel): """MaskFeat. Implementation of `Masked Feature Prediction for Self-Supervised Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_. """
[文档] def extract_feat(self, inputs: List[torch.Tensor], data_samples: List[SelfSupDataSample], compute_hog: bool = True, **kwarg) -> Tuple[torch.Tensor]: """The forward function to extract features from neck. Args: inputs (List[torch.Tensor]): The input images and mask. data_samples (List[SelfSupDataSample]): All elements required during the forward function. compute_hog (bool): Whether to compute hog during extraction. If True, the batch size of inputs need to be 1. Defaults to True. Returns: Tuple[torch.Tensor]: Neck outputs. """ img = inputs[0] self.mask = torch.stack( [data_sample.mask.value for data_sample in data_samples]) latent = self.backbone(img, self.mask) B, L, C = latent.shape pred = self.neck([latent.view(B * L, C)]) pred = pred[0].view(B, L, -1) # compute hog if compute_hog: assert img.size(0) == 1, 'Currently only support batch size 1.' _ = self.target_generator(img) hog_image = torch.from_numpy( self.target_generator.generate_hog_image( self.target_generator.out)).unsqueeze(0).unsqueeze(0) self.target = hog_image.expand(-1, 3, -1, -1) return pred[:, 1:, :] # remove cls token
[文档] def loss(self, inputs: List[torch.Tensor], data_samples: List[SelfSupDataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ img = inputs[0] mask = torch.stack( [data_sample.mask.value for data_sample in data_samples]) mask = mask.to(torch.bool) latent = self.backbone(img, mask) B, L, C = latent.shape pred = self.neck([latent.view(B * L, C)]) pred = pred[0].view(B, L, -1) hog = self.target_generator(img) loss = self.head(pred, hog, mask) losses = dict(loss=loss) return losses
[文档] def reconstruct(self, features: List[torch.Tensor], data_samples: Optional[List[SelfSupDataSample]] = None, **kwargs) -> SelfSupDataSample: """The function is for image reconstruction. Args: features (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: SelfSupDataSample: The prediction from model. """ # recover to HOG description from feature embeddings unfold_size = self.target_generator.unfold_size tmp4 = features.unflatten(2, (features.shape[2] // unfold_size**2, unfold_size, unfold_size)) # 1,196,27,2,2 tmp3 = tmp4.unflatten(1, self.backbone.patch_resolution) b, p1, p2, c_nbins, _, _ = tmp3.shape # 1,14,14,27,2,2 tmp2 = tmp3.permute(0, 1, 2, 5, 3, 4).reshape( (b, p1, p2 * unfold_size, c_nbins, unfold_size)) tmp1 = tmp2.permute(0, 1, 4, 2, 3).reshape( (b, p1 * unfold_size, p2 * unfold_size, c_nbins)) tmp0 = tmp1.permute(0, 3, 1, 2) # 1,27,28,28 hog_out = tmp0.unflatten(1, (int(c_nbins // self.target_generator.nbins), self.target_generator.nbins)) # 1,3,9,28,28 # generate predction of HOG hog_image = torch.from_numpy( self.target_generator.generate_hog_image(hog_out)) hog_image = hog_image.unsqueeze(0).unsqueeze(0) pred = torch.einsum('nchw->nhwc', hog_image).expand(-1, -1, -1, 3).detach().cpu() # transform patch mask to pixel mask mask = self.mask patch_dim_1 = int(self.backbone.patch_embed.init_input_size[0] // self.backbone.patch_resolution[0]) patch_dim_2 = int(self.backbone.patch_embed.init_input_size[1] // self.backbone.patch_resolution[1]) mask = mask.repeat_interleave( patch_dim_1, dim=1).repeat_interleave( patch_dim_2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3) # 1 is removing, 0 is keeping mask = mask.detach().cpu() results = SelfSupDataSample() results.mask = BaseDataElement(**dict(value=mask)) results.pred = BaseDataElement(**dict(value=pred)) return results
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.