Shortcuts

PreciseBNHook

class mmpretrain.engine.hooks.PreciseBNHook(num_samples=8192, interval=1)[源代码]

Precise BN hook.

Recompute and update the batch norm stats to make them more precise. During training both BN stats and the weight are changing after every iteration, so the running average can not precisely reflect the actual stats of the current model.

With this hook, the BN stats are recomputed with fixed weights, to make the running average more precise. Specifically, it computes the true average of per-batch mean/variance instead of the running average. See Sec. 3 of the paper Rethinking Batch in BatchNorm <https://arxiv.org/abs/2105.07576> for details.

This hook will update BN stats, so it should be executed before CheckpointHook and EMAHook, generally set its priority to “ABOVE_NORMAL”.

参数:
  • num_samples (int) – The number of samples to update the bn stats. Defaults to 8192.

  • interval (int) – Perform precise bn interval. If the train loop is

  • by_epoch=True (EpochBasedTrainLoop or) – train loop is IterBasedTrainLoop or by_epoch=False, its unit is ‘iter’. Defaults to 1.

  • the (its unit is 'epoch'; if) – train loop is IterBasedTrainLoop or by_epoch=False, its unit is ‘iter’. Defaults to 1.

after_train_epoch(runner)[源代码]

Calculate prcise BN and broadcast BN stats across GPUs.

参数:

(obj (runner) – Runner): The runner of the training process.

after_train_iter(runner, batch_idx, data_batch=None, outputs=None)[源代码]

Calculate prcise BN and broadcast BN stats across GPUs.

参数:
  • (obj (runner) – Runner): The runner of the training process.

  • batch_idx (int) – The index of the current batch in the train loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.