| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- import torch
- from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU
- import math
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import Optional, List, Tuple
- def Conv_Block(in_channel, out_channel, kernel_size, stride, padding, group=1, has_bn=True, is_linear=False):
- return Sequential(
- Conv2d(in_channel, out_channel, kernel_size, stride, padding=padding, groups=group, bias=False),
- BatchNorm2d(out_channel) if has_bn else Sequential(),
- ReLU(inplace=True) if not is_linear else Sequential()
- )
- class InvertedResidual(Module):
- def __init__(self, in_channel, out_channel, stride, use_res_connect, expand_ratio):
- super(InvertedResidual, self).__init__()
- self.stride = stride
- assert stride in [1, 2]
- exp_channel = in_channel * expand_ratio
- self.use_res_connect = use_res_connect
- self.inv_res = Sequential(
- Conv_Block(in_channel=in_channel, out_channel=exp_channel, kernel_size=1, stride=1, padding=0),
- Conv_Block(in_channel=exp_channel, out_channel=exp_channel, kernel_size=3, stride=stride, padding=1,
- group=exp_channel),
- Conv_Block(in_channel=exp_channel, out_channel=out_channel, kernel_size=1, stride=1, padding=0,
- is_linear=True)
- )
- def forward(self, x):
- if self.use_res_connect:
- return x + self.inv_res(x)
- else:
- return self.inv_res(x)
- class GhostModule(Module):
- def __init__(self, in_channel, out_channel, is_linear=False):
- super(GhostModule, self).__init__()
- self.out_channel = out_channel
- init_channel = math.ceil(out_channel / 2)
- new_channel = init_channel
- self.primary_conv = Conv_Block(in_channel, init_channel, 1, 1, 0, is_linear=is_linear)
- self.cheap_operation = Conv_Block(init_channel, new_channel, 3, 1, 1, group=init_channel, is_linear=is_linear)
- def forward(self, x):
- x1 = self.primary_conv(x)
- x2 = self.cheap_operation(x1)
- out = torch.cat([x1, x2], dim=1)
- return out[:, :self.out_channel, :, :]
- class GhostBottleneck(Module):
- def __init__(self, in_channel, hidden_channel, out_channel, stride):
- super(GhostBottleneck, self).__init__()
- assert stride in [1, 2]
- self.ghost_conv = Sequential(
- # GhostModule
- GhostModule(in_channel, hidden_channel, is_linear=False),
- # DepthwiseConv-linear
- Conv_Block(hidden_channel, hidden_channel, 3, stride, 1, group=hidden_channel,
- is_linear=True) if stride == 2 else Sequential(),
- # GhostModule-linear
- GhostModule(hidden_channel, out_channel, is_linear=True)
- )
- if stride == 1 and in_channel == out_channel:
- self.shortcut = Sequential()
- else:
- self.shortcut = Sequential(
- Conv_Block(in_channel, in_channel, 3, stride, 1, group=in_channel, is_linear=True),
- Conv_Block(in_channel, out_channel, 1, 1, 0, is_linear=True)
- )
- def forward(self, x):
- return self.ghost_conv(x) + self.shortcut(x)
- class GhostOneModule(Module):
- def __init__(self, in_channel, out_channel, is_linear=False, inference_mode=False, num_conv_branches=1):
- super(GhostOneModule, self).__init__()
- self.out_channel = out_channel
- half_outchannel = math.ceil(out_channel / 2)
- self.inference_mode = inference_mode
- self.num_conv_branches = num_conv_branches
- self.primary_conv = MobileOneBlock(in_channels=in_channel,
- out_channels=half_outchannel,
- kernel_size=1,
- stride=1,
- padding=0,
- groups=1,
- inference_mode=self.inference_mode,
- use_se=False,
- num_conv_branches=self.num_conv_branches,
- is_linear=is_linear)
- self.cheap_operation = MobileOneBlock(in_channels=half_outchannel,
- out_channels=half_outchannel,
- kernel_size=3,
- stride=1,
- padding=1,
- groups=half_outchannel,
- inference_mode=self.inference_mode,
- use_se=False,
- num_conv_branches=self.num_conv_branches,
- is_linear=is_linear)
- def forward(self, x):
- x1 = self.primary_conv(x)
- x2 = self.cheap_operation(x1)
- out = torch.cat([x1, x2], dim=1)
- return out
- class GhostOneBottleneck(Module):
- def __init__(self, in_channel, hidden_channel, out_channel, stride, inference_mode=False, num_conv_branches=1):
- super(GhostOneBottleneck, self).__init__()
- assert stride in [1, 2]
- self.inference_mode = inference_mode
- self.num_conv_branches = num_conv_branches
- self.ghost_conv = Sequential(
- # GhostModule
- GhostOneModule(in_channel, hidden_channel, is_linear=False, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches),
- # DepthwiseConv-linear
- MobileOneBlock(in_channels=hidden_channel,
- out_channels=hidden_channel,
- kernel_size=3,
- stride=stride,
- padding=1,
- groups=hidden_channel,
- inference_mode=self.inference_mode,
- use_se=False,
- num_conv_branches=self.num_conv_branches,
- is_linear=True) if stride == 2 else Sequential(),
- # GhostModule-linear
- GhostOneModule(hidden_channel, out_channel, is_linear=True, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
- )
- def forward(self, x):
- return self.ghost_conv(x)
- class SEBlock(nn.Module):
- """ Squeeze and Excite module.
- Pytorch implementation of `Squeeze-and-Excitation Networks` -
- https://arxiv.org/pdf/1709.01507.pdf
- """
- def __init__(self,
- in_channels: int,
- rd_ratio: float = 0.0625) -> None:
- """ Construct a Squeeze and Excite Module.
- :param in_channels: Number of input channels.
- :param rd_ratio: Input channel reduction ratio.
- """
- super(SEBlock, self).__init__()
- self.reduce = nn.Conv2d(in_channels=in_channels,
- out_channels=int(in_channels * rd_ratio),
- kernel_size=1,
- stride=1,
- bias=True)
- self.expand = nn.Conv2d(in_channels=int(in_channels * rd_ratio),
- out_channels=in_channels,
- kernel_size=1,
- stride=1,
- bias=True)
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- """ Apply forward pass. """
- b, c, h, w = inputs.size()
- x = F.avg_pool2d(inputs, kernel_size=[h, w])
- x = self.reduce(x)
- x = F.relu(x)
- x = self.expand(x)
- x = torch.sigmoid(x)
- x = x.view(-1, c, 1, 1)
- return inputs * x
- class MobileOneBlock(nn.Module):
- """ MobileOne building block.
- This block has a multi-branched architecture at train-time
- and plain-CNN style architecture at inference time
- For more details, please refer to our paper:
- `An Improved One millisecond Mobile Backbone` -
- https://arxiv.org/pdf/2206.04040.pdf
- """
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int = 1,
- padding: int = 0,
- dilation: int = 1,
- groups: int = 1,
- inference_mode: bool = False,
- use_se: bool = False,
- num_conv_branches: int = 1,
- is_linear: bool = False) -> None:
- """ Construct a MobileOneBlock module.
- :param in_channels: Number of channels in the input.
- :param out_channels: Number of channels produced by the block.
- :param kernel_size: Size of the convolution kernel.
- :param stride: Stride size.
- :param padding: Zero-padding size.
- :param dilation: Kernel dilation factor.
- :param groups: Group number.
- :param inference_mode: If True, instantiates model in inference mode.
- :param use_se: Whether to use SE-ReLU activations.
- :param num_conv_branches: Number of linear conv branches.
- """
- super(MobileOneBlock, self).__init__()
- self.inference_mode = inference_mode
- self.groups = groups
- self.stride = stride
- self.kernel_size = kernel_size
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.num_conv_branches = num_conv_branches
- # Check if SE-ReLU is requested
- if use_se:
- self.se = SEBlock(out_channels)
- else:
- self.se = nn.Identity()
- if is_linear:
- self.activation = nn.Identity()
- else:
- self.activation = nn.ReLU()
- if inference_mode:
- self.reparam_conv = nn.Conv2d(in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias=True)
- else:
- # Re-parameterizable skip connection
- self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \
- if out_channels == in_channels and stride == 1 else None
- # Re-parameterizable conv branches
- rbr_conv = list()
- for _ in range(self.num_conv_branches):
- rbr_conv.append(self._conv_bn(kernel_size=kernel_size,
- padding=padding))
- self.rbr_conv = nn.ModuleList(rbr_conv)
- # Re-parameterizable scale branch
- self.rbr_scale = None
- if kernel_size > 1:
- self.rbr_scale = self._conv_bn(kernel_size=1,
- padding=0)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """ Apply forward pass. """
- # Inference mode forward pass.
- if self.inference_mode:
- return self.activation(self.se(self.reparam_conv(x)))
- # Multi-branched train-time forward pass.
- # Skip branch output
- identity_out = 0
- if self.rbr_skip is not None:
- identity_out = self.rbr_skip(x)
- # Scale branch output
- scale_out = 0
- if self.rbr_scale is not None:
- scale_out = self.rbr_scale(x)
- # Other branches
- out = scale_out + identity_out
- for ix in range(self.num_conv_branches):
- out += self.rbr_conv[ix](x)
- return self.activation(self.se(out))
- def reparameterize(self):
- """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
- https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
- architecture used at training time to obtain a plain CNN-like structure
- for inference.
- """
- if self.inference_mode:
- return
- kernel, bias = self._get_kernel_bias()
- self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels,
- out_channels=self.rbr_conv[0].conv.out_channels,
- kernel_size=self.rbr_conv[0].conv.kernel_size,
- stride=self.rbr_conv[0].conv.stride,
- padding=self.rbr_conv[0].conv.padding,
- dilation=self.rbr_conv[0].conv.dilation,
- groups=self.rbr_conv[0].conv.groups,
- bias=True)
- self.reparam_conv.weight.data = kernel
- self.reparam_conv.bias.data = bias
- # Delete un-used branches
- for para in self.parameters():
- para.detach_()
- self.__delattr__('rbr_conv')
- self.__delattr__('rbr_scale')
- if hasattr(self, 'rbr_skip'):
- self.__delattr__('rbr_skip')
- self.inference_mode = True
- def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
- """ Method to obtain re-parameterized kernel and bias.
- Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
- :return: Tuple of (kernel, bias) after fusing branches.
- """
- # get weights and bias of scale branch
- kernel_scale = 0
- bias_scale = 0
- if self.rbr_scale is not None:
- kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
- # Pad scale branch kernel to match conv branch kernel size.
- pad = self.kernel_size // 2
- kernel_scale = torch.nn.functional.pad(kernel_scale,
- [pad, pad, pad, pad])
- # get weights and bias of skip branch
- kernel_identity = 0
- bias_identity = 0
- if self.rbr_skip is not None:
- kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
- # get weights and bias of conv branches
- kernel_conv = 0
- bias_conv = 0
- for ix in range(self.num_conv_branches):
- _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
- kernel_conv += _kernel
- bias_conv += _bias
- kernel_final = kernel_conv + kernel_scale + kernel_identity
- bias_final = bias_conv + bias_scale + bias_identity
- return kernel_final, bias_final
- def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
- """ Method to fuse batchnorm layer with preceeding conv layer.
- Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
- :param branch:
- :return: Tuple of (kernel, bias) after fusing batchnorm.
- """
- if isinstance(branch, nn.Sequential):
- kernel = branch.conv.weight
- running_mean = branch.bn.running_mean
- running_var = branch.bn.running_var
- gamma = branch.bn.weight
- beta = branch.bn.bias
- eps = branch.bn.eps
- else:
- assert isinstance(branch, nn.BatchNorm2d)
- if not hasattr(self, 'id_tensor'):
- input_dim = self.in_channels // self.groups
- kernel_value = torch.zeros((self.in_channels,
- input_dim,
- self.kernel_size,
- self.kernel_size),
- dtype=branch.weight.dtype,
- device=branch.weight.device)
- for i in range(self.in_channels):
- kernel_value[i, i % input_dim,
- self.kernel_size // 2,
- self.kernel_size // 2] = 1
- self.id_tensor = kernel_value
- kernel = self.id_tensor
- running_mean = branch.running_mean
- running_var = branch.running_var
- gamma = branch.weight
- beta = branch.bias
- eps = branch.eps
- std = (running_var + eps).sqrt()
- t = (gamma / std).reshape(-1, 1, 1, 1)
- return kernel * t, beta - running_mean * gamma / std
- def _conv_bn(self,
- kernel_size: int,
- padding: int) -> nn.Sequential:
- """ Helper method to construct conv-batchnorm layers.
- :param kernel_size: Size of the convolution kernel.
- :param padding: Zero-padding size.
- :return: Conv-BN module.
- """
- mod_list = nn.Sequential()
- mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=kernel_size,
- stride=self.stride,
- padding=padding,
- groups=self.groups,
- bias=False))
- mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
- return mod_list
|