genavatar.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. import argparse
  2. import glob
  3. import json
  4. import os
  5. import pickle
  6. import shutil
  7. import cv2
  8. import numpy as np
  9. import torch
  10. # import torchvision.transforms as transforms
  11. # from PIL import Image
  12. # from diffusers import AutoencoderKL
  13. # from face_alignment import NetworkSize
  14. # from mmpose.apis import inference_topdown, init_model
  15. # from mmpose.structures import merge_data_samples
  16. from tqdm import tqdm
  17. from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
  18. from musetalk.utils.blending import get_image_prepare_material
  19. from musetalk.utils.utils import load_all_model
  20. try:
  21. from utils.face_parsing import FaceParsing
  22. except ModuleNotFoundError:
  23. from musetalk.utils.face_parsing import FaceParsing
  24. def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
  25. cap = cv2.VideoCapture(vid_path)
  26. count = 0
  27. while True:
  28. if count > cut_frame:
  29. break
  30. ret, frame = cap.read()
  31. if ret:
  32. cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
  33. cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
  34. count += 1
  35. else:
  36. break
  37. '''
  38. def read_imgs(img_list):
  39. frames = []
  40. print('reading images...')
  41. for img_path in tqdm(img_list):
  42. frame = cv2.imread(img_path)
  43. frames.append(frame)
  44. return frames
  45. def get_landmark_and_bbox(img_list, upperbondrange=0):
  46. frames = read_imgs(img_list)
  47. batch_size_fa = 1
  48. batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
  49. coords_list = []
  50. landmarks = []
  51. if upperbondrange != 0:
  52. print('get key_landmark and face bounding boxes with the bbox_shift:', upperbondrange)
  53. else:
  54. print('get key_landmark and face bounding boxes with the default value')
  55. average_range_minus = []
  56. average_range_plus = []
  57. coord_placeholder = (0.0, 0.0, 0.0, 0.0)
  58. for fb in tqdm(batches):
  59. results = inference_topdown(model, np.asarray(fb)[0])
  60. results = merge_data_samples(results)
  61. keypoints = results.pred_instances.keypoints
  62. face_land_mark = keypoints[0][23:91]
  63. face_land_mark = face_land_mark.astype(np.int32)
  64. # get bounding boxes by face detetion
  65. bbox = fa.get_detections_for_batch(np.asarray(fb))
  66. # adjust the bounding box refer to landmark
  67. # Add the bounding box to a tuple and append it to the coordinates list
  68. for j, f in enumerate(bbox):
  69. if f is None: # no face in the image
  70. coords_list += [coord_placeholder]
  71. continue
  72. half_face_coord = face_land_mark[29] # np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
  73. range_minus = (face_land_mark[30] - face_land_mark[29])[1]
  74. range_plus = (face_land_mark[29] - face_land_mark[28])[1]
  75. average_range_minus.append(range_minus)
  76. average_range_plus.append(range_plus)
  77. if upperbondrange != 0:
  78. half_face_coord[1] = upperbondrange + half_face_coord[1] # 手动调整 + 向下(偏29) - 向上(偏28)
  79. half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
  80. upper_bond = half_face_coord[1] - half_face_dist
  81. f_landmark = (
  82. np.min(face_land_mark[:, 0]), int(upper_bond), np.max(face_land_mark[:, 0]),
  83. np.max(face_land_mark[:, 1]))
  84. x1, y1, x2, y2 = f_landmark
  85. if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: # if the landmark bbox is not suitable, reuse the bbox
  86. coords_list += [f]
  87. w, h = f[2] - f[0], f[3] - f[1]
  88. print("error bbox:", f)
  89. else:
  90. coords_list += [f_landmark]
  91. return coords_list, frames
  92. class FaceAlignment:
  93. def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
  94. device='cuda', flip_input=False, face_detector='sfd', verbose=False):
  95. self.device = device
  96. self.flip_input = flip_input
  97. self.landmarks_type = landmarks_type
  98. self.verbose = verbose
  99. network_size = int(network_size)
  100. if 'cuda' in device:
  101. torch.backends.cudnn.benchmark = True
  102. # torch.backends.cuda.matmul.allow_tf32 = False
  103. # torch.backends.cudnn.benchmark = True
  104. # torch.backends.cudnn.deterministic = False
  105. # torch.backends.cudnn.allow_tf32 = True
  106. print('cuda start')
  107. # Get the face detector
  108. face_detector_module = __import__('face_detection.detection.' + face_detector,
  109. globals(), locals(), [face_detector], 0)
  110. self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
  111. def get_detections_for_batch(self, images):
  112. images = images[..., ::-1]
  113. detected_faces = self.face_detector.detect_from_batch(images.copy())
  114. results = []
  115. for i, d in enumerate(detected_faces):
  116. if len(d) == 0:
  117. results.append(None)
  118. continue
  119. d = d[0]
  120. d = np.clip(d, 0, None)
  121. x1, y1, x2, y2 = map(int, d[:-1])
  122. results.append((x1, y1, x2, y2))
  123. return results
  124. def get_mask_tensor():
  125. """
  126. Creates a mask tensor for image processing.
  127. :return: A mask tensor.
  128. """
  129. mask_tensor = torch.zeros((256, 256))
  130. mask_tensor[:256 // 2, :] = 1
  131. mask_tensor[mask_tensor < 0.5] = 0
  132. mask_tensor[mask_tensor >= 0.5] = 1
  133. return mask_tensor
  134. def preprocess_img(img_name, half_mask=False):
  135. window = []
  136. if isinstance(img_name, str):
  137. window_fnames = [img_name]
  138. for fname in window_fnames:
  139. img = cv2.imread(fname)
  140. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  141. img = cv2.resize(img, (256, 256),
  142. interpolation=cv2.INTER_LANCZOS4)
  143. window.append(img)
  144. else:
  145. img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
  146. window.append(img)
  147. x = np.asarray(window) / 255.
  148. x = np.transpose(x, (3, 0, 1, 2))
  149. x = torch.squeeze(torch.FloatTensor(x))
  150. if half_mask:
  151. x = x * (get_mask_tensor() > 0.5)
  152. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  153. x = normalize(x)
  154. x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
  155. x = x.to(device)
  156. return x
  157. def encode_latents(image):
  158. with torch.no_grad():
  159. init_latent_dist = vae.encode(image.to(vae.dtype)).latent_dist
  160. init_latents = vae.config.scaling_factor * init_latent_dist.sample()
  161. return init_latents
  162. def get_latents_for_unet(img):
  163. ref_image = preprocess_img(img, half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
  164. masked_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
  165. ref_image = preprocess_img(img, half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
  166. ref_latents = encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
  167. latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
  168. return latent_model_input
  169. def get_crop_box(box, expand):
  170. x, y, x1, y1 = box
  171. x_c, y_c = (x + x1) // 2, (y + y1) // 2
  172. w, h = x1 - x, y1 - y
  173. s = int(max(w, h) // 2 * expand)
  174. crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
  175. return crop_box, s
  176. def face_seg(image):
  177. seg_image = fp(image)
  178. if seg_image is None:
  179. print("error, no person_segment")
  180. return None
  181. seg_image = seg_image.resize(image.size)
  182. return seg_image
  183. def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.2):
  184. body = Image.fromarray(image[:, :, ::-1])
  185. x, y, x1, y1 = face_box
  186. # print(x1-x,y1-y)
  187. crop_box, s = get_crop_box(face_box, expand)
  188. x_s, y_s, x_e, y_e = crop_box
  189. face_large = body.crop(crop_box)
  190. ori_shape = face_large.size
  191. mask_image = face_seg(face_large)
  192. mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
  193. mask_image = Image.new('L', ori_shape, 0)
  194. mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
  195. # keep upper_boundary_ratio of talking area
  196. width, height = mask_image.size
  197. top_boundary = int(height * upper_boundary_ratio)
  198. modified_mask_image = Image.new('L', ori_shape, 0)
  199. modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
  200. blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
  201. mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
  202. return mask_array, crop_box
  203. '''
  204. ##todo 简单根据文件后缀判断 要更精确的可以自己修改 使用 magic
  205. def is_video_file(file_path):
  206. video_exts = ['.mp4', '.mkv', '.flv', '.avi', '.mov'] # 这里列出了一些常见的视频文件扩展名,可以根据需要添加更多
  207. file_ext = os.path.splitext(file_path)[1].lower() # 获取文件扩展名并转换为小写
  208. return file_ext in video_exts
  209. def create_dir(dir_path):
  210. if not os.path.exists(dir_path):
  211. os.makedirs(dir_path)
  212. current_dir = os.path.dirname(os.path.abspath(__file__))
  213. def create_musetalk_human(file, avatar_id):
  214. # 保存文件设置 可以不动
  215. save_path = os.path.join(current_dir, f'./data/avatars/{avatar_id}')
  216. save_full_path = os.path.join(current_dir, f'./data/avatars/{avatar_id}/full_imgs')
  217. create_dir(save_path)
  218. create_dir(save_full_path)
  219. mask_out_path = os.path.join(current_dir, f'./data/avatars/{avatar_id}/mask')
  220. create_dir(mask_out_path)
  221. # 模型
  222. mask_coords_path = os.path.join(current_dir, f'{save_path}/mask_coords.pkl')
  223. coords_path = os.path.join(current_dir, f'{save_path}/coords.pkl')
  224. latents_out_path = os.path.join(current_dir, f'{save_path}/latents.pt')
  225. with open(os.path.join(current_dir, f'{save_path}/avator_info.json'), "w") as f:
  226. json.dump({
  227. "avatar_id": avatar_id,
  228. "video_path": file,
  229. "bbox_shift": args.bbox_shift
  230. }, f)
  231. if os.path.isfile(file):
  232. if is_video_file(file):
  233. video2imgs(file, save_full_path, ext='png')
  234. else:
  235. shutil.copyfile(file, f"{save_full_path}/{os.path.basename(file)}")
  236. else:
  237. files = os.listdir(file)
  238. files.sort()
  239. files = [file for file in files if file.split(".")[-1] == "png"]
  240. for filename in files:
  241. shutil.copyfile(f"{file}/{filename}", f"{save_full_path}/{filename}")
  242. input_img_list = sorted(glob.glob(os.path.join(save_full_path, '*.[jpJP][pnPN]*[gG]')))
  243. print("extracting landmarks...")
  244. coord_list, frame_list = get_landmark_and_bbox(input_img_list, args.bbox_shift)
  245. input_latent_list = []
  246. idx = -1
  247. # maker if the bbox is not sufficient
  248. coord_placeholder = (0.0, 0.0, 0.0, 0.0)
  249. for bbox, frame in zip(coord_list, frame_list):
  250. idx = idx + 1
  251. if bbox == coord_placeholder:
  252. continue
  253. x1, y1, x2, y2 = bbox
  254. if args.version == "v15":
  255. y2 = y2 + args.extra_margin
  256. y2 = min(y2, frame.shape[0])
  257. coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
  258. crop_frame = frame[y1:y2, x1:x2]
  259. resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
  260. latents = vae.get_latents_for_unet(resized_crop_frame)
  261. input_latent_list.append(latents)
  262. frame_list_cycle = frame_list #+ frame_list[::-1]
  263. coord_list_cycle = coord_list #+ coord_list[::-1]
  264. input_latent_list_cycle = input_latent_list #+ input_latent_list[::-1]
  265. mask_coords_list_cycle = []
  266. mask_list_cycle = []
  267. for i, frame in enumerate(tqdm(frame_list_cycle)):
  268. cv2.imwrite(f"{save_full_path}/{str(i).zfill(8)}.png", frame)
  269. x1, y1, x2, y2 = coord_list_cycle[i]
  270. if args.version == "v15":
  271. mode = args.parsing_mode
  272. else:
  273. mode = "raw"
  274. mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
  275. cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
  276. mask_coords_list_cycle += [crop_box]
  277. mask_list_cycle.append(mask)
  278. with open(mask_coords_path, 'wb') as f:
  279. pickle.dump(mask_coords_list_cycle, f)
  280. with open(coords_path, 'wb') as f:
  281. pickle.dump(coord_list_cycle, f)
  282. torch.save(input_latent_list_cycle, os.path.join(latents_out_path))
  283. # initialize the mmpose model
  284. # device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
  285. # fa = FaceAlignment(1, flip_input=False, device=device)
  286. # config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
  287. # checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))
  288. # model = init_model(config_file, checkpoint_file, device=device)
  289. # vae = AutoencoderKL.from_pretrained(os.path.abspath(os.path.join(current_dir, '../models/sd-vae-ft-mse')))
  290. # vae.to(device)
  291. # fp = FaceParsing(os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/resnet18-5c106cde.pth')),
  292. # os.path.abspath(os.path.join(current_dir, '../models/face-parse-bisent/79999_iter.pth')))
  293. if __name__ == '__main__':
  294. # 视频文件地址
  295. parser = argparse.ArgumentParser()
  296. parser.add_argument("--file",
  297. type=str,
  298. default=r'D:\ok\00000000.png',
  299. )
  300. parser.add_argument("--avatar_id",
  301. type=str,
  302. default='musetalk_avatar1',
  303. )
  304. parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
  305. parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
  306. parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
  307. parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
  308. parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
  309. parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
  310. parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
  311. args = parser.parse_args()
  312. # Set computing device
  313. device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
  314. # Load model weights
  315. vae, unet, pe = load_all_model(
  316. device=device
  317. )
  318. vae.vae = vae.vae.half().to(device)
  319. # Initialize face parser with configurable parameters based on version
  320. if args.version == "v15":
  321. fp = FaceParsing(
  322. left_cheek_width=args.left_cheek_width,
  323. right_cheek_width=args.right_cheek_width
  324. )
  325. else: # v1
  326. fp = FaceParsing()
  327. create_musetalk_human(args.file, args.avatar_id)