############################################################################### # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking # email: lipku@foxmail.com # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### import math import torch import numpy as np #from .utils import * import os import time import cv2 import glob import pickle import copy import queue from queue import Queue from threading import Thread, Event import torch.multiprocessing as mp from lipasr import LipASR import asyncio from av import AudioFrame, VideoFrame from wav2lip.models import Wav2Lip from basereal import BaseReal #from imgcache import ImgCache from tqdm import tqdm from logger import logger device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") print('Using {} for inference.'.format(device)) def _load(checkpoint_path): if device == 'cuda': checkpoint = torch.load(checkpoint_path) #,weights_only=True else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): model = Wav2Lip() logger.info("Load checkpoint from: {}".format(path)) checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval() def load_avatar(avatar_id): avatar_path = f"./data/avatars/{avatar_id}" full_imgs_path = f"{avatar_path}/full_imgs" face_imgs_path = f"{avatar_path}/face_imgs" coords_path = f"{avatar_path}/coords.pkl" with open(coords_path, 'rb') as f: coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) frame_list_cycle = read_imgs(input_img_list) #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000) input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) face_list_cycle = read_imgs(input_face_list) return frame_list_cycle,face_list_cycle,coord_list_cycle @torch.no_grad() def warm_up(batch_size,model,modelres): # 预热函数 logger.info('warmup model...') img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) mel_batch = torch.ones(batch_size, 1, 80, 16).to(device) model(mel_batch, img_batch) def read_imgs(img_list): frames = [] logger.info('reading images...') for img_path in tqdm(img_list): frame = cv2.imread(img_path) frames.append(frame) return frames def __mirror_index(size, index): #size = len(self.coord_list_cycle) turn = index // size res = index % size if turn % 2 == 0: return res else: return size - res - 1 def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,model): #model = load_model("./models/wav2lip.pth") # input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) # face_list_cycle = read_imgs(input_face_list) #input_latent_list_cycle = torch.load(latents_out_path) length = len(face_list_cycle) index = 0 count=0 counttime=0 logger.info('start inference') while not quit_event.is_set(): starttime=time.perf_counter() mel_batch = [] try: mel_batch = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue is_all_silence=True audio_frames = [] for _ in range(batch_size*2): frame,type,eventpoint = audio_out_queue.get() audio_frames.append((frame,type,eventpoint)) if type==0: is_all_silence=False if is_all_silence: for i in range(batch_size): res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 else: # print('infer=======') t=time.perf_counter() img_batch = [] for i in range(batch_size): idx = __mirror_index(length,index+i) face = face_list_cycle[idx] img_batch.append(face) img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0]//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. counttime += (time.perf_counter() - t) count += batch_size #_totalframe += 1 if count>=100: logger.info(f"------actual avg infer fps:{count/counttime:.4f}") count=0 counttime=0 for i,res_frame in enumerate(pred): #self.__pushmedia(res_frame,loop,audio_track,video_track) res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1 #print('total batch time:',time.perf_counter()-starttime) logger.info('lipreal inference processor stop') class LipReal(BaseReal): @torch.no_grad() def __init__(self, opt, model, avatar): super().__init__(opt,model, avatar) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. # self.W = opt.W # self.H = opt.H self.fps = opt.fps # 20 ms per frame self.batch_size = opt.batch_size self.idx = 0 self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue #self.__loadavatar() self.model = model self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar self.asr = LipASR(opt,self) self.asr.warm_up() self.render_event = mp.Event() # def __del__(self): # logger.info(f'lipreal({self.sessionid}) delete') import copy import cv2 import numpy as np import copy import cv2 import numpy as np def paste_back_frame(self, pred_frame, idx: int): """ 解决唇形缩小+位置偏移 + 修复脖子旁背景色彩不一致问题 """ # ========== 保留调试代码 ========== # try: # import os # debug_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'debug') # os.makedirs(debug_dir, exist_ok=True) # bbox = self.coord_list_cycle[idx] # y1, y2, x1, x2 = bbox # debug_frame = copy.deepcopy(self.frame_list_cycle[idx]) # cv2.rectangle(debug_frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2) # cv2.imwrite(os.path.join(debug_dir, f'frame_{idx}_bbox.jpg'), debug_frame) # cv2.imwrite(os.path.join(debug_dir, f'frame_{idx}_pred.jpg'), pred_frame.astype(np.uint8)) # logger.info(f"帧{idx}:bbox坐标(y1={y1},y2={y2},x1={x1},x2={x2}) | 生成帧尺寸{pred_frame.shape}") # except Exception as e: # logger.error(f"调试代码报错:{str(e)}") # ============================================== # ========== 基础坐标处理 ========== bbox = self.coord_list_cycle[idx] combine_frame = copy.deepcopy(self.frame_list_cycle[idx]) y1, y2, x1, x2 = bbox h_ori, w_ori = combine_frame.shape[:2] # 精准校准原始嘴巴区域坐标(仅框住嘴唇,不包含脖子/背景) y1 = int(y1) if (y1 > 0) else 280 # 仅框嘴唇上沿 y2 = int(y2) if (y2 < h_ori) else 350 # 仅框嘴唇下沿 x1 = int(x1) if (x1 > 0) else 180 # 仅框嘴唇左沿 x2 = int(x2) if (x2 < w_ori) else 320 # 仅框嘴唇右沿 if (x2 - x1) <= 0 or (y2 - y1) <= 0: return combine_frame # ========== 1. 强制1:1缩放(保证唇形尺寸) ========== target_w = x2 - x1 target_h = y2 - y1 res_frame = cv2.resize( pred_frame.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_CUBIC ) # ========== 2. 色彩匹配(仅匹配唇部,不影响背景) ========== ori_face_region = combine_frame[y1:y2, x1:x2] def simple_color_match(source, target): source = source.astype(np.float32) target = target.astype(np.float32) for i in range(3): src_mean = np.mean(source[:, :, i]) trg_mean = np.mean(target[:, :, i]) # 降低色彩校正强度,减少对背景的影响 source[:, :, i] = source[:, :, i] * 0.9 + (source[:, :, i] - src_mean + trg_mean) * 0.1 return np.clip(source, 0, 255).astype(np.uint8) color_matched_frame = simple_color_match(res_frame, ori_face_region) # ========== 3. 核心优化:唇部专属掩码(仅融合嘴唇,保留背景) ========== # 1. 创建椭圆掩码(仅覆盖嘴唇区域,避开脖子/背景) mask = np.zeros((target_h, target_w), dtype=np.float32) center = (target_w // 2, target_h // 2) # 椭圆尺寸:仅覆盖嘴唇核心区域(不超出嘴唇范围) axes = (int(target_w * 0.4), int(target_h * 0.35)) cv2.ellipse(mask, center, axes, 0, 0, 360, 1, -1) # 2. 轻量羽化(仅嘴唇边缘,不扩散到背景) feather_width = min(4, target_w // 20, target_h // 20) # 进一步缩小羽化范围 if feather_width > 0: mask = cv2.GaussianBlur(mask, (feather_width * 2 + 1, feather_width * 2 + 1), feather_width) mask_3ch = np.repeat(mask[:, :, np.newaxis], 3, axis=2) # ========== 4. 精准对齐 + 背景保留 ========== paste_x1 = x1 paste_x2 = x2 paste_y1 = y1 paste_y2 = y2 # ========== 5. 最终融合(仅替换嘴唇,背景完全用原图) ========== ori_paste_region = combine_frame[paste_y1:paste_y2, paste_x1:paste_x2].astype(np.float32) color_matched_frame = color_matched_frame.astype(np.float32) # 核心:掩码只作用于嘴唇,背景区域完全保留原图 fused_region = mask_3ch * color_matched_frame + (1 - mask_3ch) * ori_paste_region fused_region = np.clip(fused_region, 0, 255).astype(np.uint8) # 仅替换嘴唇区域,脖子/背景区域完全不动 combine_frame[paste_y1:paste_y2, paste_x1:paste_x2] = fused_region return combine_frame def render(self,quit_event,loop=None,audio_track=None,video_track=None): #if self.opt.asr: # self.asr.warm_up() self.init_customindex() self.tts.render(quit_event) infer_quit_event = Event() infer_thread = Thread(target=inference, args=(infer_quit_event,self.batch_size,self.face_list_cycle, self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, self.model,)) #mp.Process infer_thread.start() process_quit_event = Event() process_thread = Thread(target=self.process_frames, args=(process_quit_event,loop,audio_track,video_track)) process_thread.start() #self.render_event.set() #start infer process render count=0 totaltime=0 _starttime=time.perf_counter() #_totalframe=0 while not quit_event.is_set(): # update texture every frame # audio stream thread... t = time.perf_counter() self.asr.run_step() # if video_track._queue.qsize()>=2*self.opt.batch_size: # print('sleep qsize=',video_track._queue.qsize()) # time.sleep(0.04*video_track._queue.qsize()*0.8) if video_track and video_track._queue.qsize()>=5: logger.debug('sleep qsize=%d',video_track._queue.qsize()) time.sleep(0.04*video_track._queue.qsize()*0.8) # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) #self.render_event.clear() #end infer process render logger.info('lipreal thread stop') infer_quit_event.set() infer_thread.join() process_quit_event.set() process_thread.join()