注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.utils.extractor 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from torch.utils.data import Dataset
from mmselfsup.utils import dist_forward_collect, nondist_forward_collect
[文档]class Extractor(object):
"""Feature extractor.
Args:
dataset (Dataset | dict): A PyTorch dataset or dict that indicates
the dataset.
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
dist_mode (bool): Use distributed extraction or not. Defaults to False.
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Defaults to True.
"""
def __init__(self,
dataset,
samples_per_gpu,
workers_per_gpu,
dist_mode=False,
persistent_workers=True,
**kwargs):
from mmselfsup import datasets
if isinstance(dataset, Dataset):
self.dataset = dataset
elif isinstance(dataset, dict):
self.dataset = datasets.build_dataset(dataset)
else:
raise TypeError(f'dataset must be a Dataset object or a dict, '
f'not {type(dataset)}')
self.data_loader = datasets.build_dataloader(
self.dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=dist_mode,
shuffle=False,
persistent_workers=persistent_workers,
prefetch=kwargs.get('prefetch', False),
img_norm_cfg=kwargs.get('img_norm_cfg', dict()))
self.dist_mode = dist_mode
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def _forward_func(self, runner, **x):
backbone_feat = runner.model(mode='extract', **x)
last_layer_feat = runner.model.module.neck([backbone_feat[-1]])[0]
last_layer_feat = last_layer_feat.view(last_layer_feat.size(0), -1)
return dict(feature=last_layer_feat.cpu())
def __call__(self, runner):
# the function sent to collect function
def func(**x):
return self._forward_func(runner, **x)
if self.dist_mode:
feats = dist_forward_collect(
func,
self.data_loader,
runner.rank,
len(self.dataset),
ret_rank=-1)['feature'] # NxD
else:
feats = nondist_forward_collect(func, self.data_loader,
len(self.dataset))['feature']
return feats