Source code for mmselfsup.datasets.transforms.formatting
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import numpy as np
import torch
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, LabelData
from mmselfsup.registry import TRANSFORMS
from mmselfsup.structures import SelfSupDataSample
[docs]@TRANSFORMS.register_module()
class PackSelfSupInputs(BaseTransform):
"""Pack data into the format compatible with the inputs of algorithm.
Required Keys:
- img
Added Keys:
- data_samples
- inputs
Args:
key (str): The key of image inputted into the model. Defaults to 'img'.
algorithm_keys (List[str]): Keys of elements related
to algorithms, e.g. mask. Defaults to [].
pseudo_label_keys (List[str]): Keys set to be the attributes of
pseudo_label. Defaults to [].
meta_keys (List[str]): The keys of meta info of an image.
Defaults to [].
"""
def __init__(self,
key: str = 'img',
algorithm_keys: List[str] = [],
pseudo_label_keys: List[str] = [],
meta_keys: List[str] = []) -> None:
assert isinstance(key, str), f'key should be the type of str, instead \
of {type(key)}.'
self.key = key
self.algorithm_keys = algorithm_keys
self.pseudo_label_keys = pseudo_label_keys
self.meta_keys = meta_keys
[docs] def transform(self,
results: Dict) -> Dict[torch.Tensor, SelfSupDataSample]:
"""Method to pack the data.
Args:
results (Dict): Result dict from the data pipeline.
Returns:
Dict:
- ``inputs`` (List[torch.Tensor]): The forward data of models.
- ``data_samples`` (SelfSupDataSample): The annotation info of
the forward data.
"""
packed_results = dict()
if self.key in results:
img = results[self.key]
# if img is not a list, convert it to a list
if not isinstance(img, List):
img = [img]
for i, img_ in enumerate(img):
# to handle the single channel image
img_ = np.expand_dims(img_, -1) \
if len(img_.shape) == 2 else img_
if len(img_.shape) == 3:
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
elif len(img_.shape) == 5:
# for video data with the shape (B, C, T, H, W)
img_ = img_
else:
raise ValueError(
'img should be 2, 3 or 5 dimensional, '
f'instead of {len(img_.shape)} dimensional.')
img[i] = to_tensor(img_)
packed_results['inputs'] = img
data_sample = SelfSupDataSample()
if len(self.pseudo_label_keys) > 0:
pseudo_label = InstanceData()
data_sample.pseudo_label = pseudo_label
# gt_label, sample_idx, mask, pred_label will be set here
for key in self.algorithm_keys:
self.set_algorithm_keys(data_sample, key, results)
# keys, except for gt_label, sample_idx, mask, pred_label, will be
# set as the attributes of pseudo_label
for key in self.pseudo_label_keys:
# convert data to torch.Tensor
value = to_tensor(results[key])
setattr(data_sample.pseudo_label, key, value)
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
[docs] @classmethod
def set_algorithm_keys(self, data_sample: SelfSupDataSample, key: str,
results: dict) -> None:
"""Set the algorithm keys of SelfSupDataSample.
Args:
data_sample (SelfSupDataSample): An instance of SelfSupDataSample.
key (str): The key, which may be used by the algorithm, such as
gt_label, sample_idx, mask, pred_label. For more keys, please
refer to the attribute of SelfSupDataSample.
results (dict): The results from the data pipeline.
"""
value = to_tensor(results[key])
if key == 'sample_idx':
sample_idx = InstanceData(value=value)
setattr(data_sample, 'sample_idx', sample_idx)
elif key == 'mask':
mask = InstanceData(value=value)
setattr(data_sample, 'mask', mask)
elif key == 'gt_label':
gt_label = LabelData(value=value)
setattr(data_sample, 'gt_label', gt_label)
elif key == 'pred_label':
pred_label = LabelData(value=value)
setattr(data_sample, 'pred_label', pred_label)
else:
raise AttributeError(f'{key} is not a attribute of \
SelfSupDataSample')
def __repr__(self) -> str:
return self.__class__.__name__ + (f'(keys={self.key}, \
algorithm_keys={self.algorithm_keys}, \
pseudo_label_keys={self.pseudo_label_keys}, \
meta_keys={self.meta_keys})')