models.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5. def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
  6. "3x3 convolution with padding"
  7. return nn.Conv2d(in_planes, out_planes, kernel_size=3,
  8. stride=strd, padding=padding, bias=bias)
  9. class ConvBlock(nn.Module):
  10. def __init__(self, in_planes, out_planes):
  11. super(ConvBlock, self).__init__()
  12. self.bn1 = nn.BatchNorm2d(in_planes)
  13. self.conv1 = conv3x3(in_planes, int(out_planes / 2))
  14. self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
  15. self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
  16. self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
  17. self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
  18. if in_planes != out_planes:
  19. self.downsample = nn.Sequential(
  20. nn.BatchNorm2d(in_planes),
  21. nn.ReLU(True),
  22. nn.Conv2d(in_planes, out_planes,
  23. kernel_size=1, stride=1, bias=False),
  24. )
  25. else:
  26. self.downsample = None
  27. def forward(self, x):
  28. residual = x
  29. out1 = self.bn1(x)
  30. out1 = F.relu(out1, True)
  31. out1 = self.conv1(out1)
  32. out2 = self.bn2(out1)
  33. out2 = F.relu(out2, True)
  34. out2 = self.conv2(out2)
  35. out3 = self.bn3(out2)
  36. out3 = F.relu(out3, True)
  37. out3 = self.conv3(out3)
  38. out3 = torch.cat((out1, out2, out3), 1)
  39. if self.downsample is not None:
  40. residual = self.downsample(residual)
  41. out3 += residual
  42. return out3
  43. class Bottleneck(nn.Module):
  44. expansion = 4
  45. def __init__(self, inplanes, planes, stride=1, downsample=None):
  46. super(Bottleneck, self).__init__()
  47. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  48. self.bn1 = nn.BatchNorm2d(planes)
  49. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  50. padding=1, bias=False)
  51. self.bn2 = nn.BatchNorm2d(planes)
  52. self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  53. self.bn3 = nn.BatchNorm2d(planes * 4)
  54. self.relu = nn.ReLU(inplace=True)
  55. self.downsample = downsample
  56. self.stride = stride
  57. def forward(self, x):
  58. residual = x
  59. out = self.conv1(x)
  60. out = self.bn1(out)
  61. out = self.relu(out)
  62. out = self.conv2(out)
  63. out = self.bn2(out)
  64. out = self.relu(out)
  65. out = self.conv3(out)
  66. out = self.bn3(out)
  67. if self.downsample is not None:
  68. residual = self.downsample(x)
  69. out += residual
  70. out = self.relu(out)
  71. return out
  72. class HourGlass(nn.Module):
  73. def __init__(self, num_modules, depth, num_features):
  74. super(HourGlass, self).__init__()
  75. self.num_modules = num_modules
  76. self.depth = depth
  77. self.features = num_features
  78. self._generate_network(self.depth)
  79. def _generate_network(self, level):
  80. self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
  81. self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
  82. if level > 1:
  83. self._generate_network(level - 1)
  84. else:
  85. self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
  86. self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
  87. def _forward(self, level, inp):
  88. # Upper branch
  89. up1 = inp
  90. up1 = self._modules['b1_' + str(level)](up1)
  91. # Lower branch
  92. low1 = F.avg_pool2d(inp, 2, stride=2)
  93. low1 = self._modules['b2_' + str(level)](low1)
  94. if level > 1:
  95. low2 = self._forward(level - 1, low1)
  96. else:
  97. low2 = low1
  98. low2 = self._modules['b2_plus_' + str(level)](low2)
  99. low3 = low2
  100. low3 = self._modules['b3_' + str(level)](low3)
  101. up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
  102. return up1 + up2
  103. def forward(self, x):
  104. return self._forward(self.depth, x)
  105. class FAN(nn.Module):
  106. def __init__(self, num_modules=1):
  107. super(FAN, self).__init__()
  108. self.num_modules = num_modules
  109. # Base part
  110. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
  111. self.bn1 = nn.BatchNorm2d(64)
  112. self.conv2 = ConvBlock(64, 128)
  113. self.conv3 = ConvBlock(128, 128)
  114. self.conv4 = ConvBlock(128, 256)
  115. # Stacking part
  116. for hg_module in range(self.num_modules):
  117. self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
  118. self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
  119. self.add_module('conv_last' + str(hg_module),
  120. nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
  121. self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
  122. self.add_module('l' + str(hg_module), nn.Conv2d(256,
  123. 68, kernel_size=1, stride=1, padding=0))
  124. if hg_module < self.num_modules - 1:
  125. self.add_module(
  126. 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
  127. self.add_module('al' + str(hg_module), nn.Conv2d(68,
  128. 256, kernel_size=1, stride=1, padding=0))
  129. def forward(self, x):
  130. x = F.relu(self.bn1(self.conv1(x)), True)
  131. x = F.avg_pool2d(self.conv2(x), 2, stride=2)
  132. x = self.conv3(x)
  133. x = self.conv4(x)
  134. previous = x
  135. outputs = []
  136. for i in range(self.num_modules):
  137. hg = self._modules['m' + str(i)](previous)
  138. ll = hg
  139. ll = self._modules['top_m_' + str(i)](ll)
  140. ll = F.relu(self._modules['bn_end' + str(i)]
  141. (self._modules['conv_last' + str(i)](ll)), True)
  142. # Predict heatmaps
  143. tmp_out = self._modules['l' + str(i)](ll)
  144. outputs.append(tmp_out)
  145. if i < self.num_modules - 1:
  146. ll = self._modules['bl' + str(i)](ll)
  147. tmp_out_ = self._modules['al' + str(i)](tmp_out)
  148. previous = previous + ll + tmp_out_
  149. return outputs
  150. class ResNetDepth(nn.Module):
  151. def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
  152. self.inplanes = 64
  153. super(ResNetDepth, self).__init__()
  154. self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
  155. bias=False)
  156. self.bn1 = nn.BatchNorm2d(64)
  157. self.relu = nn.ReLU(inplace=True)
  158. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  159. self.layer1 = self._make_layer(block, 64, layers[0])
  160. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  161. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  162. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  163. self.avgpool = nn.AvgPool2d(7)
  164. self.fc = nn.Linear(512 * block.expansion, num_classes)
  165. for m in self.modules():
  166. if isinstance(m, nn.Conv2d):
  167. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  168. m.weight.data.normal_(0, math.sqrt(2. / n))
  169. elif isinstance(m, nn.BatchNorm2d):
  170. m.weight.data.fill_(1)
  171. m.bias.data.zero_()
  172. def _make_layer(self, block, planes, blocks, stride=1):
  173. downsample = None
  174. if stride != 1 or self.inplanes != planes * block.expansion:
  175. downsample = nn.Sequential(
  176. nn.Conv2d(self.inplanes, planes * block.expansion,
  177. kernel_size=1, stride=stride, bias=False),
  178. nn.BatchNorm2d(planes * block.expansion),
  179. )
  180. layers = []
  181. layers.append(block(self.inplanes, planes, stride, downsample))
  182. self.inplanes = planes * block.expansion
  183. for i in range(1, blocks):
  184. layers.append(block(self.inplanes, planes))
  185. return nn.Sequential(*layers)
  186. def forward(self, x):
  187. x = self.conv1(x)
  188. x = self.bn1(x)
  189. x = self.relu(x)
  190. x = self.maxpool(x)
  191. x = self.layer1(x)
  192. x = self.layer2(x)
  193. x = self.layer3(x)
  194. x = self.layer4(x)
  195. x = self.avgpool(x)
  196. x = x.view(x.size(0), -1)
  197. x = self.fc(x)
  198. return x