Shortcuts

mmselfsup.datasets.transforms.formatting 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List

import numpy as np
import torch
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, LabelData

from mmselfsup.registry import TRANSFORMS
from mmselfsup.structures import SelfSupDataSample


[文档]@TRANSFORMS.register_module() class PackSelfSupInputs(BaseTransform): """Pack data into the format compatible with the inputs of algorithm. Required Keys: - img Added Keys: - data_samples - inputs Args: key (str): The key of image inputted into the model. Defaults to 'img'. algorithm_keys (List[str]): Keys of elements related to algorithms, e.g. mask. Defaults to []. pseudo_label_keys (List[str]): Keys set to be the attributes of pseudo_label. Defaults to []. meta_keys (List[str]): The keys of meta info of an image. Defaults to []. """ def __init__(self, key: str = 'img', algorithm_keys: List[str] = [], pseudo_label_keys: List[str] = [], meta_keys: List[str] = []) -> None: assert isinstance(key, str), f'key should be the type of str, instead \ of {type(key)}.' self.key = key self.algorithm_keys = algorithm_keys self.pseudo_label_keys = pseudo_label_keys self.meta_keys = meta_keys
[文档] def transform(self, results: Dict) -> Dict[torch.Tensor, SelfSupDataSample]: """Method to pack the data. Args: results (Dict): Result dict from the data pipeline. Returns: Dict: - ``inputs`` (List[torch.Tensor]): The forward data of models. - ``data_samples`` (SelfSupDataSample): The annotation info of the forward data. """ packed_results = dict() if self.key in results: img = results[self.key] # if img is not a list, convert it to a list if not isinstance(img, List): img = [img] for i, img_ in enumerate(img): # to handle the single channel image img_ = np.expand_dims(img_, -1) \ if len(img_.shape) == 2 else img_ if len(img_.shape) == 3: img_ = np.ascontiguousarray(img_.transpose(2, 0, 1)) elif len(img_.shape) == 5: # for video data with the shape (B, C, T, H, W) img_ = img_ else: raise ValueError( 'img should be 2, 3 or 5 dimensional, ' f'instead of {len(img_.shape)} dimensional.') img[i] = to_tensor(img_) packed_results['inputs'] = img data_sample = SelfSupDataSample() if len(self.pseudo_label_keys) > 0: pseudo_label = InstanceData() data_sample.pseudo_label = pseudo_label # gt_label, sample_idx, mask, pred_label will be set here for key in self.algorithm_keys: self.set_algorithm_keys(data_sample, key, results) # keys, except for gt_label, sample_idx, mask, pred_label, will be # set as the attributes of pseudo_label for key in self.pseudo_label_keys: # convert data to torch.Tensor value = to_tensor(results[key]) setattr(data_sample.pseudo_label, key, value) img_meta = {} for key in self.meta_keys: img_meta[key] = results[key] data_sample.set_metainfo(img_meta) packed_results['data_samples'] = data_sample return packed_results
[文档] @classmethod def set_algorithm_keys(self, data_sample: SelfSupDataSample, key: str, results: dict) -> None: """Set the algorithm keys of SelfSupDataSample. Args: data_sample (SelfSupDataSample): An instance of SelfSupDataSample. key (str): The key, which may be used by the algorithm, such as gt_label, sample_idx, mask, pred_label. For more keys, please refer to the attribute of SelfSupDataSample. results (dict): The results from the data pipeline. """ value = to_tensor(results[key]) if key == 'sample_idx': sample_idx = InstanceData(value=value) setattr(data_sample, 'sample_idx', sample_idx) elif key == 'mask': mask = InstanceData(value=value) setattr(data_sample, 'mask', mask) elif key == 'gt_label': gt_label = LabelData(value=value) setattr(data_sample, 'gt_label', gt_label) elif key == 'pred_label': pred_label = LabelData(value=value) setattr(data_sample, 'pred_label', pred_label) else: raise AttributeError(f'{key} is not a attribute of \ SelfSupDataSample')
def __repr__(self) -> str: return self.__class__.__name__ + (f'(keys={self.key}, \ algorithm_keys={self.algorithm_keys}, \ pseudo_label_keys={self.pseudo_label_keys}, \ meta_keys={self.meta_keys})')
Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.