base_module.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU
  5. import math
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from typing import Optional, List, Tuple
  9. def Conv_Block(in_channel, out_channel, kernel_size, stride, padding, group=1, has_bn=True, is_linear=False):
  10. return Sequential(
  11. Conv2d(in_channel, out_channel, kernel_size, stride, padding=padding, groups=group, bias=False),
  12. BatchNorm2d(out_channel) if has_bn else Sequential(),
  13. ReLU(inplace=True) if not is_linear else Sequential()
  14. )
  15. class InvertedResidual(Module):
  16. def __init__(self, in_channel, out_channel, stride, use_res_connect, expand_ratio):
  17. super(InvertedResidual, self).__init__()
  18. self.stride = stride
  19. assert stride in [1, 2]
  20. exp_channel = in_channel * expand_ratio
  21. self.use_res_connect = use_res_connect
  22. self.inv_res = Sequential(
  23. Conv_Block(in_channel=in_channel, out_channel=exp_channel, kernel_size=1, stride=1, padding=0),
  24. Conv_Block(in_channel=exp_channel, out_channel=exp_channel, kernel_size=3, stride=stride, padding=1,
  25. group=exp_channel),
  26. Conv_Block(in_channel=exp_channel, out_channel=out_channel, kernel_size=1, stride=1, padding=0,
  27. is_linear=True)
  28. )
  29. def forward(self, x):
  30. if self.use_res_connect:
  31. return x + self.inv_res(x)
  32. else:
  33. return self.inv_res(x)
  34. class GhostModule(Module):
  35. def __init__(self, in_channel, out_channel, is_linear=False):
  36. super(GhostModule, self).__init__()
  37. self.out_channel = out_channel
  38. init_channel = math.ceil(out_channel / 2)
  39. new_channel = init_channel
  40. self.primary_conv = Conv_Block(in_channel, init_channel, 1, 1, 0, is_linear=is_linear)
  41. self.cheap_operation = Conv_Block(init_channel, new_channel, 3, 1, 1, group=init_channel, is_linear=is_linear)
  42. def forward(self, x):
  43. x1 = self.primary_conv(x)
  44. x2 = self.cheap_operation(x1)
  45. out = torch.cat([x1, x2], dim=1)
  46. return out[:, :self.out_channel, :, :]
  47. class GhostBottleneck(Module):
  48. def __init__(self, in_channel, hidden_channel, out_channel, stride):
  49. super(GhostBottleneck, self).__init__()
  50. assert stride in [1, 2]
  51. self.ghost_conv = Sequential(
  52. # GhostModule
  53. GhostModule(in_channel, hidden_channel, is_linear=False),
  54. # DepthwiseConv-linear
  55. Conv_Block(hidden_channel, hidden_channel, 3, stride, 1, group=hidden_channel,
  56. is_linear=True) if stride == 2 else Sequential(),
  57. # GhostModule-linear
  58. GhostModule(hidden_channel, out_channel, is_linear=True)
  59. )
  60. if stride == 1 and in_channel == out_channel:
  61. self.shortcut = Sequential()
  62. else:
  63. self.shortcut = Sequential(
  64. Conv_Block(in_channel, in_channel, 3, stride, 1, group=in_channel, is_linear=True),
  65. Conv_Block(in_channel, out_channel, 1, 1, 0, is_linear=True)
  66. )
  67. def forward(self, x):
  68. return self.ghost_conv(x) + self.shortcut(x)
  69. class GhostOneModule(Module):
  70. def __init__(self, in_channel, out_channel, is_linear=False, inference_mode=False, num_conv_branches=1):
  71. super(GhostOneModule, self).__init__()
  72. self.out_channel = out_channel
  73. half_outchannel = math.ceil(out_channel / 2)
  74. self.inference_mode = inference_mode
  75. self.num_conv_branches = num_conv_branches
  76. self.primary_conv = MobileOneBlock(in_channels=in_channel,
  77. out_channels=half_outchannel,
  78. kernel_size=1,
  79. stride=1,
  80. padding=0,
  81. groups=1,
  82. inference_mode=self.inference_mode,
  83. use_se=False,
  84. num_conv_branches=self.num_conv_branches,
  85. is_linear=is_linear)
  86. self.cheap_operation = MobileOneBlock(in_channels=half_outchannel,
  87. out_channels=half_outchannel,
  88. kernel_size=3,
  89. stride=1,
  90. padding=1,
  91. groups=half_outchannel,
  92. inference_mode=self.inference_mode,
  93. use_se=False,
  94. num_conv_branches=self.num_conv_branches,
  95. is_linear=is_linear)
  96. def forward(self, x):
  97. x1 = self.primary_conv(x)
  98. x2 = self.cheap_operation(x1)
  99. out = torch.cat([x1, x2], dim=1)
  100. return out
  101. class GhostOneBottleneck(Module):
  102. def __init__(self, in_channel, hidden_channel, out_channel, stride, inference_mode=False, num_conv_branches=1):
  103. super(GhostOneBottleneck, self).__init__()
  104. assert stride in [1, 2]
  105. self.inference_mode = inference_mode
  106. self.num_conv_branches = num_conv_branches
  107. self.ghost_conv = Sequential(
  108. # GhostModule
  109. GhostOneModule(in_channel, hidden_channel, is_linear=False, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches),
  110. # DepthwiseConv-linear
  111. MobileOneBlock(in_channels=hidden_channel,
  112. out_channels=hidden_channel,
  113. kernel_size=3,
  114. stride=stride,
  115. padding=1,
  116. groups=hidden_channel,
  117. inference_mode=self.inference_mode,
  118. use_se=False,
  119. num_conv_branches=self.num_conv_branches,
  120. is_linear=True) if stride == 2 else Sequential(),
  121. # GhostModule-linear
  122. GhostOneModule(hidden_channel, out_channel, is_linear=True, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  123. )
  124. def forward(self, x):
  125. return self.ghost_conv(x)
  126. class SEBlock(nn.Module):
  127. """ Squeeze and Excite module.
  128. Pytorch implementation of `Squeeze-and-Excitation Networks` -
  129. https://arxiv.org/pdf/1709.01507.pdf
  130. """
  131. def __init__(self,
  132. in_channels: int,
  133. rd_ratio: float = 0.0625) -> None:
  134. """ Construct a Squeeze and Excite Module.
  135. :param in_channels: Number of input channels.
  136. :param rd_ratio: Input channel reduction ratio.
  137. """
  138. super(SEBlock, self).__init__()
  139. self.reduce = nn.Conv2d(in_channels=in_channels,
  140. out_channels=int(in_channels * rd_ratio),
  141. kernel_size=1,
  142. stride=1,
  143. bias=True)
  144. self.expand = nn.Conv2d(in_channels=int(in_channels * rd_ratio),
  145. out_channels=in_channels,
  146. kernel_size=1,
  147. stride=1,
  148. bias=True)
  149. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  150. """ Apply forward pass. """
  151. b, c, h, w = inputs.size()
  152. x = F.avg_pool2d(inputs, kernel_size=[h, w])
  153. x = self.reduce(x)
  154. x = F.relu(x)
  155. x = self.expand(x)
  156. x = torch.sigmoid(x)
  157. x = x.view(-1, c, 1, 1)
  158. return inputs * x
  159. class MobileOneBlock(nn.Module):
  160. """ MobileOne building block.
  161. This block has a multi-branched architecture at train-time
  162. and plain-CNN style architecture at inference time
  163. For more details, please refer to our paper:
  164. `An Improved One millisecond Mobile Backbone` -
  165. https://arxiv.org/pdf/2206.04040.pdf
  166. """
  167. def __init__(self,
  168. in_channels: int,
  169. out_channels: int,
  170. kernel_size: int,
  171. stride: int = 1,
  172. padding: int = 0,
  173. dilation: int = 1,
  174. groups: int = 1,
  175. inference_mode: bool = False,
  176. use_se: bool = False,
  177. num_conv_branches: int = 1,
  178. is_linear: bool = False) -> None:
  179. """ Construct a MobileOneBlock module.
  180. :param in_channels: Number of channels in the input.
  181. :param out_channels: Number of channels produced by the block.
  182. :param kernel_size: Size of the convolution kernel.
  183. :param stride: Stride size.
  184. :param padding: Zero-padding size.
  185. :param dilation: Kernel dilation factor.
  186. :param groups: Group number.
  187. :param inference_mode: If True, instantiates model in inference mode.
  188. :param use_se: Whether to use SE-ReLU activations.
  189. :param num_conv_branches: Number of linear conv branches.
  190. """
  191. super(MobileOneBlock, self).__init__()
  192. self.inference_mode = inference_mode
  193. self.groups = groups
  194. self.stride = stride
  195. self.kernel_size = kernel_size
  196. self.in_channels = in_channels
  197. self.out_channels = out_channels
  198. self.num_conv_branches = num_conv_branches
  199. # Check if SE-ReLU is requested
  200. if use_se:
  201. self.se = SEBlock(out_channels)
  202. else:
  203. self.se = nn.Identity()
  204. if is_linear:
  205. self.activation = nn.Identity()
  206. else:
  207. self.activation = nn.ReLU()
  208. if inference_mode:
  209. self.reparam_conv = nn.Conv2d(in_channels=in_channels,
  210. out_channels=out_channels,
  211. kernel_size=kernel_size,
  212. stride=stride,
  213. padding=padding,
  214. dilation=dilation,
  215. groups=groups,
  216. bias=True)
  217. else:
  218. # Re-parameterizable skip connection
  219. self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \
  220. if out_channels == in_channels and stride == 1 else None
  221. # Re-parameterizable conv branches
  222. rbr_conv = list()
  223. for _ in range(self.num_conv_branches):
  224. rbr_conv.append(self._conv_bn(kernel_size=kernel_size,
  225. padding=padding))
  226. self.rbr_conv = nn.ModuleList(rbr_conv)
  227. # Re-parameterizable scale branch
  228. self.rbr_scale = None
  229. if kernel_size > 1:
  230. self.rbr_scale = self._conv_bn(kernel_size=1,
  231. padding=0)
  232. def forward(self, x: torch.Tensor) -> torch.Tensor:
  233. """ Apply forward pass. """
  234. # Inference mode forward pass.
  235. if self.inference_mode:
  236. return self.activation(self.se(self.reparam_conv(x)))
  237. # Multi-branched train-time forward pass.
  238. # Skip branch output
  239. identity_out = 0
  240. if self.rbr_skip is not None:
  241. identity_out = self.rbr_skip(x)
  242. # Scale branch output
  243. scale_out = 0
  244. if self.rbr_scale is not None:
  245. scale_out = self.rbr_scale(x)
  246. # Other branches
  247. out = scale_out + identity_out
  248. for ix in range(self.num_conv_branches):
  249. out += self.rbr_conv[ix](x)
  250. return self.activation(self.se(out))
  251. def reparameterize(self):
  252. """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
  253. https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
  254. architecture used at training time to obtain a plain CNN-like structure
  255. for inference.
  256. """
  257. if self.inference_mode:
  258. return
  259. kernel, bias = self._get_kernel_bias()
  260. self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels,
  261. out_channels=self.rbr_conv[0].conv.out_channels,
  262. kernel_size=self.rbr_conv[0].conv.kernel_size,
  263. stride=self.rbr_conv[0].conv.stride,
  264. padding=self.rbr_conv[0].conv.padding,
  265. dilation=self.rbr_conv[0].conv.dilation,
  266. groups=self.rbr_conv[0].conv.groups,
  267. bias=True)
  268. self.reparam_conv.weight.data = kernel
  269. self.reparam_conv.bias.data = bias
  270. # Delete un-used branches
  271. for para in self.parameters():
  272. para.detach_()
  273. self.__delattr__('rbr_conv')
  274. self.__delattr__('rbr_scale')
  275. if hasattr(self, 'rbr_skip'):
  276. self.__delattr__('rbr_skip')
  277. self.inference_mode = True
  278. def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
  279. """ Method to obtain re-parameterized kernel and bias.
  280. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
  281. :return: Tuple of (kernel, bias) after fusing branches.
  282. """
  283. # get weights and bias of scale branch
  284. kernel_scale = 0
  285. bias_scale = 0
  286. if self.rbr_scale is not None:
  287. kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
  288. # Pad scale branch kernel to match conv branch kernel size.
  289. pad = self.kernel_size // 2
  290. kernel_scale = torch.nn.functional.pad(kernel_scale,
  291. [pad, pad, pad, pad])
  292. # get weights and bias of skip branch
  293. kernel_identity = 0
  294. bias_identity = 0
  295. if self.rbr_skip is not None:
  296. kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
  297. # get weights and bias of conv branches
  298. kernel_conv = 0
  299. bias_conv = 0
  300. for ix in range(self.num_conv_branches):
  301. _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
  302. kernel_conv += _kernel
  303. bias_conv += _bias
  304. kernel_final = kernel_conv + kernel_scale + kernel_identity
  305. bias_final = bias_conv + bias_scale + bias_identity
  306. return kernel_final, bias_final
  307. def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
  308. """ Method to fuse batchnorm layer with preceeding conv layer.
  309. Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
  310. :param branch:
  311. :return: Tuple of (kernel, bias) after fusing batchnorm.
  312. """
  313. if isinstance(branch, nn.Sequential):
  314. kernel = branch.conv.weight
  315. running_mean = branch.bn.running_mean
  316. running_var = branch.bn.running_var
  317. gamma = branch.bn.weight
  318. beta = branch.bn.bias
  319. eps = branch.bn.eps
  320. else:
  321. assert isinstance(branch, nn.BatchNorm2d)
  322. if not hasattr(self, 'id_tensor'):
  323. input_dim = self.in_channels // self.groups
  324. kernel_value = torch.zeros((self.in_channels,
  325. input_dim,
  326. self.kernel_size,
  327. self.kernel_size),
  328. dtype=branch.weight.dtype,
  329. device=branch.weight.device)
  330. for i in range(self.in_channels):
  331. kernel_value[i, i % input_dim,
  332. self.kernel_size // 2,
  333. self.kernel_size // 2] = 1
  334. self.id_tensor = kernel_value
  335. kernel = self.id_tensor
  336. running_mean = branch.running_mean
  337. running_var = branch.running_var
  338. gamma = branch.weight
  339. beta = branch.bias
  340. eps = branch.eps
  341. std = (running_var + eps).sqrt()
  342. t = (gamma / std).reshape(-1, 1, 1, 1)
  343. return kernel * t, beta - running_mean * gamma / std
  344. def _conv_bn(self,
  345. kernel_size: int,
  346. padding: int) -> nn.Sequential:
  347. """ Helper method to construct conv-batchnorm layers.
  348. :param kernel_size: Size of the convolution kernel.
  349. :param padding: Zero-padding size.
  350. :return: Conv-BN module.
  351. """
  352. mod_list = nn.Sequential()
  353. mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,
  354. out_channels=self.out_channels,
  355. kernel_size=kernel_size,
  356. stride=self.stride,
  357. padding=padding,
  358. groups=self.groups,
  359. bias=False))
  360. mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
  361. return mod_list