Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.models.utils.gather_layer
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.distributed as dist
[docs]class GatherLayer(torch.autograd.Function):
"""Gather tensors from all process, supporting backward propagation."""
[docs] @staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [
torch.zeros_like(input) for _ in range(dist.get_world_size())
]
dist.all_gather(output, input)
return tuple(output)
[docs] @staticmethod
def backward(ctx, *grads):
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out