detect.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import torch
  2. import torch.nn.functional as F
  3. import os
  4. import sys
  5. import cv2
  6. import random
  7. import datetime
  8. import math
  9. import argparse
  10. import numpy as np
  11. import scipy.io as sio
  12. import zipfile
  13. from .net_s3fd import s3fd
  14. from .bbox import *
  15. def detect(net, img, device):
  16. img = img - np.array([104, 117, 123])
  17. img = img.transpose(2, 0, 1)
  18. img = img.reshape((1,) + img.shape)
  19. if 'cuda' in device:
  20. torch.backends.cudnn.benchmark = True
  21. img = torch.from_numpy(img).float().to(device)
  22. BB, CC, HH, WW = img.size()
  23. with torch.no_grad():
  24. olist = net(img)
  25. bboxlist = []
  26. for i in range(len(olist) // 2):
  27. olist[i * 2] = F.softmax(olist[i * 2], dim=1)
  28. olist = [oelem.data.cpu() for oelem in olist]
  29. for i in range(len(olist) // 2):
  30. ocls, oreg = olist[i * 2], olist[i * 2 + 1]
  31. FB, FC, FH, FW = ocls.size() # feature map size
  32. stride = 2**(i + 2) # 4,8,16,32,64,128
  33. anchor = stride * 4
  34. poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
  35. for Iindex, hindex, windex in poss:
  36. axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
  37. score = ocls[0, 1, hindex, windex]
  38. loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
  39. priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
  40. variances = [0.1, 0.2]
  41. box = decode(loc, priors, variances)
  42. x1, y1, x2, y2 = box[0] * 1.0
  43. # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
  44. bboxlist.append([x1, y1, x2, y2, score])
  45. bboxlist = np.array(bboxlist)
  46. if 0 == len(bboxlist):
  47. bboxlist = np.zeros((1, 5))
  48. return bboxlist
  49. def batch_detect(net, imgs, device):
  50. imgs = imgs - np.array([104, 117, 123])
  51. imgs = imgs.transpose(0, 3, 1, 2)
  52. if 'cuda' in device:
  53. torch.backends.cudnn.benchmark = True
  54. imgs = torch.from_numpy(imgs).float().to(device)
  55. BB, CC, HH, WW = imgs.size()
  56. with torch.no_grad():
  57. olist = net(imgs)
  58. # print(olist)
  59. bboxlist = []
  60. for i in range(len(olist) // 2):
  61. olist[i * 2] = F.softmax(olist[i * 2], dim=1)
  62. olist = [oelem.cpu() for oelem in olist]
  63. for i in range(len(olist) // 2):
  64. ocls, oreg = olist[i * 2], olist[i * 2 + 1]
  65. FB, FC, FH, FW = ocls.size() # feature map size
  66. stride = 2**(i + 2) # 4,8,16,32,64,128
  67. anchor = stride * 4
  68. poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
  69. for Iindex, hindex, windex in poss:
  70. axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
  71. score = ocls[:, 1, hindex, windex]
  72. loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
  73. priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
  74. variances = [0.1, 0.2]
  75. box = batch_decode(loc, priors, variances)
  76. box = box[:, 0] * 1.0
  77. # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
  78. bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
  79. bboxlist = np.array(bboxlist)
  80. if 0 == len(bboxlist):
  81. bboxlist = np.zeros((1, BB, 5))
  82. return bboxlist
  83. def flip_detect(net, img, device):
  84. img = cv2.flip(img, 1)
  85. b = detect(net, img, device)
  86. bboxlist = np.zeros(b.shape)
  87. bboxlist[:, 0] = img.shape[1] - b[:, 2]
  88. bboxlist[:, 1] = b[:, 1]
  89. bboxlist[:, 2] = img.shape[1] - b[:, 0]
  90. bboxlist[:, 3] = b[:, 3]
  91. bboxlist[:, 4] = b[:, 4]
  92. return bboxlist
  93. def pts_to_bb(pts):
  94. min_x, min_y = np.min(pts, axis=0)
  95. max_x, max_y = np.max(pts, axis=0)
  96. return np.array([min_x, min_y, max_x, max_y])