training_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import os
  2. import json
  3. import logging
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from torch.optim.lr_scheduler import CosineAnnealingLR
  8. from diffusers import AutoencoderKL, UNet2DConditionModel
  9. from transformers import WhisperModel
  10. from diffusers.optimization import get_scheduler
  11. from omegaconf import OmegaConf
  12. from einops import rearrange
  13. from musetalk.models.syncnet import SyncNet
  14. from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
  15. from musetalk.loss.basic_loss import Interpolate
  16. import musetalk.loss.vgg_face as vgg_face
  17. from musetalk.data.dataset import PortraitDataset
  18. from musetalk.utils.utils import (
  19. get_image_pred,
  20. process_audio_features,
  21. process_and_save_images
  22. )
  23. class Net(nn.Module):
  24. def __init__(
  25. self,
  26. unet: UNet2DConditionModel,
  27. ):
  28. super().__init__()
  29. self.unet = unet
  30. def forward(
  31. self,
  32. input_latents,
  33. timesteps,
  34. audio_prompts,
  35. ):
  36. model_pred = self.unet(
  37. input_latents,
  38. timesteps,
  39. encoder_hidden_states=audio_prompts
  40. ).sample
  41. return model_pred
  42. logger = logging.getLogger(__name__)
  43. def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
  44. """Initialize models and optimizers"""
  45. model_dict = {
  46. 'vae': None,
  47. 'unet': None,
  48. 'net': None,
  49. 'wav2vec': None,
  50. 'optimizer': None,
  51. 'lr_scheduler': None,
  52. 'scheduler_max_steps': None,
  53. 'trainable_params': None
  54. }
  55. model_dict['vae'] = AutoencoderKL.from_pretrained(
  56. cfg.pretrained_model_name_or_path,
  57. subfolder=cfg.vae_type,
  58. )
  59. unet_config_file = os.path.join(
  60. cfg.pretrained_model_name_or_path,
  61. cfg.unet_sub_folder + "/musetalk.json"
  62. )
  63. with open(unet_config_file, 'r') as f:
  64. unet_config = json.load(f)
  65. model_dict['unet'] = UNet2DConditionModel(**unet_config)
  66. if not cfg.random_init_unet:
  67. pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
  68. print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
  69. checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
  70. model_dict['unet'].load_state_dict(checkpoint)
  71. unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
  72. logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
  73. model_dict['vae'].requires_grad_(False)
  74. model_dict['unet'].requires_grad_(True)
  75. model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
  76. model_dict['net'] = Net(model_dict['unet'])
  77. model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
  78. device="cuda", dtype=weight_dtype).eval()
  79. model_dict['wav2vec'].requires_grad_(False)
  80. if cfg.solver.gradient_checkpointing:
  81. model_dict['unet'].enable_gradient_checkpointing()
  82. if cfg.solver.scale_lr:
  83. learning_rate = (
  84. cfg.solver.learning_rate
  85. * cfg.solver.gradient_accumulation_steps
  86. * cfg.data.train_bs
  87. * accelerator.num_processes
  88. )
  89. else:
  90. learning_rate = cfg.solver.learning_rate
  91. if cfg.solver.use_8bit_adam:
  92. try:
  93. import bitsandbytes as bnb
  94. except ImportError:
  95. raise ImportError(
  96. "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
  97. )
  98. optimizer_cls = bnb.optim.AdamW8bit
  99. else:
  100. optimizer_cls = torch.optim.AdamW
  101. model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
  102. if accelerator.is_main_process:
  103. print('trainable params')
  104. for n, p in model_dict['net'].named_parameters():
  105. if p.requires_grad:
  106. print(n)
  107. model_dict['optimizer'] = optimizer_cls(
  108. model_dict['trainable_params'],
  109. lr=learning_rate,
  110. betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
  111. weight_decay=cfg.solver.adam_weight_decay,
  112. eps=cfg.solver.adam_epsilon,
  113. )
  114. model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
  115. model_dict['lr_scheduler'] = get_scheduler(
  116. cfg.solver.lr_scheduler,
  117. optimizer=model_dict['optimizer'],
  118. num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
  119. num_training_steps=model_dict['scheduler_max_steps'],
  120. )
  121. return model_dict
  122. def initialize_dataloaders(cfg):
  123. """Initialize training and validation dataloaders"""
  124. dataloader_dict = {
  125. 'train_dataset': None,
  126. 'val_dataset': None,
  127. 'train_dataloader': None,
  128. 'val_dataloader': None
  129. }
  130. dataloader_dict['train_dataset'] = PortraitDataset(cfg={
  131. 'image_size': cfg.data.image_size,
  132. 'T': cfg.data.n_sample_frames,
  133. "sample_method": cfg.data.sample_method,
  134. 'top_k_ratio': cfg.data.top_k_ratio,
  135. "contorl_face_min_size": cfg.data.contorl_face_min_size,
  136. "dataset_key": cfg.data.dataset_key,
  137. "padding_pixel_mouth": cfg.padding_pixel_mouth,
  138. "whisper_path": cfg.whisper_path,
  139. "min_face_size": cfg.data.min_face_size,
  140. "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
  141. "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
  142. "crop_type": cfg.crop_type,
  143. "random_margin_method": cfg.random_margin_method,
  144. })
  145. dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
  146. dataloader_dict['train_dataset'],
  147. batch_size=cfg.data.train_bs,
  148. shuffle=True,
  149. num_workers=cfg.data.num_workers,
  150. )
  151. dataloader_dict['val_dataset'] = PortraitDataset(cfg={
  152. 'image_size': cfg.data.image_size,
  153. 'T': cfg.data.n_sample_frames,
  154. "sample_method": cfg.data.sample_method,
  155. 'top_k_ratio': cfg.data.top_k_ratio,
  156. "contorl_face_min_size": cfg.data.contorl_face_min_size,
  157. "dataset_key": cfg.data.dataset_key,
  158. "padding_pixel_mouth": cfg.padding_pixel_mouth,
  159. "whisper_path": cfg.whisper_path,
  160. "min_face_size": cfg.data.min_face_size,
  161. "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
  162. "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
  163. "crop_type": cfg.crop_type,
  164. "random_margin_method": cfg.random_margin_method,
  165. })
  166. dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
  167. dataloader_dict['val_dataset'],
  168. batch_size=cfg.data.train_bs,
  169. shuffle=True,
  170. num_workers=1,
  171. )
  172. return dataloader_dict
  173. def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
  174. """Initialize loss functions and discriminators"""
  175. loss_dict = {
  176. 'L1_loss': nn.L1Loss(reduction='mean'),
  177. 'discriminator': None,
  178. 'mouth_discriminator': None,
  179. 'optimizer_D': None,
  180. 'mouth_optimizer_D': None,
  181. 'scheduler_D': None,
  182. 'mouth_scheduler_D': None,
  183. 'disc_scales': None,
  184. 'discriminator_full': None,
  185. 'mouth_discriminator_full': None
  186. }
  187. if cfg.loss_params.gan_loss > 0:
  188. loss_dict['discriminator'] = MultiScaleDiscriminator(
  189. **cfg.model_params.discriminator_params).to(accelerator.device)
  190. loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
  191. loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
  192. loss_dict['optimizer_D'] = optim.AdamW(
  193. loss_dict['discriminator'].parameters(),
  194. lr=cfg.discriminator_train_params.lr,
  195. weight_decay=cfg.discriminator_train_params.weight_decay,
  196. betas=cfg.discriminator_train_params.betas,
  197. eps=cfg.discriminator_train_params.eps)
  198. loss_dict['scheduler_D'] = CosineAnnealingLR(
  199. loss_dict['optimizer_D'],
  200. T_max=scheduler_max_steps,
  201. eta_min=1e-6
  202. )
  203. if cfg.loss_params.mouth_gan_loss > 0:
  204. loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
  205. **cfg.model_params.discriminator_params).to(accelerator.device)
  206. loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
  207. loss_dict['mouth_optimizer_D'] = optim.AdamW(
  208. loss_dict['mouth_discriminator'].parameters(),
  209. lr=cfg.discriminator_train_params.lr,
  210. weight_decay=cfg.discriminator_train_params.weight_decay,
  211. betas=cfg.discriminator_train_params.betas,
  212. eps=cfg.discriminator_train_params.eps)
  213. loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
  214. loss_dict['mouth_optimizer_D'],
  215. T_max=scheduler_max_steps,
  216. eta_min=1e-6
  217. )
  218. return loss_dict
  219. def initialize_syncnet(cfg, accelerator, weight_dtype):
  220. """Initialize SyncNet model"""
  221. if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
  222. if cfg.data.n_sample_frames != 16:
  223. raise ValueError(
  224. f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
  225. )
  226. syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
  227. syncnet = SyncNet(OmegaConf.to_container(
  228. syncnet_config.model)).to(accelerator.device)
  229. print(
  230. f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
  231. checkpoint = torch.load(
  232. syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
  233. syncnet.load_state_dict(checkpoint["state_dict"])
  234. syncnet.to(dtype=weight_dtype)
  235. syncnet.requires_grad_(False)
  236. syncnet.eval()
  237. return syncnet
  238. return None
  239. def initialize_vgg(cfg, accelerator):
  240. """Initialize VGG model"""
  241. if cfg.loss_params.vgg_loss > 0:
  242. vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
  243. pyramid = vgg_face.ImagePyramide(
  244. cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
  245. vgg_IN.eval()
  246. downsampler = Interpolate(
  247. size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
  248. return vgg_IN, pyramid, downsampler
  249. return None, None, None
  250. def validation(
  251. cfg,
  252. val_dataloader,
  253. net,
  254. vae,
  255. wav2vec,
  256. accelerator,
  257. save_dir,
  258. global_step,
  259. weight_dtype,
  260. syncnet_score=1,
  261. ):
  262. """Validation function for model evaluation"""
  263. net.eval() # Set the model to evaluation mode
  264. for batch in val_dataloader:
  265. # The same ref_latents
  266. ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
  267. accelerator.device, non_blocking=True
  268. )
  269. pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
  270. accelerator.device, non_blocking=True
  271. )
  272. bsz, num_frames, c, h, w = ref_pixel_values.shape
  273. audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
  274. # audio feature for unet
  275. audio_prompts = rearrange(
  276. audio_prompts,
  277. 'b f c h w-> (b f) c h w'
  278. )
  279. audio_prompts = rearrange(
  280. audio_prompts,
  281. '(b f) c h w -> (b f) (c h) w',
  282. b=bsz
  283. )
  284. # different masked_latents
  285. image_pred_train = get_image_pred(
  286. pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
  287. image_pred_infer = get_image_pred(
  288. ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
  289. process_and_save_images(
  290. batch,
  291. image_pred_train,
  292. image_pred_infer,
  293. save_dir,
  294. global_step,
  295. accelerator,
  296. cfg.num_images_to_keep,
  297. syncnet_score
  298. )
  299. # only infer 1 image in validation
  300. break
  301. net.train() # Set the model back to training mode