genavatar-bak.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import argparse
  2. import os
  3. import cv2
  4. import torch
  5. import numpy as np
  6. import torch.nn as nn
  7. from torch import optim
  8. from tqdm import tqdm
  9. from torch.utils.data import DataLoader
  10. from unet import Model
  11. import pickle
  12. # from unet2 import Model
  13. # from unet_att import Model
  14. import time
  15. def osmakedirs(path_list):
  16. for path in path_list:
  17. os.makedirs(path) if not os.path.exists(path) else None
  18. parser = argparse.ArgumentParser(description='Train',
  19. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  20. parser.add_argument('--dataset', type=str, default="")
  21. #parser.add_argument('--save_path', type=str, default="") # end with .mp4 please
  22. parser.add_argument('--checkpoint', type=str, default="")
  23. parser.add_argument('--avatar_id', default='ultralight_avatar1', type=str)
  24. args = parser.parse_args()
  25. checkpoint = args.checkpoint
  26. dataset_dir = args.dataset
  27. img_dir = os.path.join(dataset_dir, "full_body_img/")
  28. lms_dir = os.path.join(dataset_dir, "landmarks/")
  29. avatar_path = f"./results/avatars/{args.avatar_id}"
  30. full_imgs_path = f"{avatar_path}/full_imgs"
  31. face_imgs_path = f"{avatar_path}/face_imgs"
  32. coords_path = f"{avatar_path}/coords.pkl"
  33. pth_path = f"{avatar_path}/ultralight.pth"
  34. osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
  35. len_img = len(os.listdir(img_dir)) - 1
  36. exm_img = cv2.imread(img_dir+"0.jpg")
  37. h, w = exm_img.shape[:2]
  38. step_stride = 0
  39. img_idx = 0
  40. coord_list = []
  41. net = Model(6, 'hubert').cuda()
  42. net.load_state_dict(torch.load(checkpoint))
  43. net.eval()
  44. for i in range(len_img):
  45. if img_idx>len_img - 1:
  46. step_stride = -1
  47. if img_idx<1:
  48. step_stride = 1
  49. img_idx += step_stride
  50. img_path = img_dir + str(img_idx)+'.jpg'
  51. lms_path = lms_dir + str(img_idx)+'.lms'
  52. img = cv2.imread(img_path)
  53. lms_list = []
  54. with open(lms_path, "r") as f:
  55. lines = f.read().splitlines()
  56. for line in lines:
  57. arr = line.split(" ")
  58. arr = np.array(arr, dtype=np.float32)
  59. lms_list.append(arr)
  60. lms = np.array(lms_list, dtype=np.int32)
  61. xmin = lms[1][0]
  62. ymin = lms[52][1]
  63. xmax = lms[31][0]
  64. width = xmax - xmin
  65. ymax = ymin + width
  66. crop_img = img[ymin:ymax, xmin:xmax]
  67. h, w = crop_img.shape[:2]
  68. crop_img = cv2.resize(crop_img, (168, 168), cv2.INTER_AREA)
  69. crop_img_ori = crop_img.copy()
  70. img_real_ex = crop_img[4:164, 4:164].copy()
  71. img_real_ex_ori = img_real_ex.copy()
  72. img_masked = cv2.rectangle(img_real_ex_ori,(5,5,150,145),(0,0,0),-1)
  73. img_masked = img_masked.transpose(2,0,1).astype(np.float32)
  74. img_real_ex = img_real_ex.transpose(2,0,1).astype(np.float32)
  75. img_real_ex_T = torch.from_numpy(img_real_ex / 255.0)
  76. img_masked_T = torch.from_numpy(img_masked / 255.0)
  77. img_concat_T = torch.cat([img_real_ex_T, img_masked_T], axis=0)[None]
  78. audio_feat = torch.zeros(1, 32, 32, 32)
  79. #print('audio_feat:',audio_feat.shape)
  80. audio_feat = audio_feat.cuda()
  81. img_concat_T = img_concat_T.cuda()
  82. #print('img_concat_T:',img_concat_T.shape)
  83. with torch.no_grad():
  84. pred = net(img_concat_T, audio_feat)[0]
  85. pred = pred.cpu().numpy().transpose(1,2,0)*255
  86. pred = np.array(pred, dtype=np.uint8)
  87. crop_img_ori[4:164, 4:164] = pred
  88. crop_img_ori = cv2.resize(crop_img_ori, (w, h))
  89. img[ymin:ymax, xmin:xmax] = crop_img_ori
  90. cv2.putText(img, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
  91. cv2.imwrite(f"{full_imgs_path}/{img_idx:08d}.png", img)
  92. cv2.imwrite(f"{face_imgs_path}/{img_idx:08d}.png", crop_img)
  93. coord_list.append((xmin, ymin, xmin+w, ymin+h))
  94. with open(coords_path, 'wb') as f:
  95. pickle.dump(coord_list, f)
  96. os.system(f"cp {checkpoint} {pth_path}")
  97. # ffmpeg -i test_video.mp4 -i test_audio.pcm -c:v libx264 -c:a aac result_test.mp4