Shortcuts

mmpretrain.models.backbones.sparse_resnet 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Tuple

import torch.nn as nn

from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling,
                                                    SparseBatchNorm2d,
                                                    SparseConv2d,
                                                    SparseMaxPooling,
                                                    SparseSyncBatchNorm2d)
from mmpretrain.registry import MODELS
from .resnet import ResNet


[文档]@MODELS.register_module() class SparseResNet(ResNet): """ResNet with sparse module conversion function. Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py Args: depth (int): Network depth, from {18, 34, 50, 101, 152}. in_channels (int): Number of input image channels. Defaults to 3. stem_channels (int): Output channels of the stem layer. Defaults to 64. base_channels (int): Middle channels of the first stage. Defaults to 64. num_stages (int): Stages of the network. Defaults to 4. strides (Sequence[int]): Strides of the first block of each stage. Defaults to ``(1, 2, 2, 2)``. dilations (Sequence[int]): Dilation of each stage. Defaults to ``(1, 1, 1, 1)``. out_indices (Sequence[int]): Output from which stages. Defaults to ``(3, )``. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. Defaults to False. avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. conv_cfg (dict | None): The config dict for conv layers. Defaults to None. norm_cfg (dict): The config dict for norm layers. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Defaults to False. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. Defaults to True. drop_path_rate (float): stochastic depth rate. Defaults to 0. """ def __init__(self, depth: int, in_channels: int = 3, stem_channels: int = 64, base_channels: int = 64, expansion: Optional[int] = None, num_stages: int = 4, strides: Tuple[int] = (1, 2, 2, 2), dilations: Tuple[int] = (1, 1, 1, 1), out_indices: Tuple[int] = (3, ), style: str = 'pytorch', deep_stem: bool = False, avg_down: bool = False, frozen_stages: int = -1, conv_cfg: Optional[dict] = None, norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'), norm_eval: bool = False, with_cp: bool = False, zero_init_residual: bool = False, init_cfg: Optional[dict] = [ dict(type='Kaiming', layer=['Conv2d']), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ], drop_path_rate: float = 0, **kwargs): super().__init__( depth=depth, in_channels=in_channels, stem_channels=stem_channels, base_channels=base_channels, expansion=expansion, num_stages=num_stages, strides=strides, dilations=dilations, out_indices=out_indices, style=style, deep_stem=deep_stem, avg_down=avg_down, frozen_stages=frozen_stages, conv_cfg=conv_cfg, norm_cfg=norm_cfg, norm_eval=norm_eval, with_cp=with_cp, zero_init_residual=zero_init_residual, init_cfg=init_cfg, drop_path_rate=drop_path_rate, **kwargs) norm_type = norm_cfg['type'] enable_sync_bn = False if re.search('Sync', norm_type) is not None: enable_sync_bn = True self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn)
[文档] def dense_model_to_sparse(self, m: nn.Module, enable_sync_bn: bool) -> nn.Module: """Convert regular dense modules to sparse modules.""" output = m if isinstance(m, nn.Conv2d): m: nn.Conv2d bias = m.bias is not None output = SparseConv2d( m.in_channels, m.out_channels, kernel_size=m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode, ) output.weight.data.copy_(m.weight.data) if bias: output.bias.data.copy_(m.bias.data) elif isinstance(m, nn.MaxPool2d): m: nn.MaxPool2d output = SparseMaxPooling( m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode) elif isinstance(m, nn.AvgPool2d): m: nn.AvgPool2d output = SparseAvgPooling( m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override) elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): m: nn.BatchNorm2d output = (SparseSyncBatchNorm2d if enable_sync_bn else SparseBatchNorm2d)( m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats) output.weight.data.copy_(m.weight.data) output.bias.data.copy_(m.bias.data) output.running_mean.data.copy_(m.running_mean.data) output.running_var.data.copy_(m.running_var.data) output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) elif isinstance(m, (nn.Conv1d, )): raise NotImplementedError for name, child in m.named_children(): output.add_module( name, self.dense_model_to_sparse( child, enable_sync_bn=enable_sync_bn)) del m return output