Shortcuts

EMAHook

class mmpretrain.engine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, evaluate_on_ema=True, evaluate_on_origin=False, **kwargs)[源代码]

A Hook to apply Exponential Moving Average (EMA) on the model during training.

Comparing with mmengine.hooks.EMAHook, this hook accepts evaluate_on_ema and evaluate_on_origin arguments. By default, the evaluate_on_ema is enabled, and if you want to do validation and testing on both original and EMA models, please set both arguments True.

备注

  • EMAHook takes priority over CheckpointHook.

  • The original model parameters are actually saved in ema field after train.

  • begin_iter and begin_epoch cannot be set at the same time.

参数:
  • ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in mmengine.model.averaged_model. Defaults to ‘ExponentialMovingAverage’.

  • strict_load (bool) – Whether to strictly enforce that the keys of state_dict in checkpoint match the keys returned by self.module.state_dict. Defaults to False. Changed in v0.3.0.

  • begin_iter (int) – The number of iteration to enable EMAHook. Defaults to 0.

  • begin_epoch (int) – The number of epoch to enable EMAHook. Defaults to 0.

  • evaluate_on_ema (bool) – Whether to evaluate (validate and test) on EMA model during val-loop and test-loop. Defaults to True.

  • evaluate_on_origin (bool) – Whether to evaluate (validate and test) on the original model during val-loop and test-loop. Defaults to False.

  • **kwargs – Keyword arguments passed to subclasses of BaseAveragedModel

after_load_checkpoint(runner, checkpoint)[源代码]

Resume ema parameters from checkpoint.

参数:

runner (Runner) – The runner of the testing process.

after_test_epoch(runner, metrics=None)[源代码]

We recover source model’s parameter from ema model after test.

参数:
  • runner (Runner) – The runner of the testing process.

  • metrics (Dict[str, float], optional) – Evaluation results of all metrics on test dataset. The keys are the names of the metrics, and the values are corresponding results.

after_val_epoch(runner, metrics=None)[源代码]

We recover source model’s parameter from ema model after validation.

参数:
  • runner (Runner) – The runner of the validation process.

  • metrics (Dict[str, float], optional) – Evaluation results of all metrics on validation dataset. The keys are the names of the metrics, and the values are corresponding results.

before_test_epoch(runner)[源代码]

We load parameter values from ema model to source model before test.

参数:

runner (Runner) – The runner of the training process.

before_val_epoch(runner)[源代码]

We load parameter values from ema model to source model before validation.

参数:

runner (Runner) – The runner of the training process.

Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.