Shortcuts

Note

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.models.algorithms.mmcls_classifier_wrapper

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models import ImageClassifier
from mmcv.runner import auto_fp16

from ..builder import ALGORITHMS


[docs]@ALGORITHMS.register_module() class MMClsImageClassifierWrapper(ImageClassifier): """Workaround to use models from mmclassificaiton. Since the output of classifier from mmclassification is not compatible with mmselfsup's evaluation function. We rewrite some key components from mmclassification. Args: backbone (dict): Config dict for module of backbone. neck (dict, optional): Config dict for module of neck. Defaults to None. head (dict, optional): Config dict for module of loss functions. Defaults to None. pretrained (str, optional): The path of pre-trained checkpoint. Defaults to None. train_cfg (dict, optional): Config dict for pre-processing utils, e.g. mixup. Defaults to None. init_cfg (dict, optional): Config dict for initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict = None, head: dict = None, pretrained: str = None, train_cfg: dict = None, init_cfg: dict = None): super(MMClsImageClassifierWrapper, self).__init__( backbone=backbone, neck=neck, head=head, pretrained=pretrained, train_cfg=train_cfg, init_cfg=init_cfg)
[docs] @auto_fp16(apply_to=('img', )) def forward(self, img, mode='train', **kwargs): """Forward function of model. Calls either forward_train, forward_test or extract_feat function according to the mode. """ if mode == 'train': return self.forward_train(img, **kwargs) elif mode == 'test': return self.forward_test(img, **kwargs) elif mode == 'extract': return self.extract_feat(img) else: raise Exception(f'No such mode: {mode}')
[docs] def forward_train(self, img, label, **kwargs): """Forward computation during training. Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. label (Tensor): It should be of shape (N, 1) encoding the ground-truth label of input images for single label task. It shoulf be of shape (N, C) encoding the ground-truth label of input images for multi-labels task. Returns: dict[str, Tensor]: a dictionary of loss components """ if self.augments is not None: img, label = self.augments(img, label) x = self.extract_feat(img) losses = dict() loss = self.head.forward_train(x, label) losses.update(loss) return losses
[docs] def forward_test(self, imgs, **kwargs): """ Args: imgs (List[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains all images in the batch. """ kwargs.pop('label', None) kwargs.pop('idx', None) if isinstance(imgs, torch.Tensor): imgs = [imgs] for var, name in [(imgs, 'imgs')]: if not isinstance(var, list): raise TypeError(f'{name} must be a list, but got {type(var)}') if len(imgs) == 1: outs = self.simple_test(imgs[0], post_process=False, **kwargs) outs = outs if isinstance(outs, list) else [outs] keys = [f'head{i}' for i in self.backbone.out_indices] out_tensors = [out.cpu() for out in outs] return dict(zip(keys, out_tensors)) else: raise NotImplementedError('aug_test has not been implemented')
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.