Shortcuts

mmselfsup.engine.hooks.odc_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import numpy as np
import torch
from mmengine.hooks import Hook
from mmengine.logging import print_log

from mmselfsup.registry import HOOKS


[文档]@HOOKS.register_module() class ODCHook(Hook): """Hook for ODC. This hook includes the online clustering process in ODC. Args: centroids_update_interval (int): Frequency of iterations to update centroids. deal_with_small_clusters_interval (int): Frequency of iterations to deal with small clusters. evaluate_interval (int): Frequency of iterations to evaluate clusters. reweight (bool): Whether to perform loss re-weighting. reweight_pow (float): The power of re-weighting. dist_mode (bool): Use distributed training or not. Defaults to True. """ def __init__(self, centroids_update_interval: int, deal_with_small_clusters_interval: int, evaluate_interval: int, reweight: bool, reweight_pow: float, dist_mode: bool = True) -> None: assert dist_mode, 'non-dist mode is not implemented' self.centroids_update_interval = centroids_update_interval self.deal_with_small_clusters_interval = \ deal_with_small_clusters_interval self.evaluate_interval = evaluate_interval self.reweight = reweight self.reweight_pow = reweight_pow
[文档] def after_train_iter(self, runner, batch_idx: int, data_batch: Optional[Sequence[dict]] = None, outputs: Optional[dict] = None) -> None: """Update cluster centroids and the loss_weight.""" # centroids update if self.every_n_train_iters(runner, self.centroids_update_interval): runner.model.module.memory_bank.update_centroids_memory() # deal with small clusters if self.every_n_train_iters(runner, self.deal_with_small_clusters_interval): runner.model.module.memory_bank.deal_with_small_clusters() # reweight self.set_reweight(runner) # evaluate if self.every_n_train_iters(runner, self.evaluate_interval): new_labels = runner.model.module.memory_bank.label_bank if new_labels.is_cuda: new_labels = new_labels.cpu() self.evaluate(runner, new_labels.numpy())
[文档] def after_train_epoch(self, runner) -> None: """Save cluster.""" if self.every_n_epochs(runner, 10) and runner.rank == 0: new_labels = runner.model.module.memory_bank.label_bank if new_labels.is_cuda: new_labels = new_labels.cpu() np.save(f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy', new_labels.numpy())
[文档] def evaluate(self, runner, new_labels: np.ndarray) -> None: """Evaluate with labels histogram.""" histogram = np.bincount( new_labels, minlength=runner.model.module.memory_bank.num_classes) empty_cls = (histogram == 0).sum() minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max() if runner.rank == 0: print_log( f'empty_num: {empty_cls.item()}\t' f'min_cluster: {minimal_cls_size.item()}\t' f'max_cluster:{maximal_cls_size.item()}', logger='current')
[文档] def set_reweight(self, runner, labels: Optional[np.ndarray] = None, reweight_pow: float = 0.5): """Loss re-weighting. Re-weighting the loss according to the number of samples in each class. Args: runner (mmengine.Runner): mmengine Runner. labels (numpy.ndarray): Label assignments. reweight_pow (float, optional): The power of re-weighting. Defaults to 0.5. """ if labels is None: if runner.model.module.memory_bank.label_bank.is_cuda: labels = runner.model.module.memory_bank.label_bank.cpu( ).numpy() else: labels = runner.model.module.memory_bank.label_bank.numpy() histogram = np.bincount( labels, minlength=runner.model.module.memory_bank.num_classes).astype( np.float32) inv_histogram = (1. / (histogram + 1e-10))**reweight_pow weight = inv_histogram / inv_histogram.sum() runner.model.module.loss_weight.copy_(torch.from_numpy(weight)) runner.model.module.head.loss.class_weight = \ runner.model.module.loss_weight
Read the Docs v: stable
Versions
latest
stable
1.x
dev-1.x
0.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.