

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.datasets.relative_loc

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torchvision.transforms.functional as TF
from mmcv.utils import build_from_cfg
from torchvision.transforms import Compose, RandomCrop

from .base import BaseDataset
from .builder import DATASETS, PIPELINES
from .utils import to_numpy

def image_to_patches(img):
    """Crop split_per_side x split_per_side patches from input image.

        img (PIL Image): input image.

        list[PIL Image]: A list of cropped patches.
    split_per_side = 3  # split of patches per image side
    patch_jitter = 21  # jitter of each patch from each grid
    h, w = img.size
    h_grid = h // split_per_side
    w_grid = w // split_per_side
    h_patch = h_grid - patch_jitter
    w_patch = w_grid - patch_jitter
    assert h_patch > 0 and w_patch > 0
    patches = []
    for i in range(split_per_side):
        for j in range(split_per_side):
            p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
            p = RandomCrop((h_patch, w_patch))(p)
    return patches

[docs]@DATASETS.register_module() class RelativeLocDataset(BaseDataset): """Dataset for relative patch location. The dataset crops image into several patches and concatenates every surrounding patch with center one. Finally it also outputs corresponding labels `0, 1, 2, 3, 4, 5, 6, 7`. Args: data_source (dict): Data source defined in `mmselfsup.datasets.data_sources`. pipeline (list[dict]): A list of dict, where each element represents an operation defined in `mmselfsup.datasets.pipelines`. format_pipeline (list[dict]): A list of dict, it converts input format from PIL.Image to Tensor. The operation is defined in `mmselfsup.datasets.pipelines`. prefetch (bool, optional): Whether to prefetch data. Defaults to False. """ def __init__(self, data_source, pipeline, format_pipeline, prefetch=False): super(RelativeLocDataset, self).__init__(data_source, pipeline, prefetch) format_pipeline = [ build_from_cfg(p, PIPELINES) for p in format_pipeline ] self.format_pipeline = Compose(format_pipeline) def __getitem__(self, idx): img = self.data_source.get_img(idx) img = self.pipeline(img) patches = image_to_patches(img) if self.prefetch: patches = [torch.from_numpy(to_numpy(p)) for p in patches] else: patches = [self.format_pipeline(p) for p in patches] perms = [] # create a list of patch pairs [ perms.append([i], patches[4]), dim=0)) for i in range(9) if i != 4 ] # create corresponding labels for patch pairs patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7]) return dict( img=torch.stack(perms), patch_label=patch_labels) # 8(2C)HW, 8 def evaluate(self, results, logger=None): return NotImplemented
Read the Docs v: 0.x
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.