mmpretrain.models.backbones.sparse_resnet 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Tuple
import torch.nn as nn
from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling,
SparseBatchNorm2d,
SparseConv2d,
SparseMaxPooling,
SparseSyncBatchNorm2d)
from mmpretrain.registry import MODELS
from .resnet import ResNet
[文档]@MODELS.register_module()
class SparseResNet(ResNet):
"""ResNet with sparse module conversion function.
Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Output channels of the stem layer. Defaults to 64.
base_channels (int): Middle channels of the first stage.
Defaults to 64.
num_stages (int): Stages of the network. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
"""
def __init__(self,
depth: int,
in_channels: int = 3,
stem_channels: int = 64,
base_channels: int = 64,
expansion: Optional[int] = None,
num_stages: int = 4,
strides: Tuple[int] = (1, 2, 2, 2),
dilations: Tuple[int] = (1, 1, 1, 1),
out_indices: Tuple[int] = (3, ),
style: str = 'pytorch',
deep_stem: bool = False,
avg_down: bool = False,
frozen_stages: int = -1,
conv_cfg: Optional[dict] = None,
norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'),
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = False,
init_cfg: Optional[dict] = [
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
],
drop_path_rate: float = 0,
**kwargs):
super().__init__(
depth=depth,
in_channels=in_channels,
stem_channels=stem_channels,
base_channels=base_channels,
expansion=expansion,
num_stages=num_stages,
strides=strides,
dilations=dilations,
out_indices=out_indices,
style=style,
deep_stem=deep_stem,
avg_down=avg_down,
frozen_stages=frozen_stages,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
norm_eval=norm_eval,
with_cp=with_cp,
zero_init_residual=zero_init_residual,
init_cfg=init_cfg,
drop_path_rate=drop_path_rate,
**kwargs)
norm_type = norm_cfg['type']
enable_sync_bn = False
if re.search('Sync', norm_type) is not None:
enable_sync_bn = True
self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn)
[文档] def dense_model_to_sparse(self, m: nn.Module,
enable_sync_bn: bool) -> nn.Module:
"""Convert regular dense modules to sparse modules."""
output = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
output = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
output.weight.data.copy_(m.weight.data)
if bias:
output.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
output = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
output = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
output = (SparseSyncBatchNorm2d
if enable_sync_bn else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats)
output.weight.data.copy_(m.weight.data)
output.bias.data.copy_(m.bias.data)
output.running_mean.data.copy_(m.running_mean.data)
output.running_var.data.copy_(m.running_var.data)
output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
elif isinstance(m, (nn.Conv1d, )):
raise NotImplementedError
for name, child in m.named_children():
output.add_module(
name,
self.dense_model_to_sparse(
child, enable_sync_bn=enable_sync_bn))
del m
return output