Shortcuts

mmselfsup.models.utils.gather_layer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple

import torch
from mmengine.dist import all_gather, get_rank


[文档]class GatherLayer(torch.autograd.Function): """Gather tensors from all process, supporting backward propagation."""
[文档] @staticmethod def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]: ctx.save_for_backward(input) output = all_gather(input) return tuple(output)
[文档] @staticmethod def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor: input, = ctx.saved_tensors grad_out = torch.zeros_like(input) grad_out[:] = grads[get_rank()] return grad_out
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.