__init__.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. import time
  3. import os
  4. import cv2
  5. import numpy as np
  6. from PIL import Image
  7. from .model import BiSeNet
  8. import torchvision.transforms as transforms
  9. class FaceParsing():
  10. def __init__(self, left_cheek_width=80, right_cheek_width=80):
  11. self.net = self.model_init()
  12. self.preprocess = self.image_preprocess()
  13. # Ensure all size parameters are integers
  14. cone_height = 21
  15. tail_height = 12
  16. total_size = cone_height + tail_height
  17. # Create kernel with explicit integer dimensions
  18. kernel = np.zeros((total_size, total_size), dtype=np.uint8)
  19. center_x = total_size // 2 # Ensure center coordinates are integers
  20. # Cone part
  21. for row in range(cone_height):
  22. if row < cone_height//2:
  23. continue
  24. width = int(2 * (row - cone_height//2) + 1)
  25. start = int(center_x - (width // 2))
  26. end = int(center_x + (width // 2) + 1)
  27. kernel[row, start:end] = 1
  28. # Vertical extension part
  29. if cone_height > 0:
  30. base_width = int(kernel[cone_height-1].sum())
  31. else:
  32. base_width = 1
  33. for row in range(cone_height, total_size):
  34. start = max(0, int(center_x - (base_width//2)))
  35. end = min(total_size, int(center_x + (base_width//2) + 1))
  36. kernel[row, start:end] = 1
  37. self.kernel = kernel
  38. # Modify cheek erosion kernel to be flatter ellipse
  39. self.cheek_kernel = cv2.getStructuringElement(
  40. cv2.MORPH_ELLIPSE, (35, 3))
  41. # Add cheek area mask (protect chin area)
  42. self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
  43. def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
  44. """Create cheek area mask (1/4 area on both sides)"""
  45. mask = np.zeros((512, 512), dtype=np.uint8)
  46. center = 512 // 2
  47. cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
  48. cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
  49. return mask
  50. def model_init(self,
  51. resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
  52. model_pth='./models/face-parse-bisent/79999_iter.pth'):
  53. net = BiSeNet(resnet_path)
  54. if torch.cuda.is_available():
  55. net.cuda()
  56. net.load_state_dict(torch.load(model_pth))
  57. else:
  58. net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
  59. net.eval()
  60. return net
  61. def image_preprocess(self):
  62. return transforms.Compose([
  63. transforms.ToTensor(),
  64. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  65. ])
  66. def __call__(self, image, size=(512, 512), mode="raw"):
  67. if isinstance(image, str):
  68. image = Image.open(image)
  69. width, height = image.size
  70. with torch.no_grad():
  71. image = image.resize(size, Image.BILINEAR)
  72. img = self.preprocess(image)
  73. if torch.cuda.is_available():
  74. img = torch.unsqueeze(img, 0).cuda()
  75. else:
  76. img = torch.unsqueeze(img, 0)
  77. out = self.net(img)[0]
  78. parsing = out.squeeze(0).cpu().numpy().argmax(0)
  79. # Add 14:neck, remove 10:nose and 7:8:9
  80. if mode == "neck":
  81. parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
  82. parsing[np.where(parsing!=255)] = 0
  83. elif mode == "jaw":
  84. face_region = np.isin(parsing, [1])*255
  85. face_region = face_region.astype(np.uint8)
  86. original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
  87. eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
  88. face_region = cv2.bitwise_and(eroded, self.cheek_mask)
  89. face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
  90. parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
  91. parsing[np.isin(parsing, [11, 12, 13])] = 255
  92. parsing[np.where(parsing!=255)] = 0
  93. else:
  94. parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
  95. parsing[np.where(parsing!=255)] = 0
  96. parsing = Image.fromarray(parsing.astype(np.uint8))
  97. return parsing
  98. if __name__ == "__main__":
  99. fp = FaceParsing()
  100. segmap = fp('154_small.png')
  101. segmap.save('res.png')