Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.datasets.data_sources.imagenet
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import numpy as np
from ..builder import DATASOURCES
from .base import BaseDataSource
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_folders(root):
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
folder_to_idx (dict): the map from folder name to class idx
"""
folders = [d for d in os.listdir(root) if osp.isdir(osp.join(root, d))]
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folder_to_idx
def get_samples(root, folder_to_idx, extensions):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
extensions (tuple): allowed extensions
Returns:
samples (list): a list of tuple where each element is (image, label)
"""
samples = []
root = osp.expanduser(root)
for folder_name in sorted(list(folder_to_idx.keys())):
_dir = osp.join(root, folder_name)
for _, _, fns in sorted(os.walk(_dir)):
for fn in sorted(fns):
if has_file_allowed_extension(fn, extensions):
path = osp.join(folder_name, fn)
item = (path, folder_to_idx[folder_name])
samples.append(item)
return samples
[docs]@DATASOURCES.register_module()
class ImageNet(BaseDataSource):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py
""" # noqa: E501
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
def load_annotations(self):
if self.ann_file is None:
folder_to_idx = find_folders(self.data_prefix)
samples = get_samples(
self.data_prefix,
folder_to_idx,
extensions=self.IMG_EXTENSIONS)
if len(samples) == 0:
raise (RuntimeError('Found 0 files in subfolders of: '
f'{self.data_prefix}. '
'Supported extensions are: '
f'{",".join(self.IMG_EXTENSIONS)}'))
self.folder_to_idx = folder_to_idx
elif isinstance(self.ann_file, str):
with open(self.ann_file) as f:
samples = [x.strip().rsplit(' ', 1) for x in f.readlines()]
else:
raise TypeError('ann_file must be a str or None')
self.samples = samples
data_infos = []
for i, (filename, gt_label) in enumerate(self.samples):
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
info['idx'] = int(i)
data_infos.append(info)
return data_infos