Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.models.algorithms.npid 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from ..builder import (ALGORITHMS, build_backbone, build_head, build_memory,
                       build_neck)
from .base import BaseModel


[文档]@ALGORITHMS.register_module() class NPID(BaseModel): """NPID. Implementation of `Unsupervised Feature Learning via Non-parametric Instance Discrimination <https://arxiv.org/abs/1805.01978>`_. Args: backbone (dict): Config dict for module of backbone. neck (dict): Config dict for module of deep features to compact feature vectors. Defaults to None. head (dict): Config dict for module of loss functions. Defaults to None. memory_bank (dict): Config dict for module of memory banks. Defaults to None. neg_num (int): Number of negative samples for each image. Defaults to 65536. ensure_neg (bool): If False, there is a small probability that negative samples contain positive ones. Defaults to False. """ def __init__(self, backbone, neck=None, head=None, memory_bank=None, neg_num=65536, ensure_neg=False, init_cfg=None): super(NPID, self).__init__(init_cfg) self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) assert head is not None self.head = build_head(head) assert memory_bank is not None self.memory_bank = build_memory(memory_bank) self.neg_num = neg_num self.ensure_neg = ensure_neg
[文档] def extract_feat(self, img): """Function to extract features from backbone. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. Returns: tuple[Tensor]: backbone outputs. """ x = self.backbone(img) return x
[文档] def forward_train(self, img, idx, **kwargs): """Forward computation during training. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. idx (Tensor): Index corresponding to each image. kwargs: Any keyword arguments to be used to forward. Returns: dict[str, Tensor]: A dictionary of loss components. """ feature = self.extract_feat(img) idx = idx.cuda() if self.with_neck: feature = self.neck(feature)[0] feature = nn.functional.normalize(feature) # BxC bs, feat_dim = feature.shape[:2] neg_idx = self.memory_bank.multinomial.draw(bs * self.neg_num) if self.ensure_neg: neg_idx = neg_idx.view(bs, -1) while True: wrong = (neg_idx == idx.view(-1, 1)) if wrong.sum().item() > 0: neg_idx[wrong] = self.memory_bank.multinomial.draw( wrong.sum().item()) else: break neg_idx = neg_idx.flatten() pos_feat = torch.index_select(self.memory_bank.feature_bank, 0, idx) # BXC neg_feat = torch.index_select(self.memory_bank.feature_bank, 0, neg_idx).view(bs, self.neg_num, feat_dim) # BxKxC pos_logits = torch.einsum('nc,nc->n', [pos_feat, feature]).unsqueeze(-1) neg_logits = torch.bmm(neg_feat, feature.unsqueeze(2)).squeeze(2) losses = self.head(pos_logits, neg_logits) # update memory bank with torch.no_grad(): self.memory_bank.update(idx, feature.detach()) return losses
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.