Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.datasets.relative_loc
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torchvision.transforms.functional as TF
from mmcv.utils import build_from_cfg
from torchvision.transforms import Compose, RandomCrop
from .base import BaseDataset
from .builder import DATASETS, PIPELINES
from .utils import to_numpy
def image_to_patches(img):
"""Crop split_per_side x split_per_side patches from input image.
Args:
img (PIL Image): input image.
Returns:
list[PIL Image]: A list of cropped patches.
"""
split_per_side = 3 # split of patches per image side
patch_jitter = 21 # jitter of each patch from each grid
h, w = img.size
h_grid = h // split_per_side
w_grid = w // split_per_side
h_patch = h_grid - patch_jitter
w_patch = w_grid - patch_jitter
assert h_patch > 0 and w_patch > 0
patches = []
for i in range(split_per_side):
for j in range(split_per_side):
p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
p = RandomCrop((h_patch, w_patch))(p)
patches.append(p)
return patches
[docs]@DATASETS.register_module()
class RelativeLocDataset(BaseDataset):
"""Dataset for relative patch location.
The dataset crops image into several patches and concatenates every
surrounding patch with center one. Finally it also outputs corresponding
labels `0, 1, 2, 3, 4, 5, 6, 7`.
Args:
data_source (dict): Data source defined in
`mmselfsup.datasets.data_sources`.
pipeline (list[dict]): A list of dict, where each element represents
an operation defined in `mmselfsup.datasets.pipelines`.
format_pipeline (list[dict]): A list of dict, it converts input format
from PIL.Image to Tensor. The operation is defined in
`mmselfsup.datasets.pipelines`.
prefetch (bool, optional): Whether to prefetch data. Defaults to False.
"""
def __init__(self, data_source, pipeline, format_pipeline, prefetch=False):
super(RelativeLocDataset, self).__init__(data_source, pipeline,
prefetch)
format_pipeline = [
build_from_cfg(p, PIPELINES) for p in format_pipeline
]
self.format_pipeline = Compose(format_pipeline)
def __getitem__(self, idx):
img = self.data_source.get_img(idx)
img = self.pipeline(img)
patches = image_to_patches(img)
if self.prefetch:
patches = [torch.from_numpy(to_numpy(p)) for p in patches]
else:
patches = [self.format_pipeline(p) for p in patches]
perms = []
# create a list of patch pairs
[
perms.append(torch.cat((patches[i], patches[4]), dim=0))
for i in range(9) if i != 4
]
# create corresponding labels for patch pairs
patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
return dict(
img=torch.stack(perms), patch_label=patch_labels) # 8(2C)HW, 8
def evaluate(self, results, logger=None):
return NotImplemented