unet.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import time
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class InvertedResidual(nn.Module):
  7. def __init__(self, inp, oup, stride, use_res_connect, expand_ratio=6):
  8. super(InvertedResidual, self).__init__()
  9. self.stride = stride
  10. assert stride in [1, 2]
  11. self.use_res_connect = use_res_connect
  12. self.conv = nn.Sequential(
  13. nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
  14. nn.BatchNorm2d(inp * expand_ratio),
  15. nn.ReLU(inplace=True),
  16. nn.Conv2d(inp * expand_ratio,
  17. inp * expand_ratio,
  18. 3,
  19. stride,
  20. 1,
  21. groups=inp * expand_ratio,
  22. bias=False),
  23. nn.BatchNorm2d(inp * expand_ratio),
  24. nn.ReLU(inplace=True),
  25. nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
  26. nn.BatchNorm2d(oup),
  27. )
  28. def forward(self, x):
  29. if self.use_res_connect:
  30. return x + self.conv(x)
  31. else:
  32. return self.conv(x)
  33. class DoubleConvDW(nn.Module):
  34. def __init__(self, in_channels, out_channels, stride=2):
  35. super(DoubleConvDW, self).__init__()
  36. self.double_conv = nn.Sequential(
  37. InvertedResidual(in_channels, out_channels, stride=stride, use_res_connect=False, expand_ratio=2),
  38. InvertedResidual(out_channels, out_channels, stride=1, use_res_connect=True, expand_ratio=2)
  39. )
  40. def forward(self, x):
  41. return self.double_conv(x)
  42. class InConvDw(nn.Module):
  43. def __init__(self, in_channels, out_channels):
  44. super(InConvDw, self).__init__()
  45. self.inconv = nn.Sequential(
  46. InvertedResidual(in_channels, out_channels, stride=1, use_res_connect=False, expand_ratio=2)
  47. )
  48. def forward(self, x):
  49. return self.inconv(x)
  50. class Down(nn.Module):
  51. def __init__(self, in_channels, out_channels):
  52. super(Down, self).__init__()
  53. self.maxpool_conv = nn.Sequential(
  54. DoubleConvDW(in_channels, out_channels, stride=2)
  55. )
  56. def forward(self, x):
  57. return self.maxpool_conv(x)
  58. class Up(nn.Module):
  59. def __init__(self, in_channels, out_channels):
  60. super(Up, self).__init__()
  61. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  62. self.conv = DoubleConvDW(in_channels, out_channels, stride=1)
  63. def forward(self, x1, x2):
  64. x1 = self.up(x1)
  65. diffY = x2.shape[2] - x1.shape[2]
  66. diffX = x2.shape[3] - x1.shape[3]
  67. x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
  68. x = torch.cat([x1, x2], axis=1)
  69. return self.conv(x)
  70. class OutConv(nn.Module):
  71. def __init__(self, in_channels, out_channels):
  72. super(OutConv, self).__init__()
  73. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
  74. def forward(self, x):
  75. return self.conv(x)
  76. class AudioConvWenet(nn.Module):
  77. def __init__(self):
  78. super(AudioConvWenet, self).__init__()
  79. # ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this.
  80. ch = [32, 64, 128, 256, 512]
  81. self.conv1 = InvertedResidual(ch[2], ch[3], stride=1, use_res_connect=False, expand_ratio=2)
  82. self.conv2 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
  83. self.conv3 = nn.Conv2d(ch[3], ch[3], kernel_size=3, padding=1, stride=(1,2))
  84. self.bn3 = nn.BatchNorm2d(ch[3])
  85. self.conv4 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
  86. self.conv5 = nn.Conv2d(ch[3], ch[4], kernel_size=3, padding=3, stride=2)
  87. self.bn5 = nn.BatchNorm2d(ch[4])
  88. self.relu = nn.ReLU()
  89. self.conv6 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
  90. self.conv7 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
  91. def forward(self, x):
  92. x = self.conv1(x)
  93. x = self.conv2(x)
  94. x = self.relu(self.bn3(self.conv3(x)))
  95. x = self.conv4(x)
  96. x = self.relu(self.bn5(self.conv5(x)))
  97. x = self.conv6(x)
  98. x = self.conv7(x)
  99. return x
  100. class AudioConvHubert(nn.Module):
  101. def __init__(self):
  102. super(AudioConvHubert, self).__init__()
  103. # ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this.
  104. ch = [32, 64, 128, 256, 512]
  105. self.conv1 = InvertedResidual(16, ch[1], stride=1, use_res_connect=False, expand_ratio=2)
  106. self.conv2 = InvertedResidual(ch[1], ch[2], stride=1, use_res_connect=False, expand_ratio=2)
  107. self.conv3 = nn.Conv2d(ch[2], ch[3], kernel_size=3, padding=1, stride=(2,2))
  108. self.bn3 = nn.BatchNorm2d(ch[3])
  109. self.conv4 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
  110. self.conv5 = nn.Conv2d(ch[3], ch[4], kernel_size=3, padding=3, stride=2)
  111. self.bn5 = nn.BatchNorm2d(ch[4])
  112. self.relu = nn.ReLU()
  113. self.conv6 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
  114. self.conv7 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
  115. def forward(self, x):
  116. x = self.conv1(x)
  117. x = self.conv2(x)
  118. x = self.relu(self.bn3(self.conv3(x)))
  119. x = self.conv4(x)
  120. x = self.relu(self.bn5(self.conv5(x)))
  121. x = self.conv6(x)
  122. x = self.conv7(x)
  123. return x
  124. class Model(nn.Module):
  125. def __init__(self,n_channels=6, mode='wenet'):
  126. super(Model, self).__init__()
  127. self.n_channels = n_channels #BGR
  128. # ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this.
  129. ch = [32, 64, 128, 256, 512]
  130. if mode=='hubert':
  131. self.audio_model = AudioConvHubert()
  132. if mode=='wenet':
  133. self.audio_model = AudioConvWenet()
  134. self.fuse_conv = nn.Sequential(
  135. DoubleConvDW(ch[4]*2, ch[4], stride=1),
  136. DoubleConvDW(ch[4], ch[3], stride=1)
  137. )
  138. self.inc = InConvDw(n_channels, ch[0])
  139. self.down1 = Down(ch[0], ch[1])
  140. self.down2 = Down(ch[1], ch[2])
  141. self.down3 = Down(ch[2], ch[3])
  142. self.down4 = Down(ch[3], ch[4])
  143. self.up1 = Up(ch[4], ch[3]//2)
  144. self.up2 = Up(ch[3], ch[2]//2)
  145. self.up3 = Up(ch[2], ch[1]//2)
  146. self.up4 = Up(ch[1], ch[0])
  147. self.outc = OutConv(ch[0], 3)
  148. def forward(self, x, audio_feat):
  149. x1 = self.inc(x)
  150. x2 = self.down1(x1)
  151. x3 = self.down2(x2)
  152. x4 = self.down3(x3)
  153. x5 = self.down4(x4)
  154. audio_feat = self.audio_model(audio_feat)
  155. x5 = torch.cat([x5, audio_feat], axis=1)
  156. x5 = self.fuse_conv(x5)
  157. x = self.up1(x5, x4)
  158. x = self.up2(x, x3)
  159. x = self.up3(x, x2)
  160. x = self.up4(x, x1)
  161. out = self.outc(x)
  162. out = F.sigmoid(out)
  163. return out
  164. if __name__ == '__main__':
  165. import time
  166. import copy
  167. import onnx
  168. import numpy as np
  169. onnx_path = "./unet.onnx"
  170. from thop import profile, clever_format
  171. def reparameterize_model(model: torch.nn.Module) -> torch.nn.Module:
  172. """ Method returns a model where a multi-branched structure
  173. used in training is re-parameterized into a single branch
  174. for inference.
  175. :param model: MobileOne model in train mode.
  176. :return: MobileOne model in inference mode.
  177. """
  178. # Avoid editing original graph
  179. model = copy.deepcopy(model)
  180. for module in model.modules():
  181. if hasattr(module, 'reparameterize'):
  182. module.reparameterize()
  183. return model
  184. device = torch.device("cuda")
  185. def check_onnx(torch_out, torch_in, audio):
  186. onnx_model = onnx.load(onnx_path)
  187. onnx.checker.check_model(onnx_model)
  188. import onnxruntime
  189. providers = ["CUDAExecutionProvider"]
  190. ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers)
  191. print(ort_session.get_providers())
  192. ort_inputs = {ort_session.get_inputs()[0].name: torch_in.cpu().numpy(), ort_session.get_inputs()[1].name: audio.cpu().numpy()}
  193. ort_outs = ort_session.run(None, ort_inputs)
  194. np.testing.assert_allclose(torch_out[0].cpu().numpy(), ort_outs[0][0], rtol=1e-03, atol=1e-05)
  195. print("Exported model has been tested with ONNXRuntime, and the result looks good!")
  196. net = Model(6).eval().to(device)
  197. img = torch.zeros([1, 6, 160, 160]).to(device)
  198. audio = torch.zeros([1, 16, 32, 32]).to(device)
  199. # net = reparameterize_model(net)
  200. flops, params = profile(net, (img,audio))
  201. macs, params = clever_format([flops, params], "%3f")
  202. print(macs, params)
  203. # dynamic_axes= {'input':[2, 3], 'output':[2, 3]}
  204. input_dict = {"input": img, "audio": audio}
  205. with torch.no_grad():
  206. torch_out = net(img, audio)
  207. print(torch_out.shape)
  208. torch.onnx.export(net, (img, audio), onnx_path, input_names=['input', "audio"],
  209. output_names=['output'],
  210. # dynamic_axes=dynamic_axes,
  211. # example_outputs=torch_out,
  212. opset_version=11,
  213. export_params=True)
  214. check_onnx(torch_out, img, audio)
  215. # img = torch.zeros([1, 6, 160, 160]).to(device)
  216. # audio = torch.zeros([1, 16, 32, 32]).to(device)
  217. # with torch.no_grad():
  218. # for i in range(100000):
  219. # t1 = time.time()
  220. # out = net(img, audio)
  221. # t2 = time.time()
  222. # # print(out.shape)
  223. # print('time cost::', t2-t1)
  224. # torch.save(net.state_dict(), '1.pth')