genavatar.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from os import listdir, path
  2. import numpy as np
  3. import scipy, cv2, os, sys, argparse
  4. import json, subprocess, random, string
  5. from tqdm import tqdm
  6. from glob import glob
  7. import torch
  8. import pickle
  9. import face_detection
  10. parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
  11. parser.add_argument('--img_size', default=96, type=int)
  12. parser.add_argument('--avatar_id', default='wav2lip_avatar1', type=str)
  13. parser.add_argument('--video_path', default='', type=str)
  14. parser.add_argument('--nosmooth', default=False, action='store_true',
  15. help='Prevent smoothing face detections over a short temporal window')
  16. parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
  17. help='Padding (top, bottom, left, right). Please adjust to include chin at least')
  18. parser.add_argument('--face_det_batch_size', type=int,
  19. help='Batch size for face detection', default=16)
  20. args = parser.parse_args()
  21. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  22. print('Using {} for inference.'.format(device))
  23. def osmakedirs(path_list):
  24. for path in path_list:
  25. os.makedirs(path) if not os.path.exists(path) else None
  26. def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
  27. cap = cv2.VideoCapture(vid_path)
  28. count = 0
  29. while True:
  30. if count > cut_frame:
  31. break
  32. ret, frame = cap.read()
  33. if ret:
  34. #cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
  35. cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
  36. count += 1
  37. else:
  38. break
  39. def read_imgs(img_list):
  40. frames = []
  41. print('reading images...')
  42. for img_path in tqdm(img_list):
  43. frame = cv2.imread(img_path)
  44. frames.append(frame)
  45. return frames
  46. def get_smoothened_boxes(boxes, T):
  47. for i in range(len(boxes)):
  48. if i + T > len(boxes):
  49. window = boxes[len(boxes) - T:]
  50. else:
  51. window = boxes[i : i + T]
  52. boxes[i] = np.mean(window, axis=0)
  53. return boxes
  54. def face_detect(images):
  55. detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
  56. flip_input=False, device=device)
  57. batch_size = args.face_det_batch_size
  58. # 确保所有图像具有相同的尺寸
  59. if len(images) > 0:
  60. target_shape = images[0].shape
  61. for idx, img in enumerate(images):
  62. if img.shape != target_shape:
  63. # 调整图像尺寸以匹配第一帧
  64. images[idx] = cv2.resize(img, (target_shape[1], target_shape[0]))
  65. while 1:
  66. predictions = []
  67. try:
  68. for i in tqdm(range(0, len(images), batch_size)):
  69. batch_images = images[i:i + batch_size]
  70. # 确保批次中的所有图像尺寸一致
  71. batch_images = [cv2.resize(img, (target_shape[1], target_shape[0])) if img.shape != target_shape else img for img in batch_images]
  72. predictions.extend(detector.get_detections_for_batch(np.array(batch_images)))
  73. except RuntimeError:
  74. if batch_size == 1:
  75. raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
  76. batch_size //= 2
  77. print('Recovering from OOM error; New batch size: {}'.format(batch_size))
  78. continue
  79. break
  80. results = []
  81. pady1, pady2, padx1, padx2 = args.pads
  82. for rect, image in zip(predictions, images):
  83. if rect is None:
  84. cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
  85. raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
  86. y1 = max(0, rect[1] - pady1)
  87. y2 = min(image.shape[0], rect[3] + pady2)
  88. x1 = max(0, rect[0] - padx1)
  89. x2 = min(image.shape[1], rect[2] + padx2)
  90. results.append([x1, y1, x2, y2])
  91. boxes = np.array(results)
  92. if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
  93. results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
  94. del detector
  95. return results
  96. if __name__ == "__main__":
  97. avatar_path = f"../data/avatars/{args.avatar_id}"
  98. full_imgs_path = f"{avatar_path}/full_imgs"
  99. face_imgs_path = f"{avatar_path}/face_imgs"
  100. coords_path = f"{avatar_path}/coords.pkl"
  101. osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
  102. print(args)
  103. #if os.path.isfile(args.video_path):
  104. video2imgs(args.video_path, full_imgs_path, ext = 'png')
  105. input_img_list = sorted(glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
  106. frames = read_imgs(input_img_list)
  107. face_det_results = face_detect(frames)
  108. coord_list = []
  109. idx = 0
  110. for frame,coords in face_det_results:
  111. #x1, y1, x2, y2 = bbox
  112. resized_crop_frame = cv2.resize(frame,(args.img_size, args.img_size)) #,interpolation = cv2.INTER_LANCZOS4)
  113. cv2.imwrite(f"{face_imgs_path}/{idx:08d}.png", resized_crop_frame)
  114. coord_list.append(coords)
  115. idx = idx + 1
  116. with open(coords_path, 'wb') as f:
  117. pickle.dump(coord_list, f)