

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.core.hooks.swav_hook

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import torch
import torch.distributed as dist
from mmcv.runner import HOOKS, Hook

[docs]@HOOKS.register_module() class SwAVHook(Hook): """Hook for SwAV. This hook builds the queue in SwAV according to ``epoch_queue_starts``. The queue will be saved in ``runner.work_dir`` or loaded at start epoch if the path folder has queues saved before. Args: batch_size (int): the batch size per GPU for computing. epoch_queue_starts (int, optional): from this epoch, starts to use the queue. Defaults to 15. crops_for_assign (list[int], optional): list of crops id used for computing assignments. Defaults to [0, 1]. feat_dim (int, optional): feature dimension of output vector. Defaults to 128. queue_length (int, optional): length of the queue (0 for no queue). Defaults to 0. interval (int, optional): the interval to save the queue. Defaults to 1. """ def __init__(self, batch_size, epoch_queue_starts=15, crops_for_assign=[0, 1], feat_dim=128, queue_length=0, interval=1, **kwargs): self.batch_size = batch_size * dist.get_world_size()\ if dist.is_initialized() else batch_size self.epoch_queue_starts = epoch_queue_starts self.crops_for_assign = crops_for_assign self.feat_dim = feat_dim self.queue_length = queue_length self.interval = interval self.queue = None def before_run(self, runner): if dist.is_initialized(): self.queue_path = osp.join(runner.work_dir, 'queue' + str(dist.get_rank()) + '.pth') else: self.queue_path = osp.join(runner.work_dir, 'queue.pth') # build the queue if osp.isfile(self.queue_path): self.queue = torch.load(self.queue_path)['queue'] runner.model.module.head.queue = self.queue # the queue needs to be divisible by the batch size self.queue_length -= self.queue_length % self.batch_size def before_train_epoch(self, runner): # optionally starts a queue if self.queue_length > 0 \ and runner.epoch >= self.epoch_queue_starts \ and self.queue is None: self.queue = torch.zeros( len(self.crops_for_assign), self.queue_length // runner.world_size, self.feat_dim, ).cuda() # set the boolean type of use_the_queue runner.model.module.head.queue = self.queue runner.model.module.head.use_queue = False def after_train_epoch(self, runner): self.queue = runner.model.module.head.queue if self.queue is not None and self.every_n_epochs( runner, self.interval):{'queue': self.queue}, self.queue_path)
Read the Docs v: 0.x
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.