lightreal.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. ###############################################################################
  2. # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3. # email: lipku@foxmail.com
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. ###############################################################################
  17. import math
  18. import torch
  19. import numpy as np
  20. #from .utils import *
  21. import os
  22. import time
  23. import cv2
  24. import glob
  25. import pickle
  26. import copy
  27. import queue
  28. from queue import Queue
  29. from threading import Thread, Event
  30. import torch.multiprocessing as mp
  31. from hubertasr import HubertASR
  32. import asyncio
  33. from av import AudioFrame, VideoFrame
  34. from basereal import BaseReal
  35. #from imgcache import ImgCache
  36. from tqdm import tqdm
  37. #new
  38. import os
  39. import cv2
  40. import torch
  41. import numpy as np
  42. import torch.nn as nn
  43. from torch import optim
  44. from tqdm import tqdm
  45. from transformers import Wav2Vec2Processor, HubertModel
  46. from torch.utils.data import DataLoader
  47. from ultralight.unet import Model
  48. from ultralight.audio2feature import Audio2Feature
  49. from logger import logger
  50. device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
  51. print('Using {} for inference.'.format(device))
  52. def load_model(opt):
  53. audio_processor = Audio2Feature()
  54. return audio_processor
  55. def load_avatar(avatar_id):
  56. avatar_path = f"./data/avatars/{avatar_id}"
  57. full_imgs_path = f"{avatar_path}/full_imgs"
  58. face_imgs_path = f"{avatar_path}/face_imgs"
  59. coords_path = f"{avatar_path}/coords.pkl"
  60. model = Model(6, 'hubert').to(device) # 假设Model是你自定义的类
  61. model.load_state_dict(torch.load(f"{avatar_path}/ultralight.pth"))
  62. with open(coords_path, 'rb') as f:
  63. coord_list_cycle = pickle.load(f)
  64. input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
  65. input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  66. frame_list_cycle = read_imgs(input_img_list)
  67. #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
  68. input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
  69. input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  70. face_list_cycle = read_imgs(input_face_list)
  71. return model.eval(),frame_list_cycle,face_list_cycle,coord_list_cycle
  72. @torch.no_grad()
  73. def warm_up(batch_size,avatar,modelres):
  74. logger.info('warmup model...')
  75. model,_,_,_ = avatar
  76. img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
  77. mel_batch = torch.ones(batch_size, 16, 32, 32).to(device)
  78. model(img_batch, mel_batch)
  79. def read_imgs(img_list):
  80. frames = []
  81. logger.info('reading images...')
  82. for img_path in tqdm(img_list):
  83. frame = cv2.imread(img_path)
  84. frames.append(frame)
  85. return frames
  86. def get_audio_features(features, index):
  87. left = index - 8
  88. right = index + 8
  89. pad_left = 0
  90. pad_right = 0
  91. if left < 0:
  92. pad_left = -left
  93. left = 0
  94. if right > features.shape[0]:
  95. pad_right = right - features.shape[0]
  96. right = features.shape[0]
  97. auds = torch.from_numpy(features[left:right])
  98. if pad_left > 0:
  99. auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0)
  100. if pad_right > 0:
  101. auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16]
  102. return auds
  103. def read_lms(lms_list):
  104. land_marks = []
  105. logger.info('reading lms...')
  106. for lms_path in tqdm(lms_list):
  107. file_landmarks = [] # Store landmarks for this file
  108. with open(lms_path, "r") as f:
  109. lines = f.read().splitlines()
  110. for line in lines:
  111. arr = list(filter(None, line.split(" ")))
  112. if arr:
  113. arr = np.array(arr, dtype=np.float32)
  114. file_landmarks.append(arr)
  115. land_marks.append(file_landmarks) # Add the file's landmarks to the overall list
  116. return land_marks
  117. def __mirror_index(size, index):
  118. #size = len(self.coord_list_cycle)
  119. turn = index // size
  120. res = index % size
  121. if turn % 2 == 0:
  122. return res
  123. else:
  124. return size - res - 1
  125. def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_out_queue, res_frame_queue, model):
  126. length = len(face_list_cycle)
  127. index = 0
  128. count = 0
  129. counttime = 0
  130. logger.info('start inference')
  131. while not quit_event.is_set():
  132. starttime=time.perf_counter()
  133. try:
  134. mel_batch = audio_feat_queue.get(block=True, timeout=1)
  135. except queue.Empty:
  136. continue
  137. is_all_silence=True
  138. audio_frames = []
  139. for _ in range(batch_size*2):
  140. frame,type_,eventpoint = audio_out_queue.get()
  141. audio_frames.append((frame,type_,eventpoint))
  142. if type_==0:
  143. is_all_silence=False
  144. if is_all_silence:
  145. for i in range(batch_size):
  146. res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
  147. index = index + 1
  148. else:
  149. t = time.perf_counter()
  150. img_batch = []
  151. for i in range(batch_size):
  152. idx = __mirror_index(length, index + i)
  153. #face = face_list_cycle[idx]
  154. crop_img = face_list_cycle[idx] #face[ymin:ymax, xmin:xmax]
  155. # h, w = crop_img.shape[:2]
  156. #crop_img = cv2.resize(crop_img, (168, 168), cv2.INTER_AREA)
  157. #crop_img_ori = crop_img.copy()
  158. img_real_ex = crop_img[4:164, 4:164].copy()
  159. img_real_ex_ori = img_real_ex.copy()
  160. img_masked = cv2.rectangle(img_real_ex_ori,(5,5,150,145),(0,0,0),-1)
  161. img_masked = img_masked.transpose(2,0,1).astype(np.float32)
  162. img_real_ex = img_real_ex.transpose(2,0,1).astype(np.float32)
  163. img_real_ex_T = torch.from_numpy(img_real_ex / 255.0)
  164. img_masked_T = torch.from_numpy(img_masked / 255.0)
  165. img_concat_T = torch.cat([img_real_ex_T, img_masked_T], axis=0)[None]
  166. img_batch.append(img_concat_T)
  167. reshaped_mel_batch = [arr.reshape(16, 32, 32) for arr in mel_batch]
  168. mel_batch = torch.stack([torch.from_numpy(arr) for arr in reshaped_mel_batch])
  169. img_batch = torch.stack(img_batch).squeeze(1)
  170. with torch.no_grad():
  171. pred = model(img_batch.to(device),mel_batch.to(device))
  172. pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
  173. counttime += (time.perf_counter() - t)
  174. count += batch_size
  175. if count >= 100:
  176. logger.info(f"------actual avg infer fps:{count / counttime:.4f}")
  177. count = 0
  178. counttime = 0
  179. for i,res_frame in enumerate(pred):
  180. #self.__pushmedia(res_frame,loop,audio_track,video_track)
  181. res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
  182. index = index + 1
  183. # for i, pred_frame in enumerate(pred):
  184. # pred_frame_uint8 = np.array(pred_frame, dtype=np.uint8)
  185. # res_frame_queue.put((pred_frame_uint8, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
  186. # index = (index + 1) % length
  187. #print('total batch time:', time.perf_counter() - starttime)
  188. logger.info('lightreal inference processor stop')
  189. class LightReal(BaseReal):
  190. @torch.no_grad()
  191. def __init__(self, opt, model, avatar):
  192. super().__init__(opt, model, avatar)
  193. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
  194. # self.W = opt.W
  195. # self.H = opt.H
  196. self.fps = opt.fps # 20 ms per frame
  197. self.batch_size = opt.batch_size
  198. self.idx = 0
  199. self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue
  200. #self.__loadavatar()
  201. audio_processor = model
  202. self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
  203. self.asr = HubertASR(opt,self,audio_processor,audio_feat_length =[4,4])
  204. self.asr.warm_up()
  205. #self.__warm_up()
  206. self.render_event = mp.Event()
  207. # def __del__(self):
  208. # logger.info(f'lightreal({self.sessionid}) delete')
  209. def paste_back_frame(self,pred_frame,idx:int):
  210. bbox = self.coord_list_cycle[idx]
  211. combine_frame = copy.deepcopy(self.frame_list_cycle[idx])
  212. x1, y1, x2, y2 = bbox
  213. crop_img = self.face_list_cycle[idx]
  214. crop_img_ori = crop_img.copy()
  215. #res_frame = np.array(res_frame, dtype=np.uint8)
  216. crop_img_ori[4:164, 4:164] = pred_frame.astype(np.uint8)
  217. crop_img_ori = cv2.resize(crop_img_ori, (x2-x1,y2-y1))
  218. combine_frame[y1:y2, x1:x2] = crop_img_ori
  219. return combine_frame
  220. def render(self,quit_event,loop=None,audio_track=None,video_track=None):
  221. #if self.opt.asr:
  222. # self.asr.warm_up()
  223. self.init_customindex()
  224. self.tts.render(quit_event)
  225. infer_quit_event = Event()
  226. infer_thread = Thread(target=inference, args=(infer_quit_event,self.batch_size,self.face_list_cycle,self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
  227. self.model,)) #mp.Process
  228. infer_thread.start()
  229. process_quit_event = Event()
  230. process_thread = Thread(target=self.process_frames, args=(process_quit_event,loop,audio_track,video_track))
  231. process_thread.start()
  232. #self.render_event.set() #start infer process render
  233. count=0
  234. totaltime=0
  235. _starttime=time.perf_counter()
  236. #_totalframe=0
  237. while not quit_event.is_set():
  238. # update texture every frame
  239. # audio stream thread...
  240. t = time.perf_counter()
  241. self.asr.run_step()
  242. # if video_track._queue.qsize()>=2*self.opt.batch_size:
  243. # print('sleep qsize=',video_track._queue.qsize())
  244. # time.sleep(0.04*video_track._queue.qsize()*0.8)
  245. if video_track and video_track._queue.qsize()>=5:
  246. logger.debug('sleep qsize=%d',video_track._queue.qsize())
  247. time.sleep(0.04*video_track._queue.qsize()*0.8)
  248. # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
  249. # if delay > 0:
  250. # time.sleep(delay)
  251. #self.render_event.clear() #end infer process render
  252. logger.info('lightreal thread stop')
  253. infer_quit_event.set()
  254. infer_thread.join()
  255. process_quit_event.set()
  256. process_thread.join()