| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- ###############################################################################
- # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
- # email: lipku@foxmail.com
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- ###############################################################################
- import math
- import torch
- import numpy as np
- #from .utils import *
- import os
- import time
- import cv2
- import glob
- import pickle
- import copy
- import queue
- from queue import Queue
- from threading import Thread, Event
- import torch.multiprocessing as mp
- from lipasr import LipASR
- import asyncio
- from av import AudioFrame, VideoFrame
- from wav2lip.models import Wav2Lip
- from basereal import BaseReal
- #from imgcache import ImgCache
- from tqdm import tqdm
- from logger import logger
- device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
- print('Using {} for inference.'.format(device))
- def _load(checkpoint_path):
- if device == 'cuda':
- checkpoint = torch.load(checkpoint_path) #,weights_only=True
- else:
- checkpoint = torch.load(checkpoint_path,
- map_location=lambda storage, loc: storage)
- return checkpoint
- def load_model(path):
- model = Wav2Lip()
- logger.info("Load checkpoint from: {}".format(path))
- checkpoint = _load(path)
- s = checkpoint["state_dict"]
- new_s = {}
- for k, v in s.items():
- new_s[k.replace('module.', '')] = v
- model.load_state_dict(new_s)
- model = model.to(device)
- return model.eval()
- def load_avatar(avatar_id):
- avatar_path = f"./data/avatars/{avatar_id}"
- full_imgs_path = f"{avatar_path}/full_imgs"
- face_imgs_path = f"{avatar_path}/face_imgs"
- coords_path = f"{avatar_path}/coords.pkl"
-
- with open(coords_path, 'rb') as f:
- coord_list_cycle = pickle.load(f)
- input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
- input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
- frame_list_cycle = read_imgs(input_img_list)
- #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
- input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
- input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
- face_list_cycle = read_imgs(input_face_list)
- return frame_list_cycle,face_list_cycle,coord_list_cycle
- @torch.no_grad()
- def warm_up(batch_size,model,modelres):
- # 预热函数
- logger.info('warmup model...')
- img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
- mel_batch = torch.ones(batch_size, 1, 80, 16).to(device)
- model(mel_batch, img_batch)
- def read_imgs(img_list):
- frames = []
- logger.info('reading images...')
- for img_path in tqdm(img_list):
- frame = cv2.imread(img_path)
- frames.append(frame)
- return frames
- def __mirror_index(size, index):
- #size = len(self.coord_list_cycle)
- turn = index // size
- res = index % size
- if turn % 2 == 0:
- return res
- else:
- return size - res - 1
- def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,model):
-
- #model = load_model("./models/wav2lip.pth")
- # input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
- # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
- # face_list_cycle = read_imgs(input_face_list)
-
- #input_latent_list_cycle = torch.load(latents_out_path)
- length = len(face_list_cycle)
- index = 0
- count=0
- counttime=0
- logger.info('start inference')
- while not quit_event.is_set():
- starttime=time.perf_counter()
- mel_batch = []
- try:
- mel_batch = audio_feat_queue.get(block=True, timeout=1)
- except queue.Empty:
- continue
-
- is_all_silence=True
- audio_frames = []
- for _ in range(batch_size*2):
- frame,type,eventpoint = audio_out_queue.get()
- audio_frames.append((frame,type,eventpoint))
- if type==0:
- is_all_silence=False
- if is_all_silence:
- for i in range(batch_size):
- res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
- index = index + 1
- else:
- # print('infer=======')
- t=time.perf_counter()
- img_batch = []
- for i in range(batch_size):
- idx = __mirror_index(length,index+i)
- face = face_list_cycle[idx]
- img_batch.append(face)
- img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
- img_masked = img_batch.copy()
- img_masked[:, face.shape[0]//2:] = 0
- img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
- mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
-
- img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
- mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
- with torch.no_grad():
- pred = model(mel_batch, img_batch)
- pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
- counttime += (time.perf_counter() - t)
- count += batch_size
- #_totalframe += 1
- if count>=100:
- logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
- count=0
- counttime=0
- for i,res_frame in enumerate(pred):
- #self.__pushmedia(res_frame,loop,audio_track,video_track)
- res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
- index = index + 1
- #print('total batch time:',time.perf_counter()-starttime)
- logger.info('lipreal inference processor stop')
- class LipReal(BaseReal):
- @torch.no_grad()
- def __init__(self, opt, model, avatar):
- super().__init__(opt,model, avatar)
- #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
- # self.W = opt.W
- # self.H = opt.H
- self.fps = opt.fps # 20 ms per frame
-
- self.batch_size = opt.batch_size
- self.idx = 0
- self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue
- #self.__loadavatar()
- self.model = model
- self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
- self.asr = LipASR(opt,self)
- self.asr.warm_up()
-
- self.render_event = mp.Event()
-
- # def __del__(self):
- # logger.info(f'lipreal({self.sessionid}) delete')
- import copy
- import cv2
- import numpy as np
- import copy
- import cv2
- import numpy as np
- def paste_back_frame(self, pred_frame, idx: int):
- """
- 解决唇形缩小+位置偏移 + 修复脖子旁背景色彩不一致问题
- """
- # ========== 保留调试代码 ==========
- # try:
- # import os
- # debug_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'debug')
- # os.makedirs(debug_dir, exist_ok=True)
- # bbox = self.coord_list_cycle[idx]
- # y1, y2, x1, x2 = bbox
- # debug_frame = copy.deepcopy(self.frame_list_cycle[idx])
- # cv2.rectangle(debug_frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
- # cv2.imwrite(os.path.join(debug_dir, f'frame_{idx}_bbox.jpg'), debug_frame)
- # cv2.imwrite(os.path.join(debug_dir, f'frame_{idx}_pred.jpg'), pred_frame.astype(np.uint8))
- # logger.info(f"帧{idx}:bbox坐标(y1={y1},y2={y2},x1={x1},x2={x2}) | 生成帧尺寸{pred_frame.shape}")
- # except Exception as e:
- # logger.error(f"调试代码报错:{str(e)}")
- # ==============================================
- # ========== 基础坐标处理 ==========
- bbox = self.coord_list_cycle[idx]
- combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
- y1, y2, x1, x2 = bbox
- h_ori, w_ori = combine_frame.shape[:2]
- # 精准校准原始嘴巴区域坐标(仅框住嘴唇,不包含脖子/背景)
- y1 = int(y1) if (y1 > 0) else 280 # 仅框嘴唇上沿
- y2 = int(y2) if (y2 < h_ori) else 350 # 仅框嘴唇下沿
- x1 = int(x1) if (x1 > 0) else 180 # 仅框嘴唇左沿
- x2 = int(x2) if (x2 < w_ori) else 320 # 仅框嘴唇右沿
- if (x2 - x1) <= 0 or (y2 - y1) <= 0:
- return combine_frame
- # ========== 1. 强制1:1缩放(保证唇形尺寸) ==========
- target_w = x2 - x1
- target_h = y2 - y1
- res_frame = cv2.resize(
- pred_frame.astype(np.uint8),
- (target_w, target_h),
- interpolation=cv2.INTER_CUBIC
- )
- # ========== 2. 色彩匹配(仅匹配唇部,不影响背景) ==========
- ori_face_region = combine_frame[y1:y2, x1:x2]
- def simple_color_match(source, target):
- source = source.astype(np.float32)
- target = target.astype(np.float32)
- for i in range(3):
- src_mean = np.mean(source[:, :, i])
- trg_mean = np.mean(target[:, :, i])
- # 降低色彩校正强度,减少对背景的影响
- source[:, :, i] = source[:, :, i] * 0.9 + (source[:, :, i] - src_mean + trg_mean) * 0.1
- return np.clip(source, 0, 255).astype(np.uint8)
- color_matched_frame = simple_color_match(res_frame, ori_face_region)
- # ========== 3. 核心优化:唇部专属掩码(仅融合嘴唇,保留背景) ==========
- # 1. 创建椭圆掩码(仅覆盖嘴唇区域,避开脖子/背景)
- mask = np.zeros((target_h, target_w), dtype=np.float32)
- center = (target_w // 2, target_h // 2)
- # 椭圆尺寸:仅覆盖嘴唇核心区域(不超出嘴唇范围)
- axes = (int(target_w * 0.4), int(target_h * 0.35))
- cv2.ellipse(mask, center, axes, 0, 0, 360, 1, -1)
- # 2. 轻量羽化(仅嘴唇边缘,不扩散到背景)
- feather_width = min(4, target_w // 20, target_h // 20) # 进一步缩小羽化范围
- if feather_width > 0:
- mask = cv2.GaussianBlur(mask, (feather_width * 2 + 1, feather_width * 2 + 1), feather_width)
- mask_3ch = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
- # ========== 4. 精准对齐 + 背景保留 ==========
- paste_x1 = x1
- paste_x2 = x2
- paste_y1 = y1
- paste_y2 = y2
- # ========== 5. 最终融合(仅替换嘴唇,背景完全用原图) ==========
- ori_paste_region = combine_frame[paste_y1:paste_y2, paste_x1:paste_x2].astype(np.float32)
- color_matched_frame = color_matched_frame.astype(np.float32)
- # 核心:掩码只作用于嘴唇,背景区域完全保留原图
- fused_region = mask_3ch * color_matched_frame + (1 - mask_3ch) * ori_paste_region
- fused_region = np.clip(fused_region, 0, 255).astype(np.uint8)
- # 仅替换嘴唇区域,脖子/背景区域完全不动
- combine_frame[paste_y1:paste_y2, paste_x1:paste_x2] = fused_region
- return combine_frame
-
- def render(self,quit_event,loop=None,audio_track=None,video_track=None):
- #if self.opt.asr:
- # self.asr.warm_up()
- self.init_customindex()
- self.tts.render(quit_event)
-
- infer_quit_event = Event()
- infer_thread = Thread(target=inference, args=(infer_quit_event,self.batch_size,self.face_list_cycle,
- self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
- self.model,)) #mp.Process
- infer_thread.start()
-
- process_quit_event = Event()
- process_thread = Thread(target=self.process_frames, args=(process_quit_event,loop,audio_track,video_track))
- process_thread.start()
- #self.render_event.set() #start infer process render
- count=0
- totaltime=0
- _starttime=time.perf_counter()
- #_totalframe=0
- while not quit_event.is_set():
- # update texture every frame
- # audio stream thread...
- t = time.perf_counter()
- self.asr.run_step()
- # if video_track._queue.qsize()>=2*self.opt.batch_size:
- # print('sleep qsize=',video_track._queue.qsize())
- # time.sleep(0.04*video_track._queue.qsize()*0.8)
- if video_track and video_track._queue.qsize()>=5:
- logger.debug('sleep qsize=%d',video_track._queue.qsize())
- time.sleep(0.04*video_track._queue.qsize()*0.8)
-
- # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
- # if delay > 0:
- # time.sleep(delay)
- #self.render_event.clear() #end infer process render
- logger.info('lipreal thread stop')
- infer_quit_event.set()
- infer_thread.join()
- process_quit_event.set()
- process_thread.join()
-
|