Shortcuts

mmpretrain.engine.hooks.swav_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence

import torch
from mmengine.device import get_device
from mmengine.dist import get_rank, get_world_size, is_distributed
from mmengine.hooks import Hook
from mmengine.logging import MMLogger

from mmpretrain.registry import HOOKS
from mmpretrain.utils import get_ori_model


[文档]@HOOKS.register_module() class SwAVHook(Hook): """Hook for SwAV. This hook builds the queue in SwAV according to ``epoch_queue_starts``. The queue will be saved in ``runner.work_dir`` or loaded at start epoch if the path folder has queues saved before. Args: batch_size (int): the batch size per GPU for computing. epoch_queue_starts (int, optional): from this epoch, starts to use the queue. Defaults to 15. crops_for_assign (list[int], optional): list of crops id used for computing assignments. Defaults to [0, 1]. feat_dim (int, optional): feature dimension of output vector. Defaults to 128. queue_length (int, optional): length of the queue (0 for no queue). Defaults to 0. interval (int, optional): the interval to save the queue. Defaults to 1. frozen_layers_cfg (dict, optional): Dict to config frozen layers. The key-value pair is layer name and its frozen iters. If frozen, the layers don't need gradient. Defaults to dict(). """ def __init__( self, batch_size: int, epoch_queue_starts: Optional[int] = 15, crops_for_assign: Optional[List[int]] = [0, 1], feat_dim: Optional[int] = 128, queue_length: Optional[int] = 0, interval: Optional[int] = 1, frozen_layers_cfg: Optional[Dict] = dict() ) -> None: self.batch_size = batch_size * get_world_size() self.epoch_queue_starts = epoch_queue_starts self.crops_for_assign = crops_for_assign self.feat_dim = feat_dim self.queue_length = queue_length self.interval = interval self.frozen_layers_cfg = frozen_layers_cfg self.requires_grad = True self.queue = None
[文档] def before_run(self, runner) -> None: """Check whether the queues exist locally or not.""" if is_distributed(): self.queue_path = osp.join(runner.work_dir, 'queue' + str(get_rank()) + '.pth') else: self.queue_path = osp.join(runner.work_dir, 'queue.pth') # load the queues if queues exist locally if osp.isfile(self.queue_path): self.queue = torch.load(self.queue_path)['queue'] get_ori_model(runner.model).head.loss_module.queue = self.queue MMLogger.get_current_instance().info( f'Load queue from file: {self.queue_path}') # the queue needs to be divisible by the batch size self.queue_length -= self.queue_length % self.batch_size
[文档] def before_train_iter(self, runner, batch_idx: int, data_batch: Optional[Sequence[dict]] = None) -> None: """Freeze layers before specific iters according to the config.""" for layer, frozen_iters in self.frozen_layers_cfg.items(): if runner.iter < frozen_iters and self.requires_grad: self.requires_grad = False for name, p in get_ori_model(runner.model).named_parameters(): if layer in name: p.requires_grad = False elif runner.iter >= frozen_iters and not self.requires_grad: self.requires_grad = True for name, p in get_ori_model(runner.model).named_parameters(): if layer in name: p.requires_grad = True
[文档] def before_train_epoch(self, runner) -> None: """Check the queues' state.""" # optionally starts a queue if self.queue_length > 0 \ and runner.epoch >= self.epoch_queue_starts \ and self.queue is None: self.queue = torch.zeros( len(self.crops_for_assign), self.queue_length // runner.world_size, self.feat_dim, device=get_device(), ) # set the boolean type of use_the_queue get_ori_model(runner.model).head.loss_module.queue = self.queue get_ori_model(runner.model).head.loss_module.use_queue = False
[文档] def after_train_epoch(self, runner) -> None: """Save the queues locally.""" self.queue = get_ori_model(runner.model).head.loss_module.queue if self.queue is not None and self.every_n_epochs( runner, self.interval): torch.save({'queue': self.queue}, self.queue_path)
Read the Docs v: dev
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.