注意
您正在阅读 MMSelfSup 0.x 版本的文档,而 MMSelfSup 0.x 版本将会在 2022 年末 开始逐步停止维护。我们建议您及时升级到 MMSelfSup 1.0.0rc 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMSelfSup 1.0.0rc 的 发版日志, 代码 和 文档 获取更多信息。
mmselfsup.core.hooks.swav_hook 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import torch
import torch.distributed as dist
from mmcv.runner import HOOKS, Hook
[文档]@HOOKS.register_module()
class SwAVHook(Hook):
"""Hook for SwAV.
This hook builds the queue in SwAV according to ``epoch_queue_starts``.
The queue will be saved in ``runner.work_dir`` or loaded at start epoch
if the path folder has queues saved before.
Args:
batch_size (int): the batch size per GPU for computing.
epoch_queue_starts (int, optional): from this epoch, starts to use the
queue. Defaults to 15.
crops_for_assign (list[int], optional): list of crops id used for
computing assignments. Defaults to [0, 1].
feat_dim (int, optional): feature dimension of output vector.
Defaults to 128.
queue_length (int, optional): length of the queue (0 for no queue).
Defaults to 0.
interval (int, optional): the interval to save the queue.
Defaults to 1.
"""
def __init__(self,
batch_size,
epoch_queue_starts=15,
crops_for_assign=[0, 1],
feat_dim=128,
queue_length=0,
interval=1,
**kwargs):
self.batch_size = batch_size * dist.get_world_size()\
if dist.is_initialized() else batch_size
self.epoch_queue_starts = epoch_queue_starts
self.crops_for_assign = crops_for_assign
self.feat_dim = feat_dim
self.queue_length = queue_length
self.interval = interval
self.queue = None
def before_run(self, runner):
if dist.is_initialized():
self.queue_path = osp.join(runner.work_dir,
'queue' + str(dist.get_rank()) + '.pth')
else:
self.queue_path = osp.join(runner.work_dir, 'queue.pth')
# build the queue
if osp.isfile(self.queue_path):
self.queue = torch.load(self.queue_path)['queue']
runner.model.module.head.queue = self.queue
# the queue needs to be divisible by the batch size
self.queue_length -= self.queue_length % self.batch_size
def before_train_epoch(self, runner):
# optionally starts a queue
if self.queue_length > 0 \
and runner.epoch >= self.epoch_queue_starts \
and self.queue is None:
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // runner.world_size,
self.feat_dim,
).cuda()
# set the boolean type of use_the_queue
runner.model.module.head.queue = self.queue
runner.model.module.head.use_queue = False
def after_train_epoch(self, runner):
self.queue = runner.model.module.head.queue
if self.queue is not None and self.every_n_epochs(
runner, self.interval):
torch.save({'queue': self.queue}, self.queue_path)