Shortcuts

数据结构

像 OpenMMLab 中其他仓库一样,MMSelfSup 也定义了一个数据结构,名为 SelfSupDataSample ,这个数据结构用于接收和传递整个训练和测试过程中的数据。 SelfSupDataSample 继承 MMEngine 中使用的 BaseDataElement。如果需要深入了解 BaseDataElement,我们建议参考 BaseDataElement。在这些教程中,我们主要讨论 SelfSupDataSample 中一些定制化的属性。

SelfSupDataSample 中的定制化的属性

在 MMSelfSup 中,SelfSupDataSample 将模型需要的所有信息(除了图片)打包,比如 mask image modeling(MIM) 中请求的 mask 和前置任务中的 pseudo_label 。除了提供信息,它还能接受模型产生的信息,比如预测得分。为实现上述功能, SelfSupDataSample 定义以下五个属性:

  • gt_label(标签数据),包含图片的真实标签。

  • sample_idx(实例数据),包含一开始被数据集初始化的数据列表中的最近的图片的序号。

  • mask(数据基类),包含 MIM 中的面具,比如 SimMIM 和 CAE。

  • pred_label(标签数据),包含模型预测的标签。

  • pseudo_label(数据基类),包含前置任务中用到的假的标签,比如 Relation Location 中的 location。

为了帮助使用者理解 SelfSupDataSample 中的基本思想,我们给出一个关于如何创建 SelfSupDataSample 实例并设置这些属性的简单例子。

import torch
from mmselfsup.core import SelfSupDataSample
from mmengine.data import LabelData, InstanceData, BaseDataElement

selfsup_data_sample = SelfSupDataSample()
# 在 selfsup_data_sample 里加入真实标签数据
# 真实标签数据的类型应与 LabelData 的类型一致
selfsup_data_sample.gt_label = LabelData(value=torch.tensor([1]))

# 如果真实标签数据类型和 LabelData 不一致会报错
selfsup_data_sample.gt_label = torch.tensor([1])
# 报错: AssertionError: tensor([1]) should be a <class 'mmengine.data.label_data.LabelData'> but got <class 'torch.Tensor'>

# 给 selfsup_data_sample 加入样例数据
# 同样的,样例数据里的值的类型应与 InstanceData 保持一致
selfsup_data_sample.sample_idx = InstanceData(value=torch.tensor([1]))

# 给 selfsup_data_sample 加面具
selfsup_data_sample.mask = BaseDataElement(value=torch.ones((3, 3)))

# 给 selfsup_data_sample 加假标签
selfsup_data_sample.pseudo_label = InstanceData(location=torch.tensor([1, 2, 3]))


# 创建这些属性后,您可轻而易举得取这些属性里的值
print(selfsup_data_sample.gt_label.value)
# 输出 tensor([1])
print(selfsup_data_sample.mask.value.shape)
# 输出 torch.Size([3, 3])

用 MMSelfSup 把数据打包给 SelfSupDataSample

在把数据喂给模型之前, MMSelfSup 按照数据流程把数据打包进 SelfSupDataSample 。如果您不熟悉数据流程,可以参考 data transform。我们用一个叫 PackSelfSupInputs的数据变换来打包数据。

class PackSelfSupInputs(BaseTransform):
    """把数据打包并让格式能与函数输入匹配

    需要的值:

    - img

    添加的值:

    - data_sample
    - inputs

    参数:
        key (str): 输入模型的图片的值,默认为 img 。
        algorithm_keys (List[str]): 和算法相关的组成部分的值,比如 mask 。默认为 [] 。
        pseudo_label_keys (List[str]): 假标签对应的属性。默认为 [] 。
        meta_keys (List[str]): 图片的 meta 信息的值。默认为 [] 。

    """

    def __init__(self,
                 key: Optional[str] = 'img',
                 algorithm_keys: Optional[List[str]] = [],
                 pseudo_label_keys: Optional[List[str]] = [],
                 meta_keys: Optional[List[str]] = []) -> None:
        assert isinstance(key, str), f'key should be the type of str, instead \
            of {type(key)}.'

        self.key = key
        self.algorithm_keys = algorithm_keys
        self.pseudo_label_keys = pseudo_label_keys
        self.meta_keys = meta_keys

    def transform(self,
                  results: Dict) -> Dict[torch.Tensor, SelfSupDataSample]:
        """打包数据的方法。

        参数:
            results (Dict): 数据变换返回的字典。

        返回:
            Dict:

            - 'inputs' (List[torch.Tensor]): 模型前面的数据。
            - 'data_sample' (SelfSupDataSample): 前面数据的注释信息。
        """
        packed_results = dict()
        if self.key in results:
            img = results[self.key]
            # if img is not a list, convert it to a list
            if not isinstance(img, List):
                img = [img]
            for i, img_ in enumerate(img):
                if len(img_.shape) < 3:
                    img_ = np.expand_dims(img_, -1)
                img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
                img[i] = to_tensor(img_)
            packed_results['inputs'] = img

        data_sample = SelfSupDataSample()
        if len(self.pseudo_label_keys) > 0:
            pseudo_label = InstanceData()
            data_sample.pseudo_label = pseudo_label

        # gt_label, sample_idx, mask, pred_label 在此设置
        for key in self.algorithm_keys:
            self.set_algorithm_keys(data_sample, key, results)

        # 除 gt_label, sample_idx, mask, pred_label 外的值会被设为假标签的属性
        for key in self.pseudo_label_keys:
            # convert data to torch.Tensor
            value = to_tensor(results[key])
            setattr(data_sample.pseudo_label, key, value)

        img_meta = {}
        for key in self.meta_keys:
            img_meta[key] = results[key]
        data_sample.set_metainfo(img_meta)
        packed_results['data_sample'] = data_sample

        return packed_results

    @classmethod
    def set_algorithm_keys(self, data_sample: SelfSupDataSample, key: str,
                           results: Dict) -> None:
        """设置 SelfSupDataSample 中算法的值."""
        value = to_tensor(results[key])
        if key == 'sample_idx':
            sample_idx = InstanceData(value=value)
            setattr(data_sample, 'sample_idx', sample_idx)
        elif key == 'mask':
            mask = InstanceData(value=value)
            setattr(data_sample, 'mask', mask)
        elif key == 'gt_label':
            gt_label = LabelData(value=value)
            setattr(data_sample, 'gt_label', gt_label)
        elif key == 'pred_label':
            pred_label = LabelData(value=value)
            setattr(data_sample, 'pred_label', pred_label)
        else:
            raise AttributeError(f'{key} is not a attribute of \
                SelfSupDataSample')

在 SelfSupDataSample 中 algorithm_keys 是除了 pseudo_label 的数据属性, pseudo_label_keys 是 SelfSupDataSample 中假标签对应的分支属性。 感谢读完整个教程。有问题的话可以在 GitHub 上提 issue,我们会尽快联系您。

Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.