detect.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import torch
  2. import torch.nn.functional as F
  3. import os
  4. import sys
  5. import cv2
  6. import random
  7. import datetime
  8. import math
  9. import argparse
  10. import numpy as np
  11. import scipy.io as sio
  12. import zipfile
  13. from .net_s3fd import s3fd
  14. from .bbox import *
  15. def detect(net, img, device):
  16. img = img - np.array([104, 117, 123])
  17. img = img.transpose(2, 0, 1)
  18. img = img.reshape((1,) + img.shape)
  19. if 'cuda' in device:
  20. torch.backends.cudnn.benchmark = True
  21. img = torch.from_numpy(img).float().to(device)
  22. BB, CC, HH, WW = img.size()
  23. with torch.no_grad():
  24. olist = net(img)
  25. bboxlist = []
  26. for i in range(len(olist) // 2):
  27. olist[i * 2] = F.softmax(olist[i * 2], dim=1)
  28. olist = [oelem.data.cpu() for oelem in olist]
  29. for i in range(len(olist) // 2):
  30. ocls, oreg = olist[i * 2], olist[i * 2 + 1]
  31. FB, FC, FH, FW = ocls.size() # feature map size
  32. stride = 2**(i + 2) # 4,8,16,32,64,128
  33. anchor = stride * 4
  34. poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
  35. for Iindex, hindex, windex in poss:
  36. axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
  37. score = ocls[0, 1, hindex, windex]
  38. loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
  39. priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
  40. variances = [0.1, 0.2]
  41. box = decode(loc, priors, variances)
  42. x1, y1, x2, y2 = box[0] * 1.0
  43. # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
  44. bboxlist.append([x1, y1, x2, y2, score])
  45. bboxlist = np.array(bboxlist)
  46. if 0 == len(bboxlist):
  47. bboxlist = np.zeros((1, 5))
  48. return bboxlist
  49. def batch_detect(net, imgs, device):
  50. imgs = imgs - np.array([104, 117, 123])
  51. imgs = imgs.transpose(0, 3, 1, 2)
  52. if 'cuda' in device:
  53. torch.backends.cudnn.benchmark = True
  54. imgs = torch.from_numpy(imgs).float().to(device)
  55. BB, CC, HH, WW = imgs.size()
  56. with torch.no_grad():
  57. olist = net(imgs)
  58. bboxlist = []
  59. for i in range(len(olist) // 2):
  60. olist[i * 2] = F.softmax(olist[i * 2], dim=1)
  61. olist = [oelem.data.cpu() for oelem in olist]
  62. for i in range(len(olist) // 2):
  63. ocls, oreg = olist[i * 2], olist[i * 2 + 1]
  64. FB, FC, FH, FW = ocls.size() # feature map size
  65. stride = 2**(i + 2) # 4,8,16,32,64,128
  66. anchor = stride * 4
  67. poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
  68. for Iindex, hindex, windex in poss:
  69. axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
  70. score = ocls[:, 1, hindex, windex]
  71. loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
  72. priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
  73. variances = [0.1, 0.2]
  74. box = batch_decode(loc, priors, variances)
  75. box = box[:, 0] * 1.0
  76. # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
  77. bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
  78. bboxlist = np.array(bboxlist)
  79. if 0 == len(bboxlist):
  80. bboxlist = np.zeros((1, BB, 5))
  81. return bboxlist
  82. def flip_detect(net, img, device):
  83. img = cv2.flip(img, 1)
  84. b = detect(net, img, device)
  85. bboxlist = np.zeros(b.shape)
  86. bboxlist[:, 0] = img.shape[1] - b[:, 2]
  87. bboxlist[:, 1] = b[:, 1]
  88. bboxlist[:, 2] = img.shape[1] - b[:, 0]
  89. bboxlist[:, 3] = b[:, 3]
  90. bboxlist[:, 4] = b[:, 4]
  91. return bboxlist
  92. def pts_to_bb(pts):
  93. min_x, min_y = np.min(pts, axis=0)
  94. max_x, max_y = np.max(pts, axis=0)
  95. return np.array([min_x, min_y, max_x, max_y])