genavatar.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from os import listdir, path
  2. import numpy as np
  3. import scipy, cv2, os, sys, argparse
  4. import json, subprocess, random, string
  5. from tqdm import tqdm
  6. from glob import glob
  7. import torch
  8. import pickle
  9. import face_detection
  10. parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
  11. parser.add_argument('--img_size', default=96, type=int)
  12. parser.add_argument('--avatar_id', default='wav2lip_avatar1', type=str)
  13. parser.add_argument('--video_path', default='', type=str)
  14. parser.add_argument('--nosmooth', default=False, action='store_true',
  15. help='Prevent smoothing face detections over a short temporal window')
  16. parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
  17. help='Padding (top, bottom, left, right). Please adjust to include chin at least')
  18. parser.add_argument('--face_det_batch_size', type=int,
  19. help='Batch size for face detection', default=16)
  20. args = parser.parse_args()
  21. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  22. print('Using {} for inference.'.format(device))
  23. def osmakedirs(path_list):
  24. for path in path_list:
  25. os.makedirs(path) if not os.path.exists(path) else None
  26. def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
  27. cap = cv2.VideoCapture(vid_path)
  28. count = 0
  29. while True:
  30. if count > cut_frame:
  31. break
  32. ret, frame = cap.read()
  33. if ret:
  34. cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
  35. cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
  36. count += 1
  37. else:
  38. break
  39. def read_imgs(img_list):
  40. frames = []
  41. print('reading images...')
  42. for img_path in tqdm(img_list):
  43. frame = cv2.imread(img_path)
  44. frames.append(frame)
  45. return frames
  46. def get_smoothened_boxes(boxes, T):
  47. for i in range(len(boxes)):
  48. if i + T > len(boxes):
  49. window = boxes[len(boxes) - T:]
  50. else:
  51. window = boxes[i : i + T]
  52. boxes[i] = np.mean(window, axis=0)
  53. return boxes
  54. def face_detect(images):
  55. detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
  56. flip_input=False, device=device)
  57. batch_size = args.face_det_batch_size
  58. while 1:
  59. predictions = []
  60. try:
  61. for i in tqdm(range(0, len(images), batch_size)):
  62. predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
  63. except RuntimeError:
  64. if batch_size == 1:
  65. raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
  66. batch_size //= 2
  67. print('Recovering from OOM error; New batch size: {}'.format(batch_size))
  68. continue
  69. break
  70. results = []
  71. pady1, pady2, padx1, padx2 = args.pads
  72. for rect, image in zip(predictions, images):
  73. if rect is None:
  74. cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
  75. raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
  76. y1 = max(0, rect[1] - pady1)
  77. y2 = min(image.shape[0], rect[3] + pady2)
  78. x1 = max(0, rect[0] - padx1)
  79. x2 = min(image.shape[1], rect[2] + padx2)
  80. results.append([x1, y1, x2, y2])
  81. boxes = np.array(results)
  82. if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
  83. results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
  84. del detector
  85. return results
  86. if __name__ == "__main__":
  87. avatar_path = f"../data/avatars/{args.avatar_id}"
  88. full_imgs_path = f"{avatar_path}/full_imgs"
  89. face_imgs_path = f"{avatar_path}/face_imgs"
  90. coords_path = f"{avatar_path}/coords.pkl"
  91. osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
  92. print(args)
  93. #if os.path.isfile(args.video_path):
  94. video2imgs(args.video_path, full_imgs_path, ext = 'png')
  95. input_img_list = sorted(glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
  96. frames = read_imgs(input_img_list)
  97. face_det_results = face_detect(frames)
  98. coord_list = []
  99. idx = 0
  100. for frame,coords in face_det_results:
  101. #x1, y1, x2, y2 = bbox
  102. resized_crop_frame = cv2.resize(frame,(args.img_size, args.img_size)) #,interpolation = cv2.INTER_LANCZOS4)
  103. cv2.imwrite(f"{face_imgs_path}/{idx:08d}.png", resized_crop_frame)
  104. coord_list.append(coords)
  105. idx = idx + 1
  106. with open(coords_path, 'wb') as f:
  107. pickle.dump(coord_list, f)