Shortcuts

Source code for mmpretrain.engine.hooks.retriever_hooks

# Copyright (c) OpenMMLab. All rights reserved
import warnings

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmpretrain.models import BaseRetriever
from mmpretrain.registry import HOOKS


[docs]@HOOKS.register_module() class PrepareProtoBeforeValLoopHook(Hook): """The hook to prepare the prototype in retrievers. Since the encoders of the retriever changes during training, the prototype changes accordingly. So the `prototype_vecs` needs to be regenerated before validation loop. """ def before_val(self, runner) -> None: model = runner.model if is_model_wrapper(model): model = model.module if isinstance(model, BaseRetriever): if hasattr(model, 'prepare_prototype'): model.prepare_prototype() else: warnings.warn( 'Only the `mmpretrain.models.retrievers.BaseRetriever` ' 'can execute `PrepareRetrieverPrototypeHook`, but got ' f'`{type(model)}`')
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.