Shortcuts

Note

You are reading the documentation for MMSelfSup 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMSelfSup 1.0.0rc versions to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the changelog, code and documentation of MMSelfSup 1.0.0rc for more details.

Source code for mmselfsup.utils.collect

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch

from .gather import gather_tensors_batch


[docs]def nondist_forward_collect(func, data_loader, length): """Forward and collect network outputs. This function performs forward propagation and collects outputs. It can be used to collect results, features, losses, etc. Args: func (function): The function to process data. The output must be a dictionary of CPU tensors. data_loader (Dataloader): the torch Dataloader to yield data. length (int): Expected length of output arrays. Returns: results_all (dict(np.ndarray)): The concatenated outputs. """ results = [] prog_bar = mmcv.ProgressBar(len(data_loader)) for i, data in enumerate(data_loader): input_data = dict(img=data['img']) with torch.no_grad(): result = func(**input_data) # feat_dict results.append(result) # list of feat_dict prog_bar.update() results_all = {} for k in results[0].keys(): results_all[k] = np.concatenate( [batch[k].numpy() for batch in results], axis=0) assert results_all[k].shape[0] == length return results_all
[docs]def dist_forward_collect(func, data_loader, rank, length, ret_rank=-1): """Forward and collect network outputs in a distributed manner. This function performs forward propagation and collects outputs. It can be used to collect results, features, losses, etc. Args: func (function): The function to process data. The output must be a dictionary of CPU tensors. data_loader (Dataloader): the torch Dataloader to yield data. rank (int): This process id. length (int): Expected length of output arrays. ret_rank (int): The process that returns. Other processes will return None. Returns: results_all (dict(np.ndarray)): The concatenated outputs. """ results = [] if rank == 0: prog_bar = mmcv.ProgressBar(len(data_loader)) for idx, data in enumerate(data_loader): with torch.no_grad(): result = func(**data) # dict{key: tensor} results.append(result) if rank == 0: prog_bar.update() results_all = {} for k in results[0].keys(): results_cat = np.concatenate([batch[k].numpy() for batch in results], axis=0) if ret_rank == -1: results_gathered = gather_tensors_batch(results_cat, part_size=20) results_strip = np.concatenate(results_gathered, axis=0)[:length] else: results_gathered = gather_tensors_batch( results_cat, part_size=20, ret_rank=ret_rank) if rank == ret_rank: results_strip = np.concatenate( results_gathered, axis=0)[:length] else: results_strip = None results_all[k] = results_strip return results_all
Read the Docs v: 0.x
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.