genavatar.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import argparse
  2. import os
  3. import cv2
  4. import torch
  5. import numpy as np
  6. import torch.nn as nn
  7. from torch import optim
  8. from tqdm import tqdm
  9. import pickle
  10. from glob import glob
  11. from face_detect_utils.get_landmark import Landmark
  12. # from unet2 import Model
  13. # from unet_att import Model
  14. import time
  15. def osmakedirs(path_list):
  16. for path in path_list:
  17. os.makedirs(path) if not os.path.exists(path) else None
  18. parser = argparse.ArgumentParser(description='Train',
  19. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  20. parser.add_argument('--video_path', default='', type=str)
  21. parser.add_argument('--img_size', default=168, type=int)
  22. parser.add_argument('--checkpoint', type=str, default="")
  23. parser.add_argument('--avatar_id', default='ultralight_avatar1', type=str)
  24. args = parser.parse_args()
  25. def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
  26. print(f"即将使用OpenCV将视频: {vid_path} 转换为图片")
  27. cap = cv2.VideoCapture(vid_path)
  28. count = 0
  29. while True:
  30. if count > cut_frame:
  31. break
  32. ret, frame = cap.read()
  33. if ret:
  34. cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
  35. count += 1
  36. else:
  37. break
  38. print("视频转换完成")
  39. def read_imgs(img_list):
  40. frames = []
  41. print('读取图片到内存...')
  42. for img_path in tqdm(img_list):
  43. frame = cv2.imread(img_path)
  44. frames.append(frame)
  45. return frames
  46. # ffmpeg -i test_video.mp4 -i test_audio.pcm -c:v libx264 -c:a aac result_test.mp4
  47. if __name__ == "__main__":
  48. avatar_path = f"./results/avatars/{args.avatar_id}"
  49. full_imgs_path = f"{avatar_path}/full_imgs"
  50. face_imgs_path = f"{avatar_path}/face_imgs"
  51. coords_path = f"{avatar_path}/coords.pkl"
  52. pth_path = f"{avatar_path}/ultralight.pth"
  53. osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
  54. print(args)
  55. video2imgs(args.video_path, full_imgs_path, ext = 'png')
  56. input_img_list = sorted(glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
  57. #frames = read_imgs(input_img_list)
  58. #face_det_results = face_detect(frames)
  59. coord_list = []
  60. idx = 0
  61. print(f"开始人脸检测")
  62. landmark = Landmark()
  63. target_size = args.img_size
  64. for i in tqdm(range(len(input_img_list))):
  65. img = cv2.imread(input_img_list[i])
  66. lms, x1, y1 = landmark.detect(input_img_list[i])
  67. xmin = lms[1][0]+x1
  68. ymin = lms[52][1]+y1
  69. xmax = lms[31][0]+x1
  70. width = xmax - xmin
  71. ymax = ymin + width
  72. crop_img = img[ymin:ymax, xmin:xmax]
  73. h, w = crop_img.shape[:2]
  74. crop_img = cv2.resize(crop_img, (target_size, target_size), cv2.INTER_AREA)
  75. # cv2.imwrite(f"{full_imgs_path}/{idx:08d}.png", img)
  76. cv2.imwrite(f"{face_imgs_path}/{idx:08d}.png", crop_img)
  77. coord_list.append((xmin, ymin, xmin+w, ymin+h))
  78. idx = idx + 1
  79. print(f"共检测到{idx}张人脸")
  80. print(f"写入数据到坐标文件:{coords_path}")
  81. with open(coords_path, 'wb') as f:
  82. pickle.dump(coord_list, f)
  83. os.system(f"cp {args.checkpoint} {pth_path}")