lipreal.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. ###############################################################################
  2. # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3. # email: lipku@foxmail.com
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. ###############################################################################
  17. import math
  18. import torch
  19. import numpy as np
  20. #from .utils import *
  21. import os
  22. import time
  23. import cv2
  24. import glob
  25. import pickle
  26. import copy
  27. import queue
  28. from queue import Queue
  29. from threading import Thread, Event
  30. import torch.multiprocessing as mp
  31. from lipasr import LipASR
  32. import asyncio
  33. from av import AudioFrame, VideoFrame
  34. from wav2lip.models import Wav2Lip
  35. from basereal import BaseReal
  36. #from imgcache import ImgCache
  37. from tqdm import tqdm
  38. from logger import logger
  39. device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
  40. print('Using {} for inference.'.format(device))
  41. def _load(checkpoint_path):
  42. if device == 'cuda':
  43. checkpoint = torch.load(checkpoint_path) #,weights_only=True
  44. else:
  45. checkpoint = torch.load(checkpoint_path,
  46. map_location=lambda storage, loc: storage)
  47. return checkpoint
  48. def load_model(path):
  49. model = Wav2Lip()
  50. logger.info("Load checkpoint from: {}".format(path))
  51. checkpoint = _load(path)
  52. s = checkpoint["state_dict"]
  53. new_s = {}
  54. for k, v in s.items():
  55. new_s[k.replace('module.', '')] = v
  56. model.load_state_dict(new_s)
  57. model = model.to(device)
  58. return model.eval()
  59. def load_avatar(avatar_id):
  60. avatar_path = f"./data/avatars/{avatar_id}"
  61. full_imgs_path = f"{avatar_path}/full_imgs"
  62. face_imgs_path = f"{avatar_path}/face_imgs"
  63. coords_path = f"{avatar_path}/coords.pkl"
  64. with open(coords_path, 'rb') as f:
  65. coord_list_cycle = pickle.load(f)
  66. input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
  67. input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  68. frame_list_cycle = read_imgs(input_img_list)
  69. #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
  70. input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
  71. input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  72. face_list_cycle = read_imgs(input_face_list)
  73. return frame_list_cycle,face_list_cycle,coord_list_cycle
  74. @torch.no_grad()
  75. def warm_up(batch_size,model,modelres):
  76. # 预热函数
  77. logger.info('warmup model...')
  78. img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
  79. mel_batch = torch.ones(batch_size, 1, 80, 16).to(device)
  80. model(mel_batch, img_batch)
  81. def read_imgs(img_list):
  82. frames = []
  83. logger.info('reading images...')
  84. for img_path in tqdm(img_list):
  85. frame = cv2.imread(img_path)
  86. frames.append(frame)
  87. return frames
  88. def __mirror_index(size, index):
  89. #size = len(self.coord_list_cycle)
  90. turn = index // size
  91. res = index % size
  92. if turn % 2 == 0:
  93. return res
  94. else:
  95. return size - res - 1
  96. def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,model):
  97. #model = load_model("./models/wav2lip.pth")
  98. # input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
  99. # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  100. # face_list_cycle = read_imgs(input_face_list)
  101. #input_latent_list_cycle = torch.load(latents_out_path)
  102. length = len(face_list_cycle)
  103. index = 0
  104. count=0
  105. counttime=0
  106. logger.info('start inference')
  107. while not quit_event.is_set():
  108. starttime=time.perf_counter()
  109. mel_batch = []
  110. try:
  111. mel_batch = audio_feat_queue.get(block=True, timeout=1)
  112. except queue.Empty:
  113. continue
  114. is_all_silence=True
  115. audio_frames = []
  116. for _ in range(batch_size*2):
  117. frame,type,eventpoint = audio_out_queue.get()
  118. audio_frames.append((frame,type,eventpoint))
  119. if type==0:
  120. is_all_silence=False
  121. if is_all_silence:
  122. for i in range(batch_size):
  123. res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
  124. index = index + 1
  125. else:
  126. # print('infer=======')
  127. t=time.perf_counter()
  128. img_batch = []
  129. for i in range(batch_size):
  130. idx = __mirror_index(length,index+i)
  131. face = face_list_cycle[idx]
  132. img_batch.append(face)
  133. img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
  134. img_masked = img_batch.copy()
  135. img_masked[:, face.shape[0]//2:] = 0
  136. img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
  137. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
  138. img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
  139. mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
  140. with torch.no_grad():
  141. pred = model(mel_batch, img_batch)
  142. pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
  143. counttime += (time.perf_counter() - t)
  144. count += batch_size
  145. #_totalframe += 1
  146. if count>=100:
  147. logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
  148. count=0
  149. counttime=0
  150. for i,res_frame in enumerate(pred):
  151. #self.__pushmedia(res_frame,loop,audio_track,video_track)
  152. res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
  153. index = index + 1
  154. #print('total batch time:',time.perf_counter()-starttime)
  155. logger.info('lipreal inference processor stop')
  156. class LipReal(BaseReal):
  157. @torch.no_grad()
  158. def __init__(self, opt, model, avatar):
  159. super().__init__(opt,model, avatar)
  160. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
  161. # self.W = opt.W
  162. # self.H = opt.H
  163. self.fps = opt.fps # 20 ms per frame
  164. self.batch_size = opt.batch_size
  165. self.idx = 0
  166. self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue
  167. #self.__loadavatar()
  168. self.model = model
  169. self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
  170. self.asr = LipASR(opt,self)
  171. self.asr.warm_up()
  172. self.render_event = mp.Event()
  173. # def __del__(self):
  174. # logger.info(f'lipreal({self.sessionid}) delete')
  175. import copy
  176. import cv2
  177. import numpy as np
  178. import copy
  179. import cv2
  180. import numpy as np
  181. def paste_back_frame(self, pred_frame, idx: int):
  182. """
  183. 解决唇形缩小+位置偏移 + 修复脖子旁背景色彩不一致问题
  184. """
  185. # ========== 保留调试代码 ==========
  186. # try:
  187. # import os
  188. # debug_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'debug')
  189. # os.makedirs(debug_dir, exist_ok=True)
  190. # bbox = self.coord_list_cycle[idx]
  191. # y1, y2, x1, x2 = bbox
  192. # debug_frame = copy.deepcopy(self.frame_list_cycle[idx])
  193. # cv2.rectangle(debug_frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
  194. # cv2.imwrite(os.path.join(debug_dir, f'frame_{idx}_bbox.jpg'), debug_frame)
  195. # cv2.imwrite(os.path.join(debug_dir, f'frame_{idx}_pred.jpg'), pred_frame.astype(np.uint8))
  196. # logger.info(f"帧{idx}:bbox坐标(y1={y1},y2={y2},x1={x1},x2={x2}) | 生成帧尺寸{pred_frame.shape}")
  197. # except Exception as e:
  198. # logger.error(f"调试代码报错:{str(e)}")
  199. # ==============================================
  200. # ========== 基础坐标处理 ==========
  201. bbox = self.coord_list_cycle[idx]
  202. combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
  203. y1, y2, x1, x2 = bbox
  204. h_ori, w_ori = combine_frame.shape[:2]
  205. # 精准校准原始嘴巴区域坐标(仅框住嘴唇,不包含脖子/背景)
  206. y1 = int(y1) if (y1 > 0) else 280 # 仅框嘴唇上沿
  207. y2 = int(y2) if (y2 < h_ori) else 350 # 仅框嘴唇下沿
  208. x1 = int(x1) if (x1 > 0) else 180 # 仅框嘴唇左沿
  209. x2 = int(x2) if (x2 < w_ori) else 320 # 仅框嘴唇右沿
  210. if (x2 - x1) <= 0 or (y2 - y1) <= 0:
  211. return combine_frame
  212. # ========== 1. 强制1:1缩放(保证唇形尺寸) ==========
  213. target_w = x2 - x1
  214. target_h = y2 - y1
  215. res_frame = cv2.resize(
  216. pred_frame.astype(np.uint8),
  217. (target_w, target_h),
  218. interpolation=cv2.INTER_CUBIC
  219. )
  220. # ========== 2. 色彩匹配(仅匹配唇部,不影响背景) ==========
  221. ori_face_region = combine_frame[y1:y2, x1:x2]
  222. def simple_color_match(source, target):
  223. source = source.astype(np.float32)
  224. target = target.astype(np.float32)
  225. for i in range(3):
  226. src_mean = np.mean(source[:, :, i])
  227. trg_mean = np.mean(target[:, :, i])
  228. # 降低色彩校正强度,减少对背景的影响
  229. source[:, :, i] = source[:, :, i] * 0.9 + (source[:, :, i] - src_mean + trg_mean) * 0.1
  230. return np.clip(source, 0, 255).astype(np.uint8)
  231. color_matched_frame = simple_color_match(res_frame, ori_face_region)
  232. # ========== 3. 核心优化:唇部专属掩码(仅融合嘴唇,保留背景) ==========
  233. # 1. 创建椭圆掩码(仅覆盖嘴唇区域,避开脖子/背景)
  234. mask = np.zeros((target_h, target_w), dtype=np.float32)
  235. center = (target_w // 2, target_h // 2)
  236. # 椭圆尺寸:仅覆盖嘴唇核心区域(不超出嘴唇范围)
  237. axes = (int(target_w * 0.4), int(target_h * 0.35))
  238. cv2.ellipse(mask, center, axes, 0, 0, 360, 1, -1)
  239. # 2. 轻量羽化(仅嘴唇边缘,不扩散到背景)
  240. feather_width = min(4, target_w // 20, target_h // 20) # 进一步缩小羽化范围
  241. if feather_width > 0:
  242. mask = cv2.GaussianBlur(mask, (feather_width * 2 + 1, feather_width * 2 + 1), feather_width)
  243. mask_3ch = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
  244. # ========== 4. 精准对齐 + 背景保留 ==========
  245. paste_x1 = x1
  246. paste_x2 = x2
  247. paste_y1 = y1
  248. paste_y2 = y2
  249. # ========== 5. 最终融合(仅替换嘴唇,背景完全用原图) ==========
  250. ori_paste_region = combine_frame[paste_y1:paste_y2, paste_x1:paste_x2].astype(np.float32)
  251. color_matched_frame = color_matched_frame.astype(np.float32)
  252. # 核心:掩码只作用于嘴唇,背景区域完全保留原图
  253. fused_region = mask_3ch * color_matched_frame + (1 - mask_3ch) * ori_paste_region
  254. fused_region = np.clip(fused_region, 0, 255).astype(np.uint8)
  255. # 仅替换嘴唇区域,脖子/背景区域完全不动
  256. combine_frame[paste_y1:paste_y2, paste_x1:paste_x2] = fused_region
  257. return combine_frame
  258. def render(self,quit_event,loop=None,audio_track=None,video_track=None):
  259. #if self.opt.asr:
  260. # self.asr.warm_up()
  261. self.init_customindex()
  262. self.tts.render(quit_event)
  263. infer_quit_event = Event()
  264. infer_thread = Thread(target=inference, args=(infer_quit_event,self.batch_size,self.face_list_cycle,
  265. self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
  266. self.model,)) #mp.Process
  267. infer_thread.start()
  268. process_quit_event = Event()
  269. process_thread = Thread(target=self.process_frames, args=(process_quit_event,loop,audio_track,video_track))
  270. process_thread.start()
  271. #self.render_event.set() #start infer process render
  272. count=0
  273. totaltime=0
  274. _starttime=time.perf_counter()
  275. #_totalframe=0
  276. while not quit_event.is_set():
  277. # update texture every frame
  278. # audio stream thread...
  279. t = time.perf_counter()
  280. self.asr.run_step()
  281. # if video_track._queue.qsize()>=2*self.opt.batch_size:
  282. # print('sleep qsize=',video_track._queue.qsize())
  283. # time.sleep(0.04*video_track._queue.qsize()*0.8)
  284. if video_track and video_track._queue.qsize()>=5:
  285. logger.debug('sleep qsize=%d',video_track._queue.qsize())
  286. time.sleep(0.04*video_track._queue.qsize()*0.8)
  287. # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
  288. # if delay > 0:
  289. # time.sleep(delay)
  290. #self.render_event.clear() #end infer process render
  291. logger.info('lipreal thread stop')
  292. infer_quit_event.set()
  293. infer_thread.join()
  294. process_quit_event.set()
  295. process_thread.join()