get_landmark.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import argparse
  2. from os import wait3
  3. import numpy as np
  4. import cv2
  5. import math
  6. import torch
  7. import torchvision
  8. from .detect_face import SCRFD
  9. # from models.pfld_lite import PFLDInference
  10. # from models.pfld import PFLDInference
  11. from .pfld_mobileone import PFLD_GhostOne as PFLDInference
  12. def face_det(img, model):
  13. cropped_imgs = []
  14. boxes_list = []
  15. center_list = []
  16. alpha_list = []
  17. height, width = img.shape[:2]
  18. bboxes, indices, kps = model.detect(img)
  19. for i in indices:
  20. x1, y1, x2, y2 = int(bboxes[i, 0]), int(bboxes[i, 1]), int(bboxes[i, 0] + bboxes[i, 2]), int(bboxes[i, 1] + bboxes[i, 3])
  21. p1 = kps[i,0]
  22. p2 = kps[i,1]
  23. w = x2 - x1
  24. h = y2 - y1
  25. cx = (x2+x1)//2
  26. cy = (y2+y1)//2
  27. wh = np.asarray([w,h])
  28. boxsize = int(np.max(wh)*1.05)
  29. size = boxsize
  30. xy = np.asarray((cx - size // 2, cy - size//2), dtype=np.int32)
  31. x1, y1 = xy
  32. x2, y2 = xy + size
  33. height, width, _ = img.shape
  34. dx = max(0, -x1)
  35. dy = max(0, -y1)
  36. x1 = max(0, x1)
  37. y1 = max(0, y1)
  38. edx = max(0, x2 - width)
  39. edy = max(0, y2 - height)
  40. x2 = min(width, x2)
  41. y2 = min(height, y2)
  42. cropped = img[y1:y2, x1:x2]
  43. if (dx > 0 or dy > 0 or edx >0 or edy > 0):
  44. cropped = cv2.copyMakeBorder(cropped, dy, edy, dx, edx, cv2.BORDER_CONSTANT, 0)
  45. y1 = y1-dy
  46. x1 = x1-dx
  47. center = (int((x2-x1)//2), int((y2-y1)//2))
  48. boxes_list.append([x1,y1,x2,y2])
  49. center_list.append(center)
  50. alpha = math.atan2(p2[1]-p1[1], p2[0]-p1[0]) * 180 / math.pi
  51. rot_mat = cv2.getRotationMatrix2D(center, alpha, 1)
  52. # img_rotated_by_alpha = cv2.warpAffine(cropped, rot_mat,
  53. # (cropped.shape[1], cropped.shape[0]))
  54. # cropped_imgs.append(img_rotated_by_alpha)
  55. cropped_imgs.append(cropped)
  56. alpha_list.append(alpha)
  57. break
  58. return cropped_imgs, boxes_list, center_list, alpha_list
  59. class Landmark:
  60. def __init__(self):
  61. with open('./face_detect_utils/mean_face.txt', 'r') as f_mean_face:
  62. mean_face = f_mean_face.read()
  63. self.mean_face = np.asarray(mean_face.split(' '), dtype=np.float32)
  64. self.det_net = SCRFD('./face_detect_utils/scrfd_2.5g_kps.onnx', confThreshold=0.1, nmsThreshold=0.5)
  65. checkpoint = torch.load('./face_detect_utils/checkpoint_epoch_335.pth.tar')
  66. self.pfld_backbone = PFLDInference().cuda()
  67. self.pfld_backbone.load_state_dict(checkpoint['pfld_backbone'])
  68. self.pfld_backbone.eval()
  69. def detect(self, img_path):
  70. img = cv2.imread(img_path)
  71. img_ori = img.copy()
  72. h,w = img_ori.shape[:2]
  73. cropped_imgs, boxes_list, center_list, alpha_list = face_det(img, self.det_net)
  74. cropped = cropped_imgs[0]
  75. # cv2.imshow("cropped", cropped)
  76. h,w = cropped.shape[:2]
  77. x1, y1, x2, y2 = boxes_list[0]
  78. transform = torchvision.transforms.Compose(
  79. [torchvision.transforms.ToTensor()])
  80. input = cv2.resize(cropped, (192, 192))
  81. input = np.asarray(input, dtype=np.float32) / 255.0
  82. input = input.transpose(2,0,1)
  83. input = torch.from_numpy(input)[None]
  84. input = input.cuda()
  85. # print(input)
  86. # asd
  87. # input = transform(input).unsqueeze(0).cuda()
  88. landmarks = self.pfld_backbone(input)
  89. pre_landmark = landmarks[0]
  90. pre_landmark = pre_landmark.cpu().detach().numpy()
  91. pre_landmark = pre_landmark + self.mean_face
  92. pre_landmark = pre_landmark.reshape(-1, 2)
  93. pre_landmark[:,0] *= w
  94. pre_landmark[:,1] *= h
  95. pre_landmark = pre_landmark.astype(np.int32)
  96. return pre_landmark, x1, y1