musereal.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. ###############################################################################
  2. # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3. # email: lipku@foxmail.com
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. ###############################################################################
  17. import math
  18. import torch
  19. import numpy as np
  20. #from .utils import *
  21. import subprocess
  22. import os
  23. import time
  24. import torch.nn.functional as F
  25. import cv2
  26. import glob
  27. import pickle
  28. import copy
  29. import queue
  30. from queue import Queue
  31. from threading import Thread, Event
  32. import torch.multiprocessing as mp
  33. from musetalk.utils.utils import get_file_type,get_video_fps,datagen
  34. #from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
  35. from musetalk.myutil import get_image_blending
  36. from musetalk.utils.utils import load_all_model
  37. from musetalk.whisper.audio2feature import Audio2Feature
  38. from museasr import MuseASR
  39. import asyncio
  40. from av import AudioFrame, VideoFrame
  41. from basereal import BaseReal
  42. from tqdm import tqdm
  43. from logger import logger
  44. def load_model():
  45. # load model weights
  46. vae, unet, pe = load_all_model()
  47. device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
  48. timesteps = torch.tensor([0], device=device)
  49. pe = pe.half().to(device)
  50. vae.vae = vae.vae.half().to(device)
  51. #vae.vae.share_memory().to(device)
  52. unet.model = unet.model.half().to(device)
  53. #unet.model.share_memory()
  54. # Initialize audio processor and Whisper model
  55. audio_processor = Audio2Feature(model_path="./models/whisper")
  56. return vae, unet, pe, timesteps, audio_processor
  57. def load_avatar(avatar_id):
  58. #self.video_path = '' #video_path
  59. #self.bbox_shift = opt.bbox_shift
  60. avatar_path = f"./data/avatars/{avatar_id}"
  61. full_imgs_path = f"{avatar_path}/full_imgs"
  62. coords_path = f"{avatar_path}/coords.pkl"
  63. latents_out_path= f"{avatar_path}/latents.pt"
  64. video_out_path = f"{avatar_path}/vid_output/"
  65. mask_out_path =f"{avatar_path}/mask"
  66. mask_coords_path =f"{avatar_path}/mask_coords.pkl"
  67. avatar_info_path = f"{avatar_path}/avator_info.json"
  68. # self.avatar_info = {
  69. # "avatar_id":self.avatar_id,
  70. # "video_path":self.video_path,
  71. # "bbox_shift":self.bbox_shift
  72. # }
  73. input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True
  74. with open(coords_path, 'rb') as f:
  75. coord_list_cycle = pickle.load(f)
  76. input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
  77. input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  78. frame_list_cycle = read_imgs(input_img_list)
  79. with open(mask_coords_path, 'rb') as f:
  80. mask_coords_list_cycle = pickle.load(f)
  81. input_mask_list = glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]'))
  82. input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
  83. mask_list_cycle = read_imgs(input_mask_list)
  84. return frame_list_cycle,mask_list_cycle,coord_list_cycle,mask_coords_list_cycle,input_latent_list_cycle
  85. @torch.no_grad()
  86. def warm_up(batch_size,model):
  87. # 预热函数
  88. logger.info('warmup model...')
  89. vae, unet, pe, timesteps, audio_processor = model
  90. #batch_size = 16
  91. #timesteps = torch.tensor([0], device=unet.device)
  92. whisper_batch = np.ones((batch_size, 50, 384), dtype=np.uint8)
  93. latent_batch = torch.ones(batch_size, 8, 32, 32).to(unet.device)
  94. audio_feature_batch = torch.from_numpy(whisper_batch)
  95. audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype)
  96. audio_feature_batch = pe(audio_feature_batch)
  97. latent_batch = latent_batch.to(dtype=unet.model.dtype)
  98. pred_latents = unet.model(latent_batch,
  99. timesteps,
  100. encoder_hidden_states=audio_feature_batch).sample
  101. vae.decode_latents(pred_latents)
  102. def read_imgs(img_list):
  103. frames = []
  104. logger.info('reading images...')
  105. for img_path in tqdm(img_list):
  106. frame = cv2.imread(img_path)
  107. frames.append(frame)
  108. return frames
  109. def __mirror_index(size, index):
  110. #size = len(self.coord_list_cycle)
  111. turn = index // size
  112. res = index % size
  113. if turn % 2 == 0:
  114. return res
  115. else:
  116. return size - res - 1
  117. @torch.no_grad()
  118. def inference(quit_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,
  119. vae, unet, pe,timesteps): #vae, unet, pe,timesteps
  120. # vae, unet, pe = load_diffusion_model()
  121. # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  122. # timesteps = torch.tensor([0], device=device)
  123. # pe = pe.half()
  124. # vae.vae = vae.vae.half()
  125. # unet.model = unet.model.half()
  126. length = len(input_latent_list_cycle)
  127. index = 0
  128. count=0
  129. counttime=0
  130. logger.info('start inference')
  131. while not quit_event.is_set():
  132. starttime=time.perf_counter()
  133. try:
  134. whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
  135. except queue.Empty:
  136. continue
  137. is_all_silence=True
  138. audio_frames = []
  139. for _ in range(batch_size*2):
  140. frame,type,eventpoint = audio_out_queue.get()
  141. audio_frames.append((frame,type,eventpoint))
  142. if type==0:
  143. is_all_silence=False
  144. if is_all_silence:
  145. for i in range(batch_size):
  146. res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
  147. index = index + 1
  148. else:
  149. # print('infer=======')
  150. t=time.perf_counter()
  151. whisper_batch = np.stack(whisper_chunks)
  152. latent_batch = []
  153. for i in range(batch_size):
  154. idx = __mirror_index(length,index+i)
  155. latent = input_latent_list_cycle[idx]
  156. latent_batch.append(latent)
  157. latent_batch = torch.cat(latent_batch, dim=0)
  158. # for i, (whisper_batch,latent_batch) in enumerate(gen):
  159. audio_feature_batch = torch.from_numpy(whisper_batch)
  160. audio_feature_batch = audio_feature_batch.to(device=unet.device,
  161. dtype=unet.model.dtype)
  162. audio_feature_batch = pe(audio_feature_batch)
  163. latent_batch = latent_batch.to(dtype=unet.model.dtype)
  164. # print('prepare time:',time.perf_counter()-t)
  165. # t=time.perf_counter()
  166. pred_latents = unet.model(latent_batch,
  167. timesteps,
  168. encoder_hidden_states=audio_feature_batch).sample
  169. # print('unet time:',time.perf_counter()-t)
  170. # t=time.perf_counter()
  171. recon = vae.decode_latents(pred_latents)
  172. # infer_inqueue.put((whisper_batch,latent_batch,sessionid))
  173. # recon,outsessionid = infer_outqueue.get()
  174. # if outsessionid != sessionid:
  175. # print('outsessionid:',outsessionid,' mysessionid:',sessionid)
  176. # print('vae time:',time.perf_counter()-t)
  177. #print('diffusion len=',len(recon))
  178. counttime += (time.perf_counter() - t)
  179. count += batch_size
  180. #_totalframe += 1
  181. if count>=100:
  182. logger.info(f"------actual avg infer fps:{count/counttime:.4f}")
  183. count=0
  184. counttime=0
  185. for i,res_frame in enumerate(recon):
  186. #self.__pushmedia(res_frame,loop,audio_track,video_track)
  187. res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
  188. index = index + 1
  189. #print('total batch time:',time.perf_counter()-starttime)
  190. logger.info('musereal inference processor stop')
  191. class MuseReal(BaseReal):
  192. @torch.no_grad()
  193. def __init__(self, opt, model, avatar):
  194. super().__init__(opt, model, avatar)
  195. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
  196. # self.W = opt.W
  197. # self.H = opt.H
  198. self.fps = opt.fps # 20 ms per frame
  199. self.batch_size = opt.batch_size
  200. self.idx = 0
  201. self.res_frame_queue = mp.Queue(self.batch_size*2)
  202. self.vae, self.unet, self.pe, self.timesteps, self.audio_processor = model
  203. self.frame_list_cycle,self.mask_list_cycle,self.coord_list_cycle,self.mask_coords_list_cycle, self.input_latent_list_cycle = avatar
  204. #self.__loadavatar()
  205. self.asr = MuseASR(opt,self,self.audio_processor)
  206. self.asr.warm_up()
  207. self.render_event = mp.Event()
  208. # def __del__(self):
  209. # logger.info(f'musereal({self.sessionid}) delete')
  210. def __mirror_index(self, index):
  211. size = len(self.coord_list_cycle)
  212. turn = index // size
  213. res = index % size
  214. if turn % 2 == 0:
  215. return res
  216. else:
  217. return size - res - 1
  218. def __warm_up(self):
  219. self.asr.run_step()
  220. whisper_chunks = self.asr.get_next_feat()
  221. whisper_batch = np.stack(whisper_chunks)
  222. latent_batch = []
  223. for i in range(self.batch_size):
  224. idx = self.__mirror_index(self.idx+i)
  225. latent = self.input_latent_list_cycle[idx]
  226. latent_batch.append(latent)
  227. latent_batch = torch.cat(latent_batch, dim=0)
  228. logger.info('infer=======')
  229. # for i, (whisper_batch,latent_batch) in enumerate(gen):
  230. audio_feature_batch = torch.from_numpy(whisper_batch)
  231. audio_feature_batch = audio_feature_batch.to(device=self.unet.device,
  232. dtype=self.unet.model.dtype)
  233. audio_feature_batch = self.pe(audio_feature_batch)
  234. latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
  235. pred_latents = self.unet.model(latent_batch,
  236. self.timesteps,
  237. encoder_hidden_states=audio_feature_batch).sample
  238. recon = self.vae.decode_latents(pred_latents)
  239. def paste_back_frame(self,pred_frame,idx:int):
  240. bbox = self.coord_list_cycle[idx]
  241. ori_frame = copy.deepcopy(self.frame_list_cycle[idx])
  242. x1, y1, x2, y2 = bbox
  243. res_frame = cv2.resize(pred_frame.astype(np.uint8),(x2-x1,y2-y1))
  244. mask = self.mask_list_cycle[idx]
  245. mask_crop_box = self.mask_coords_list_cycle[idx]
  246. combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
  247. return combine_frame
  248. def render(self,quit_event,loop=None,audio_track=None,video_track=None):
  249. #if self.opt.asr:
  250. # self.asr.warm_up()
  251. self.init_customindex()
  252. self.tts.render(quit_event)
  253. #self.render_event.set() #start infer process render
  254. infer_quit_event = Event()
  255. infer_thread = Thread(target=inference, args=(infer_quit_event,self.batch_size,self.input_latent_list_cycle,
  256. self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
  257. self.vae, self.unet, self.pe,self.timesteps)) #mp.Process
  258. infer_thread.start()
  259. process_quit_event = Event()
  260. process_thread = Thread(target=self.process_frames, args=(process_quit_event,loop,audio_track,video_track))
  261. process_thread.start()
  262. count=0
  263. totaltime=0
  264. _starttime=time.perf_counter()
  265. #_totalframe=0
  266. while not quit_event.is_set(): #todo
  267. # update texture every frame
  268. # audio stream thread...
  269. t = time.perf_counter()
  270. self.asr.run_step()
  271. #self.test_step(loop,audio_track,video_track)
  272. # totaltime += (time.perf_counter() - t)
  273. # count += self.opt.batch_size
  274. # if count>=100:
  275. # print(f"------actual avg infer fps:{count/totaltime:.4f}")
  276. # count=0
  277. # totaltime=0
  278. if video_track and video_track._queue.qsize()>=1.5*self.opt.batch_size:
  279. logger.debug('sleep qsize=%d',video_track._queue.qsize())
  280. time.sleep(0.04*video_track._queue.qsize()*0.8)
  281. # if video_track._queue.qsize()>=5:
  282. # print('sleep qsize=',video_track._queue.qsize())
  283. # time.sleep(0.04*video_track._queue.qsize()*0.8)
  284. # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
  285. # if delay > 0:
  286. # time.sleep(delay)
  287. logger.info('musereal thread stop')
  288. infer_quit_event.set()
  289. infer_thread.join()
  290. process_quit_event.set()
  291. process_thread.join()