Shortcuts

注意

您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码文档 获取更多信息。

mmselfsup.datasets.deepcluster 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .base import BaseDataset
from .builder import DATASETS
from .utils import to_numpy


[文档]@DATASETS.register_module() class DeepClusterDataset(BaseDataset): """Dataset for DC and ODC. The dataset initializes clustering labels and assigns it during training. Args: data_source (dict): Data source defined in `mmselfsup.datasets.data_sources`. pipeline (list[dict]): A list of dict, where each element represents an operation defined in `mmselfsup.datasets.pipelines`. prefetch (bool, optional): Whether to prefetch data. Defaults to False. """ def __init__(self, data_source, pipeline, prefetch=False): super(DeepClusterDataset, self).__init__(data_source, pipeline, prefetch) # init clustering labels self.clustering_labels = [-1 for _ in range(len(self.data_source))] def __getitem__(self, idx): img = self.data_source.get_img(idx) img = self.pipeline(img) clustering_label = self.clustering_labels[idx] if self.prefetch: img = torch.from_numpy(to_numpy(img)) return dict(img=img, pseudo_label=clustering_label, idx=idx) def assign_labels(self, labels): assert len(self.clustering_labels) == len(labels), ( f'Inconsistent length of assigned labels, ' f'{len(self.clustering_labels)} vs {len(labels)}') self.clustering_labels = labels[:] def evaluate(self, results, logger=None): return NotImplemented
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.