utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import os
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from typing import Union, List
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. import shutil
  9. import os.path as osp
  10. from musetalk.models.vae import VAE
  11. from musetalk.models.unet import UNet,PositionalEncoding
  12. def load_all_model(
  13. unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
  14. vae_type="sd-vae",
  15. unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
  16. device=None,
  17. ):
  18. vae = VAE(
  19. model_path = os.path.join("models", vae_type),
  20. )
  21. print(f"load unet model from {unet_model_path}")
  22. unet = UNet(
  23. unet_config=unet_config,
  24. model_path=unet_model_path,
  25. device=device
  26. )
  27. pe = PositionalEncoding(d_model=384)
  28. return vae, unet, pe
  29. def get_file_type(video_path):
  30. _, ext = os.path.splitext(video_path)
  31. if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
  32. return 'image'
  33. elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
  34. return 'video'
  35. else:
  36. return 'unsupported'
  37. def get_video_fps(video_path):
  38. video = cv2.VideoCapture(video_path)
  39. fps = video.get(cv2.CAP_PROP_FPS)
  40. video.release()
  41. return fps
  42. def datagen(
  43. whisper_chunks,
  44. vae_encode_latents,
  45. batch_size=8,
  46. delay_frame=0,
  47. device="cuda:0",
  48. ):
  49. whisper_batch, latent_batch = [], []
  50. for i, w in enumerate(whisper_chunks):
  51. idx = (i+delay_frame)%len(vae_encode_latents)
  52. latent = vae_encode_latents[idx]
  53. whisper_batch.append(w)
  54. latent_batch.append(latent)
  55. if len(latent_batch) >= batch_size:
  56. whisper_batch = torch.stack(whisper_batch)
  57. latent_batch = torch.cat(latent_batch, dim=0)
  58. yield whisper_batch, latent_batch
  59. whisper_batch, latent_batch = [], []
  60. # the last batch may smaller than batch size
  61. if len(latent_batch) > 0:
  62. whisper_batch = torch.stack(whisper_batch)
  63. latent_batch = torch.cat(latent_batch, dim=0)
  64. yield whisper_batch.to(device), latent_batch.to(device)
  65. def cast_training_params(
  66. model: Union[torch.nn.Module, List[torch.nn.Module]],
  67. dtype=torch.float32,
  68. ):
  69. if not isinstance(model, list):
  70. model = [model]
  71. for m in model:
  72. for param in m.parameters():
  73. # only upcast trainable parameters into fp32
  74. if param.requires_grad:
  75. param.data = param.to(dtype)
  76. def rand_log_normal(
  77. shape,
  78. loc=0.,
  79. scale=1.,
  80. device='cpu',
  81. dtype=torch.float32,
  82. generator=None
  83. ):
  84. """Draws samples from an lognormal distribution."""
  85. rnd_normal = torch.randn(
  86. shape, device=device, dtype=dtype, generator=generator) # N(0, I)
  87. sigma = (rnd_normal * scale + loc).exp()
  88. return sigma
  89. def get_mouth_region(frames, image_pred, pixel_values_face_mask):
  90. # Initialize lists to store the results for each image in the batch
  91. mouth_real_list = []
  92. mouth_generated_list = []
  93. # Process each image in the batch
  94. for b in range(frames.shape[0]):
  95. # Find the non-zero area in the face mask
  96. non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
  97. # If there are no non-zero indices, skip this image
  98. if non_zero_indices.numel() == 0:
  99. continue
  100. min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
  101. non_zero_indices[:, 1])
  102. min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
  103. non_zero_indices[:, 2])
  104. # Crop the frames and image_pred according to the non-zero area
  105. frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
  106. image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
  107. # Resize the cropped images to 256*256
  108. frames_resized = F.interpolate(frames_cropped.unsqueeze(
  109. 0), size=(256, 256), mode='bilinear', align_corners=False)
  110. image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
  111. 0), size=(256, 256), mode='bilinear', align_corners=False)
  112. # Append the resized images to the result lists
  113. mouth_real_list.append(frames_resized)
  114. mouth_generated_list.append(image_pred_resized)
  115. # Convert the lists to tensors if they are not empty
  116. mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
  117. mouth_generated = torch.cat(
  118. mouth_generated_list, dim=0) if mouth_generated_list else None
  119. return mouth_real, mouth_generated
  120. def get_image_pred(pixel_values,
  121. ref_pixel_values,
  122. audio_prompts,
  123. vae,
  124. net,
  125. weight_dtype):
  126. with torch.no_grad():
  127. bsz, num_frames, c, h, w = pixel_values.shape
  128. masked_pixel_values = pixel_values.clone()
  129. masked_pixel_values[:, :, :, h//2:, :] = -1
  130. masked_frames = rearrange(
  131. masked_pixel_values, 'b f c h w -> (b f) c h w')
  132. masked_latents = vae.encode(masked_frames).latent_dist.mode()
  133. masked_latents = masked_latents * vae.config.scaling_factor
  134. masked_latents = masked_latents.float()
  135. ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
  136. ref_latents = vae.encode(ref_frames).latent_dist.mode()
  137. ref_latents = ref_latents * vae.config.scaling_factor
  138. ref_latents = ref_latents.float()
  139. input_latents = torch.cat([masked_latents, ref_latents], dim=1)
  140. input_latents = input_latents.to(weight_dtype)
  141. timesteps = torch.tensor([0], device=input_latents.device)
  142. latents_pred = net(
  143. input_latents,
  144. timesteps,
  145. audio_prompts,
  146. )
  147. latents_pred = (1 / vae.config.scaling_factor) * latents_pred
  148. image_pred = vae.decode(latents_pred).sample
  149. image_pred = image_pred.float()
  150. return image_pred
  151. def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
  152. with torch.no_grad():
  153. audio_feature_length_per_frame = 2 * \
  154. (cfg.data.audio_padding_length_left +
  155. cfg.data.audio_padding_length_right + 1)
  156. audio_feats = batch['audio_feature'].to(weight_dtype)
  157. audio_feats = wav2vec.encoder(
  158. audio_feats, output_hidden_states=True).hidden_states
  159. audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
  160. start_ts = batch['audio_offset']
  161. step_ts = batch['audio_step']
  162. audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
  163. audio_feats,
  164. torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
  165. audio_prompts = []
  166. for bb in range(bsz):
  167. audio_feats_list = []
  168. for f in range(num_frames):
  169. cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
  170. audio_clip = audio_feats[bb:bb+1,
  171. cur_t: cur_t+audio_feature_length_per_frame]
  172. audio_feats_list.append(audio_clip)
  173. audio_feats_list = torch.stack(audio_feats_list, 1)
  174. audio_prompts.append(audio_feats_list)
  175. audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
  176. return audio_prompts
  177. def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
  178. save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
  179. if total_limit is not None:
  180. checkpoints = os.listdir(save_dir)
  181. checkpoints = [d for d in checkpoints if d.endswith(".pth")]
  182. checkpoints = [d for d in checkpoints if name in d]
  183. checkpoints = sorted(
  184. checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
  185. )
  186. if len(checkpoints) >= total_limit:
  187. num_to_remove = len(checkpoints) - total_limit + 1
  188. removing_checkpoints = checkpoints[0:num_to_remove]
  189. logger.info(
  190. f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
  191. )
  192. logger.info(
  193. f"removing checkpoints: {', '.join(removing_checkpoints)}")
  194. for removing_checkpoint in removing_checkpoints:
  195. removing_checkpoint = os.path.join(
  196. save_dir, removing_checkpoint)
  197. os.remove(removing_checkpoint)
  198. state_dict = model.state_dict()
  199. torch.save(state_dict, save_path)
  200. def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
  201. unwarp_net = accelerator.unwrap_model(net)
  202. save_checkpoint(
  203. unwarp_net.unet,
  204. save_dir,
  205. global_step,
  206. name="unet",
  207. total_limit=cfg.total_limit,
  208. logger=logger
  209. )
  210. def delete_additional_ckpt(base_path, num_keep):
  211. dirs = []
  212. for d in os.listdir(base_path):
  213. if d.startswith("checkpoint-"):
  214. dirs.append(d)
  215. num_tot = len(dirs)
  216. if num_tot <= num_keep:
  217. return
  218. # ensure ckpt is sorted and delete the ealier!
  219. del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
  220. for d in del_dirs:
  221. path_to_dir = osp.join(base_path, d)
  222. if osp.exists(path_to_dir):
  223. shutil.rmtree(path_to_dir)
  224. def seed_everything(seed):
  225. import random
  226. import numpy as np
  227. torch.manual_seed(seed)
  228. torch.cuda.manual_seed_all(seed)
  229. np.random.seed(seed % (2**32))
  230. random.seed(seed)
  231. def process_and_save_images(
  232. batch,
  233. image_pred,
  234. image_pred_infer,
  235. save_dir,
  236. global_step,
  237. accelerator,
  238. num_images_to_keep=10,
  239. syncnet_score=1
  240. ):
  241. # Rearrange the tensors
  242. print("image_pred.shape: ", image_pred.shape)
  243. pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
  244. pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
  245. # Create masked pixel values
  246. masked_pixel_values = batch["pixel_values_vid"].clone()
  247. _, _, _, h, _ = batch["pixel_values_vid"].shape
  248. masked_pixel_values[:, :, :, h//2:, :] = -1
  249. masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
  250. # Keep only the specified number of images
  251. pixel_values = pixel_values[:num_images_to_keep, :, :, :]
  252. masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
  253. pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
  254. image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
  255. image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
  256. # Concatenate images
  257. concat = torch.cat([
  258. masked_pixel_values * 0.5 + 0.5,
  259. pixel_values_ref_img * 0.5 + 0.5,
  260. image_pred * 0.5 + 0.5,
  261. pixel_values * 0.5 + 0.5,
  262. image_pred_infer * 0.5 + 0.5,
  263. ], dim=2)
  264. print("concat.shape: ", concat.shape)
  265. # Create the save directory if it doesn't exist
  266. os.makedirs(f'{save_dir}/samples/', exist_ok=True)
  267. # Try to save the concatenated image
  268. try:
  269. # Concatenate images horizontally and convert to numpy array
  270. final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
  271. # Save the image
  272. cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
  273. print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
  274. except Exception as e:
  275. print(f"Failed to save image: {e}")