net_s3fd.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class L2Norm(nn.Module):
  5. def __init__(self, n_channels, scale=1.0):
  6. super(L2Norm, self).__init__()
  7. self.n_channels = n_channels
  8. self.scale = scale
  9. self.eps = 1e-10
  10. self.weight = nn.Parameter(torch.Tensor(self.n_channels))
  11. self.weight.data *= 0.0
  12. self.weight.data += self.scale
  13. def forward(self, x):
  14. norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
  15. x = x / norm * self.weight.view(1, -1, 1, 1)
  16. return x
  17. class s3fd(nn.Module):
  18. def __init__(self):
  19. super(s3fd, self).__init__()
  20. self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
  21. self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
  22. self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
  23. self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
  24. self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
  25. self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
  26. self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
  27. self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
  28. self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
  29. self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
  30. self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
  31. self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
  32. self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
  33. self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
  34. self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
  35. self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
  36. self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
  37. self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
  38. self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
  39. self.conv3_3_norm = L2Norm(256, scale=10)
  40. self.conv4_3_norm = L2Norm(512, scale=8)
  41. self.conv5_3_norm = L2Norm(512, scale=5)
  42. self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
  43. self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
  44. self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
  45. self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
  46. self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
  47. self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
  48. self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
  49. self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
  50. self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
  51. self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
  52. self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
  53. self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
  54. def forward(self, x):
  55. h = F.relu(self.conv1_1(x))
  56. h = F.relu(self.conv1_2(h))
  57. h = F.max_pool2d(h, 2, 2)
  58. h = F.relu(self.conv2_1(h))
  59. h = F.relu(self.conv2_2(h))
  60. h = F.max_pool2d(h, 2, 2)
  61. h = F.relu(self.conv3_1(h))
  62. h = F.relu(self.conv3_2(h))
  63. h = F.relu(self.conv3_3(h))
  64. f3_3 = h
  65. h = F.max_pool2d(h, 2, 2)
  66. h = F.relu(self.conv4_1(h))
  67. h = F.relu(self.conv4_2(h))
  68. h = F.relu(self.conv4_3(h))
  69. f4_3 = h
  70. h = F.max_pool2d(h, 2, 2)
  71. h = F.relu(self.conv5_1(h))
  72. h = F.relu(self.conv5_2(h))
  73. h = F.relu(self.conv5_3(h))
  74. f5_3 = h
  75. h = F.max_pool2d(h, 2, 2)
  76. h = F.relu(self.fc6(h))
  77. h = F.relu(self.fc7(h))
  78. ffc7 = h
  79. h = F.relu(self.conv6_1(h))
  80. h = F.relu(self.conv6_2(h))
  81. f6_2 = h
  82. h = F.relu(self.conv7_1(h))
  83. h = F.relu(self.conv7_2(h))
  84. f7_2 = h
  85. f3_3 = self.conv3_3_norm(f3_3)
  86. f4_3 = self.conv4_3_norm(f4_3)
  87. f5_3 = self.conv5_3_norm(f5_3)
  88. cls1 = self.conv3_3_norm_mbox_conf(f3_3)
  89. reg1 = self.conv3_3_norm_mbox_loc(f3_3)
  90. cls2 = self.conv4_3_norm_mbox_conf(f4_3)
  91. reg2 = self.conv4_3_norm_mbox_loc(f4_3)
  92. cls3 = self.conv5_3_norm_mbox_conf(f5_3)
  93. reg3 = self.conv5_3_norm_mbox_loc(f5_3)
  94. cls4 = self.fc7_mbox_conf(ffc7)
  95. reg4 = self.fc7_mbox_loc(ffc7)
  96. cls5 = self.conv6_2_mbox_conf(f6_2)
  97. reg5 = self.conv6_2_mbox_loc(f6_2)
  98. cls6 = self.conv7_2_mbox_conf(f7_2)
  99. reg6 = self.conv7_2_mbox_loc(f7_2)
  100. # max-out background label
  101. chunk = torch.chunk(cls1, 4, 1)
  102. bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
  103. cls1 = torch.cat([bmax, chunk[3]], dim=1)
  104. return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]