Shortcuts

Source code for mmselfsup.models.algorithms.moco

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from mmengine.model import ExponentialMovingAverage

from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
                             concat_all_gather)
from .base import BaseModel


[docs]@MODELS.register_module() class MoCo(BaseModel): """MoCo. Implementation of `Momentum Contrast for Unsupervised Visual Representation Learning <https://arxiv.org/abs/1911.05722>`_. Part of the code is borrowed from: `<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_. Args: backbone (dict): Config dict for module of backbone. neck (dict): Config dict for module of deep features to compact feature vectors. head (dict): Config dict for module of head functions. queue_len (int): Number of negative keys maintained in the queue. Defaults to 65536. feat_dim (int): Dimension of compact feature vectors. Defaults to 128. momentum (float): Momentum coefficient for the momentum-updated encoder. Defaults to 0.999. pretrained (str, optional): The pretrained checkpoint path, support local path and remote path. Defaults to None. data_preprocessor (dict, optional): The config for preprocessing input data. If None or no specified type, it will use "SelfSupDataPreprocessor" as type. See :class:`SelfSupDataPreprocessor` for more details. Defaults to None. init_cfg (Union[List[dict], dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict, head: dict, queue_len: int = 65536, feat_dim: int = 128, momentum: float = 0.999, pretrained: Optional[str] = None, data_preprocessor: Optional[dict] = None, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( backbone=backbone, neck=neck, head=head, pretrained=pretrained, data_preprocessor=data_preprocessor, init_cfg=init_cfg) # create momentum model self.encoder_k = ExponentialMovingAverage( nn.Sequential(self.backbone, self.neck), 1 - momentum) # create the queue self.queue_len = queue_len self.register_buffer('queue', torch.randn(feat_dim, queue_len)) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) @torch.no_grad() def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: """Update queue.""" # gather keys before updating queue keys = concat_all_gather(keys) batch_size = keys.shape[0] ptr = int(self.queue_ptr) assert self.queue_len % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) ptr = (ptr + batch_size) % self.queue_len # move pointer self.queue_ptr[0] = ptr
[docs] def extract_feat(self, inputs: List[torch.Tensor], **kwarg) -> Tuple[torch.Tensor]: """Function to extract features from backbone. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Tuple[torch.Tensor]: Backbone outputs. """ x = self.backbone(inputs[0]) return x
[docs] def loss(self, inputs: List[torch.Tensor], data_samples: List[SelfSupDataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ im_q = inputs[0] im_k = inputs[1] # compute query features from encoder_q q = self.neck(self.backbone(im_q))[0] # queries: NxC q = nn.functional.normalize(q, dim=1) # compute key features with torch.no_grad(): # no gradient to keys # update the key encoder self.encoder_k.update_parameters( nn.Sequential(self.backbone, self.neck)) # shuffle for making use of BN im_k, idx_unshuffle = batch_shuffle_ddp(im_k) k = self.encoder_k(im_k)[0] # keys: NxC k = nn.functional.normalize(k, dim=1) # undo shuffle k = batch_unshuffle_ddp(k, idx_unshuffle) # compute logits # Einstein sum is more intuitive # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) loss = self.head(l_pos, l_neg) # update the queue self._dequeue_and_enqueue(k) losses = dict(loss=loss) return losses
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.