Note
You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.
Source code for mmselfsup.utils.collect
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from .gather import gather_tensors_batch
[docs]def nondist_forward_collect(func, data_loader, length):
"""Forward and collect network outputs.
This function performs forward propagation and collects outputs.
It can be used to collect results, features, losses, etc.
Args:
func (function): The function to process data. The output must be
a dictionary of CPU tensors.
data_loader (Dataloader): the torch Dataloader to yield data.
length (int): Expected length of output arrays.
Returns:
results_all (dict(np.ndarray)): The concatenated outputs.
"""
results = []
prog_bar = mmcv.ProgressBar(len(data_loader))
for i, data in enumerate(data_loader):
input_data = dict(img=data['img'])
with torch.no_grad():
result = func(**input_data) # feat_dict
results.append(result) # list of feat_dict
prog_bar.update()
results_all = {}
for k in results[0].keys():
results_all[k] = np.concatenate(
[batch[k].numpy() for batch in results], axis=0)
assert results_all[k].shape[0] == length
return results_all
[docs]def dist_forward_collect(func, data_loader, rank, length, ret_rank=-1):
"""Forward and collect network outputs in a distributed manner.
This function performs forward propagation and collects outputs.
It can be used to collect results, features, losses, etc.
Args:
func (function): The function to process data. The output must be
a dictionary of CPU tensors.
data_loader (Dataloader): the torch Dataloader to yield data.
rank (int): This process id.
length (int): Expected length of output arrays.
ret_rank (int): The process that returns.
Other processes will return None.
Returns:
results_all (dict(np.ndarray)): The concatenated outputs.
"""
results = []
if rank == 0:
prog_bar = mmcv.ProgressBar(len(data_loader))
for idx, data in enumerate(data_loader):
with torch.no_grad():
result = func(**data) # dict{key: tensor}
results.append(result)
if rank == 0:
prog_bar.update()
results_all = {}
for k in results[0].keys():
results_cat = np.concatenate([batch[k].numpy() for batch in results],
axis=0)
if ret_rank == -1:
results_gathered = gather_tensors_batch(results_cat, part_size=20)
results_strip = np.concatenate(results_gathered, axis=0)[:length]
else:
results_gathered = gather_tensors_batch(
results_cat, part_size=20, ret_rank=ret_rank)
if rank == ret_rank:
results_strip = np.concatenate(
results_gathered, axis=0)[:length]
else:
results_strip = None
results_all[k] = results_strip
return results_all