pfld_mobileone.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn import Module, AvgPool2d, Linear
  7. from .base_module import MobileOneBlock, GhostOneBottleneck, Conv_Block
  8. class PFLD_GhostOne(Module):
  9. def __init__(self, width_factor=0.5, input_size=192, landmark_number=110, inference_mode=False):
  10. super(PFLD_GhostOne, self).__init__()
  11. self.inference_mode = inference_mode
  12. self.num_conv_branches = 6
  13. self.conv1 = MobileOneBlock(in_channels=3,
  14. out_channels=int(64 * width_factor),
  15. kernel_size=3,
  16. stride=2,
  17. padding=1,
  18. groups=1,
  19. inference_mode=self.inference_mode,
  20. use_se=False,
  21. num_conv_branches=self.num_conv_branches,
  22. is_linear=False)
  23. self.conv2 = MobileOneBlock(in_channels=int(64 * width_factor),
  24. out_channels=int(64 * width_factor),
  25. kernel_size=3,
  26. stride=1,
  27. padding=1,
  28. groups=int(64 * width_factor),
  29. inference_mode=self.inference_mode,
  30. use_se=False,
  31. num_conv_branches=self.num_conv_branches,
  32. is_linear=False)
  33. # def _make_bottlenecks(self):
  34. # modules = OrderedDict()
  35. # stage_name = "Bottlenecks"
  36. # # First module is the only one with t=1
  37. # bottleneck1 = self._make_stage(inplanes=self.c[0], outplanes=self.c[1], n=self.n[1], stride=self.s[1], t=1,
  38. # stage=0)
  39. # modules[stage_name + "_0"] = bottleneck1
  40. # # add more LinearBottleneck depending on number of repeats
  41. # for i in range(1, len(self.c) - 1):
  42. # name = stage_name + "_{}".format(i)
  43. # module = self._make_stage(inplanes=self.c[i], outplanes=self.c[i + 1], n=self.n[i + 1],
  44. # stride=self.s[i + 1],
  45. # t=self.t, stage=i)
  46. # modules[name] = module
  47. # return nn.Sequential(modules)
  48. self.conv3_1 = GhostOneBottleneck(int(64 * width_factor), int(96 * width_factor), int(80 * width_factor), stride=2, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  49. self.conv3_2 = GhostOneBottleneck(int(80 * width_factor), int(120 * width_factor), int(80 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  50. self.conv3_3 = GhostOneBottleneck(int(80 * width_factor), int(120 * width_factor), int(80 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  51. self.conv4_1 = GhostOneBottleneck(int(80 * width_factor), int(200 * width_factor), int(96 * width_factor), stride=2, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  52. self.conv4_2 = GhostOneBottleneck(int(96 * width_factor), int(240 * width_factor), int(96 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  53. self.conv4_3 = GhostOneBottleneck(int(96 * width_factor), int(240 * width_factor), int(96 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  54. self.conv5_1 = GhostOneBottleneck(int(96 * width_factor), int(336 * width_factor), int(144 * width_factor), stride=2, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  55. self.conv5_2 = GhostOneBottleneck(int(144 * width_factor), int(504 * width_factor), int(144 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  56. self.conv5_3 = GhostOneBottleneck(int(144 * width_factor), int(504 * width_factor), int(144 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  57. self.conv5_4 = GhostOneBottleneck(int(144 * width_factor), int(504 * width_factor), int(144 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  58. self.conv6 = GhostOneBottleneck(int(144 * width_factor), int(216 * width_factor), int(16 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  59. self.conv7 = MobileOneBlock(in_channels=int(16 * width_factor),
  60. out_channels=int(32 * width_factor),
  61. kernel_size=3,
  62. stride=1,
  63. padding=1,
  64. groups=1,
  65. inference_mode=self.inference_mode,
  66. use_se=False,
  67. num_conv_branches=self.num_conv_branches,
  68. is_linear=False)
  69. self.conv8 = Conv_Block(int(32 * width_factor), int(128 * width_factor), input_size // 16, 1, 0, has_bn=False)
  70. self.avg_pool1 = AvgPool2d(input_size // 2)
  71. self.avg_pool2 = AvgPool2d(input_size // 4)
  72. self.avg_pool3 = AvgPool2d(input_size // 8)
  73. self.avg_pool4 = AvgPool2d(input_size // 16)
  74. self.conv_out = nn.Conv2d(int(512*width_factor), landmark_number*2, 1, 1, 0) # 这个大小需要改
  75. self.localization = nn.Sequential(
  76. nn.Conv2d(1, 8, kernel_size=7),
  77. nn.MaxPool2d(2, stride=2),
  78. nn.ReLU(True),
  79. nn.Conv2d(8, 10, kernel_size=5),
  80. nn.MaxPool2d(2, stride=2),
  81. nn.ReLU(True)
  82. )
  83. def forward(self, x):
  84. x = self.conv1(x)
  85. x = self.conv2(x)
  86. x1 = self.avg_pool1(x)
  87. # x1 = x1.view(x1.size(0), -1)
  88. x = self.conv3_1(x)
  89. x = self.conv3_2(x)
  90. x = self.conv3_3(x)
  91. x2 = self.avg_pool2(x)
  92. # x2 = x2.view(x2.size(0), -1)
  93. x = self.conv4_1(x)
  94. x = self.conv4_2(x)
  95. x = self.conv4_3(x)
  96. x3 = self.avg_pool3(x)
  97. # x3 = x3.view(x3.size(0), -1)
  98. x = self.conv5_1(x)
  99. x = self.conv5_2(x)
  100. x = self.conv5_3(x)
  101. x = self.conv5_4(x)
  102. x4 = self.avg_pool4(x)
  103. # x4 = x4.view(x4.size(0), -1)
  104. x = self.conv6(x)
  105. x = self.conv7(x)
  106. x5 = self.conv8(x)
  107. # x5 = x5.view(x5.size(0), -1)
  108. multi_scale = torch.cat([x1, x2, x3, x4, x5], 1)
  109. landmarks = self.conv_out(multi_scale)
  110. landmarks = landmarks.view(landmarks.size(0), -1)
  111. return landmarks
  112. class PFLD_GhostOne_WithSTN(Module):
  113. def __init__(self, width_factor=0.5, input_size=112, landmark_number=110, inference_mode=False):
  114. super(PFLD_GhostOne, self).__init__()
  115. self.inference_mode = inference_mode
  116. self.num_conv_branches = 6
  117. self.conv1 = MobileOneBlock(in_channels=3,
  118. out_channels=int(64 * width_factor),
  119. kernel_size=3,
  120. stride=2,
  121. padding=1,
  122. groups=1,
  123. inference_mode=self.inference_mode,
  124. use_se=False,
  125. num_conv_branches=self.num_conv_branches,
  126. is_linear=False)
  127. self.conv2 = MobileOneBlock(in_channels=int(64 * width_factor),
  128. out_channels=int(64 * width_factor),
  129. kernel_size=3,
  130. stride=1,
  131. padding=1,
  132. groups=int(64 * width_factor),
  133. inference_mode=self.inference_mode,
  134. use_se=False,
  135. num_conv_branches=self.num_conv_branches,
  136. is_linear=False)
  137. # def _make_bottlenecks(self):
  138. # modules = OrderedDict()
  139. # stage_name = "Bottlenecks"
  140. # # First module is the only one with t=1
  141. # bottleneck1 = self._make_stage(inplanes=self.c[0], outplanes=self.c[1], n=self.n[1], stride=self.s[1], t=1,
  142. # stage=0)
  143. # modules[stage_name + "_0"] = bottleneck1
  144. # # add more LinearBottleneck depending on number of repeats
  145. # for i in range(1, len(self.c) - 1):
  146. # name = stage_name + "_{}".format(i)
  147. # module = self._make_stage(inplanes=self.c[i], outplanes=self.c[i + 1], n=self.n[i + 1],
  148. # stride=self.s[i + 1],
  149. # t=self.t, stage=i)
  150. # modules[name] = module
  151. # return nn.Sequential(modules)
  152. self.conv3_1 = GhostOneBottleneck(int(64 * width_factor), int(96 * width_factor), int(80 * width_factor), stride=2, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  153. self.conv3_2 = GhostOneBottleneck(int(80 * width_factor), int(120 * width_factor), int(80 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  154. self.conv3_3 = GhostOneBottleneck(int(80 * width_factor), int(120 * width_factor), int(80 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  155. self.conv4_1 = GhostOneBottleneck(int(80 * width_factor), int(200 * width_factor), int(96 * width_factor), stride=2, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  156. self.conv4_2 = GhostOneBottleneck(int(96 * width_factor), int(240 * width_factor), int(96 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  157. self.conv4_3 = GhostOneBottleneck(int(96 * width_factor), int(240 * width_factor), int(96 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  158. self.conv5_1 = GhostOneBottleneck(int(96 * width_factor), int(336 * width_factor), int(144 * width_factor), stride=2, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  159. self.conv5_2 = GhostOneBottleneck(int(144 * width_factor), int(504 * width_factor), int(144 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  160. self.conv5_3 = GhostOneBottleneck(int(144 * width_factor), int(504 * width_factor), int(144 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  161. self.conv5_4 = GhostOneBottleneck(int(144 * width_factor), int(504 * width_factor), int(144 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  162. self.conv6 = GhostOneBottleneck(int(144 * width_factor), int(216 * width_factor), int(16 * width_factor), stride=1, inference_mode=self.inference_mode, num_conv_branches=self.num_conv_branches)
  163. self.conv7 = MobileOneBlock(in_channels=int(16 * width_factor),
  164. out_channels=int(32 * width_factor),
  165. kernel_size=3,
  166. stride=1,
  167. padding=1,
  168. groups=1,
  169. inference_mode=self.inference_mode,
  170. use_se=False,
  171. num_conv_branches=self.num_conv_branches,
  172. is_linear=False)
  173. self.conv8 = Conv_Block(int(32 * width_factor), int(128 * width_factor), input_size // 16, 1, 0, has_bn=False)
  174. self.avg_pool1 = AvgPool2d(input_size // 2)
  175. self.avg_pool2 = AvgPool2d(input_size // 4)
  176. self.avg_pool3 = AvgPool2d(input_size // 8)
  177. self.avg_pool4 = AvgPool2d(input_size // 16)
  178. self.conv_out = nn.Conv2d(int(512*width_factor), landmark_number*2, 1, 1, 0) # 这个大小需要改
  179. def forward(self, x):
  180. x = self.conv1(x)
  181. x = self.conv2(x)
  182. x1 = self.avg_pool1(x)
  183. # x1 = x1.view(x1.size(0), -1)
  184. x = self.conv3_1(x)
  185. x = self.conv3_2(x)
  186. x = self.conv3_3(x)
  187. x2 = self.avg_pool2(x)
  188. # x2 = x2.view(x2.size(0), -1)
  189. x = self.conv4_1(x)
  190. x = self.conv4_2(x)
  191. x = self.conv4_3(x)
  192. x3 = self.avg_pool3(x)
  193. # x3 = x3.view(x3.size(0), -1)
  194. x = self.conv5_1(x)
  195. x = self.conv5_2(x)
  196. x = self.conv5_3(x)
  197. x = self.conv5_4(x)
  198. x4 = self.avg_pool4(x)
  199. # x4 = x4.view(x4.size(0), -1)
  200. x = self.conv6(x)
  201. x = self.conv7(x)
  202. x5 = self.conv8(x)
  203. # x5 = x5.view(x5.size(0), -1)
  204. multi_scale = torch.cat([x1, x2, x3, x4, x5], 1)
  205. landmarks = self.conv_out(multi_scale)
  206. landmarks = landmarks.view(landmarks.size(0), -1)
  207. return landmarks
  208. class AuxiliaryNet(Module):
  209. def __init__(self, width_factor=1):
  210. super(AuxiliaryNet, self).__init__()
  211. self.conv1 = Conv_Block(int(64 * width_factor), int(64 * width_factor), 1, 1, 0)
  212. self.conv2 = Conv_Block(int(80 * width_factor), int(64 * width_factor), 1, 1, 0)
  213. self.conv3 = Conv_Block(int(96 * width_factor), int(64 * width_factor), 1, 1, 0)
  214. self.conv4 = Conv_Block(int(144 * width_factor), int(64 * width_factor), 1, 1, 0)
  215. self.merge1 = Conv_Block(int(64 * width_factor), int(64 * width_factor), 3, 1, 1)
  216. self.merge2 = Conv_Block(int(64 * width_factor), int(64 * width_factor), 3, 1, 1)
  217. self.merge3 = Conv_Block(int(64 * width_factor), int(64 * width_factor), 3, 1, 1)
  218. self.conv_out = Conv_Block(int(64 * width_factor), 1, 1, 1, 0)
  219. def forward(self, out1, out2, out3, out4):
  220. output1 = self.conv1(out1)
  221. output2 = self.conv2(out2)
  222. output3 = self.conv3(out3)
  223. output4 = self.conv4(out4)
  224. up4 = F.interpolate(output4, size=[output3.size(2), output3.size(3)], mode="nearest")
  225. output3 = output3 + up4
  226. output3 = self.merge3(output3)
  227. up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
  228. output2 = output2 + up3
  229. output2 = self.merge2(output2)
  230. up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
  231. output1 = output1 + up2
  232. output1 = self.merge1(output1)
  233. output1 = self.conv_out(output1)
  234. return output1
  235. if __name__ == "__main__":
  236. import time
  237. import onnx
  238. import numpy as np
  239. from thop import profile
  240. INPUT_SIZE = 256
  241. net = PFLD_GhostOne(0.5, INPUT_SIZE, 110, True)
  242. torch_in = torch.zeros([1, 3, INPUT_SIZE, INPUT_SIZE])
  243. flops, params = profile(net, (torch_in,))
  244. print(flops)
  245. for i in range(11):
  246. t1 = time.time()
  247. _ = net(torch_in)
  248. t2 = time.time()
  249. print(t2-t1)
  250. def check_onnx(torch_out, torch_in):
  251. onnx_model = onnx.load(onnx_path)
  252. onnx.checker.check_model(onnx_model)
  253. import onnxruntime
  254. ort_session = onnxruntime.InferenceSession(onnx_path)
  255. ort_inputs = {ort_session.get_inputs()[0].name: torch_in.cpu().numpy()}
  256. ort_outs = ort_session.run(None, ort_inputs)
  257. np.testing.assert_allclose(torch_out[0].cpu().numpy(), ort_outs[0][0], rtol=1e-03, atol=1e-05)
  258. print("Exported model has been tested with ONNXRuntime, and the result looks good!")
  259. source_file = './1.pth'
  260. onnx_path = './pfld_mobileone_256.onnx'
  261. torch.save(net.state_dict(), source_file)
  262. input_size = 256
  263. print("=====> load pytorch checkpoint...")
  264. # checkpoint = torch.load(source_file, map_location=torch.device('cpu'))
  265. dummy_input = torch.randn(1, 3, input_size, input_size)
  266. # input_names = ["input"]
  267. # output_names = ["output"]
  268. # net.load_state_dict(checkpoint)
  269. torch_in = torch.zeros([1,3,input_size,input_size])
  270. with torch.no_grad():
  271. torch_out = net(torch_in)
  272. print(torch_out)
  273. torch.onnx.export(net, torch_in, onnx_path, input_names=['input'],
  274. output_names=['output'],
  275. # example_outputs=torch_out,
  276. opset_version=11,
  277. export_params=True)