Shortcuts

mmselfsup.utils.batch_shuffle 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
from mmengine.dist import broadcast, get_rank

from .gather import concat_all_gather


[文档]@torch.no_grad() def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Batch shuffle, for making use of BatchNorm. Args: x (torch.Tensor): Data in each GPU. Returns: Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation. - x_gather[idx_this]: Shuffled data. - idx_unshuffle: Index for restoring. """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # random shuffle index idx_shuffle = torch.randperm(batch_size_all) # broadcast to all gpus broadcast(idx_shuffle, src=0) # index for restoring idx_unshuffle = torch.argsort(idx_shuffle) # shuffled index for this gpu gpu_idx = get_rank() idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this], idx_unshuffle
[文档]@torch.no_grad() def batch_unshuffle_ddp(x: torch.Tensor, idx_unshuffle: torch.Tensor) -> torch.Tensor: """Undo batch shuffle. Args: x (torch.Tensor): Data in each GPU. idx_unshuffle (torch.Tensor): Index for restoring. Returns: torch.Tensor: Output of unshuffle operation. """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # restored index for this gpu gpu_idx = get_rank() idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this]
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.