| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337 |
- import os
- import json
- import logging
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.optim.lr_scheduler import CosineAnnealingLR
- from diffusers import AutoencoderKL, UNet2DConditionModel
- from transformers import WhisperModel
- from diffusers.optimization import get_scheduler
- from omegaconf import OmegaConf
- from einops import rearrange
- from musetalk.models.syncnet import SyncNet
- from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
- from musetalk.loss.basic_loss import Interpolate
- import musetalk.loss.vgg_face as vgg_face
- from musetalk.data.dataset import PortraitDataset
- from musetalk.utils.utils import (
- get_image_pred,
- process_audio_features,
- process_and_save_images
- )
- class Net(nn.Module):
- def __init__(
- self,
- unet: UNet2DConditionModel,
- ):
- super().__init__()
- self.unet = unet
- def forward(
- self,
- input_latents,
- timesteps,
- audio_prompts,
- ):
- model_pred = self.unet(
- input_latents,
- timesteps,
- encoder_hidden_states=audio_prompts
- ).sample
- return model_pred
- logger = logging.getLogger(__name__)
- def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
- """Initialize models and optimizers"""
- model_dict = {
- 'vae': None,
- 'unet': None,
- 'net': None,
- 'wav2vec': None,
- 'optimizer': None,
- 'lr_scheduler': None,
- 'scheduler_max_steps': None,
- 'trainable_params': None
- }
-
- model_dict['vae'] = AutoencoderKL.from_pretrained(
- cfg.pretrained_model_name_or_path,
- subfolder=cfg.vae_type,
- )
- unet_config_file = os.path.join(
- cfg.pretrained_model_name_or_path,
- cfg.unet_sub_folder + "/musetalk.json"
- )
-
- with open(unet_config_file, 'r') as f:
- unet_config = json.load(f)
- model_dict['unet'] = UNet2DConditionModel(**unet_config)
-
- if not cfg.random_init_unet:
- pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
- print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
- checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
- model_dict['unet'].load_state_dict(checkpoint)
-
- unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
- logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
-
- model_dict['vae'].requires_grad_(False)
- model_dict['unet'].requires_grad_(True)
- model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
- model_dict['net'] = Net(model_dict['unet'])
- model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
- device="cuda", dtype=weight_dtype).eval()
- model_dict['wav2vec'].requires_grad_(False)
- if cfg.solver.gradient_checkpointing:
- model_dict['unet'].enable_gradient_checkpointing()
- if cfg.solver.scale_lr:
- learning_rate = (
- cfg.solver.learning_rate
- * cfg.solver.gradient_accumulation_steps
- * cfg.data.train_bs
- * accelerator.num_processes
- )
- else:
- learning_rate = cfg.solver.learning_rate
- if cfg.solver.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
- )
- optimizer_cls = bnb.optim.AdamW8bit
- else:
- optimizer_cls = torch.optim.AdamW
- model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
- if accelerator.is_main_process:
- print('trainable params')
- for n, p in model_dict['net'].named_parameters():
- if p.requires_grad:
- print(n)
- model_dict['optimizer'] = optimizer_cls(
- model_dict['trainable_params'],
- lr=learning_rate,
- betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
- weight_decay=cfg.solver.adam_weight_decay,
- eps=cfg.solver.adam_epsilon,
- )
- model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
- model_dict['lr_scheduler'] = get_scheduler(
- cfg.solver.lr_scheduler,
- optimizer=model_dict['optimizer'],
- num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
- num_training_steps=model_dict['scheduler_max_steps'],
- )
- return model_dict
- def initialize_dataloaders(cfg):
- """Initialize training and validation dataloaders"""
- dataloader_dict = {
- 'train_dataset': None,
- 'val_dataset': None,
- 'train_dataloader': None,
- 'val_dataloader': None
- }
-
- dataloader_dict['train_dataset'] = PortraitDataset(cfg={
- 'image_size': cfg.data.image_size,
- 'T': cfg.data.n_sample_frames,
- "sample_method": cfg.data.sample_method,
- 'top_k_ratio': cfg.data.top_k_ratio,
- "contorl_face_min_size": cfg.data.contorl_face_min_size,
- "dataset_key": cfg.data.dataset_key,
- "padding_pixel_mouth": cfg.padding_pixel_mouth,
- "whisper_path": cfg.whisper_path,
- "min_face_size": cfg.data.min_face_size,
- "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
- "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
- "crop_type": cfg.crop_type,
- "random_margin_method": cfg.random_margin_method,
- })
- dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
- dataloader_dict['train_dataset'],
- batch_size=cfg.data.train_bs,
- shuffle=True,
- num_workers=cfg.data.num_workers,
- )
-
- dataloader_dict['val_dataset'] = PortraitDataset(cfg={
- 'image_size': cfg.data.image_size,
- 'T': cfg.data.n_sample_frames,
- "sample_method": cfg.data.sample_method,
- 'top_k_ratio': cfg.data.top_k_ratio,
- "contorl_face_min_size": cfg.data.contorl_face_min_size,
- "dataset_key": cfg.data.dataset_key,
- "padding_pixel_mouth": cfg.padding_pixel_mouth,
- "whisper_path": cfg.whisper_path,
- "min_face_size": cfg.data.min_face_size,
- "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
- "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
- "crop_type": cfg.crop_type,
- "random_margin_method": cfg.random_margin_method,
- })
- dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
- dataloader_dict['val_dataset'],
- batch_size=cfg.data.train_bs,
- shuffle=True,
- num_workers=1,
- )
-
- return dataloader_dict
- def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
- """Initialize loss functions and discriminators"""
- loss_dict = {
- 'L1_loss': nn.L1Loss(reduction='mean'),
- 'discriminator': None,
- 'mouth_discriminator': None,
- 'optimizer_D': None,
- 'mouth_optimizer_D': None,
- 'scheduler_D': None,
- 'mouth_scheduler_D': None,
- 'disc_scales': None,
- 'discriminator_full': None,
- 'mouth_discriminator_full': None
- }
-
- if cfg.loss_params.gan_loss > 0:
- loss_dict['discriminator'] = MultiScaleDiscriminator(
- **cfg.model_params.discriminator_params).to(accelerator.device)
- loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
- loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
- loss_dict['optimizer_D'] = optim.AdamW(
- loss_dict['discriminator'].parameters(),
- lr=cfg.discriminator_train_params.lr,
- weight_decay=cfg.discriminator_train_params.weight_decay,
- betas=cfg.discriminator_train_params.betas,
- eps=cfg.discriminator_train_params.eps)
- loss_dict['scheduler_D'] = CosineAnnealingLR(
- loss_dict['optimizer_D'],
- T_max=scheduler_max_steps,
- eta_min=1e-6
- )
- if cfg.loss_params.mouth_gan_loss > 0:
- loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
- **cfg.model_params.discriminator_params).to(accelerator.device)
- loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
- loss_dict['mouth_optimizer_D'] = optim.AdamW(
- loss_dict['mouth_discriminator'].parameters(),
- lr=cfg.discriminator_train_params.lr,
- weight_decay=cfg.discriminator_train_params.weight_decay,
- betas=cfg.discriminator_train_params.betas,
- eps=cfg.discriminator_train_params.eps)
- loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
- loss_dict['mouth_optimizer_D'],
- T_max=scheduler_max_steps,
- eta_min=1e-6
- )
-
- return loss_dict
- def initialize_syncnet(cfg, accelerator, weight_dtype):
- """Initialize SyncNet model"""
- if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
- if cfg.data.n_sample_frames != 16:
- raise ValueError(
- f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
- )
- syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
- syncnet = SyncNet(OmegaConf.to_container(
- syncnet_config.model)).to(accelerator.device)
- print(
- f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
- checkpoint = torch.load(
- syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
- syncnet.load_state_dict(checkpoint["state_dict"])
- syncnet.to(dtype=weight_dtype)
- syncnet.requires_grad_(False)
- syncnet.eval()
- return syncnet
- return None
- def initialize_vgg(cfg, accelerator):
- """Initialize VGG model"""
- if cfg.loss_params.vgg_loss > 0:
- vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
- pyramid = vgg_face.ImagePyramide(
- cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
- vgg_IN.eval()
- downsampler = Interpolate(
- size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
- return vgg_IN, pyramid, downsampler
- return None, None, None
- def validation(
- cfg,
- val_dataloader,
- net,
- vae,
- wav2vec,
- accelerator,
- save_dir,
- global_step,
- weight_dtype,
- syncnet_score=1,
- ):
- """Validation function for model evaluation"""
- net.eval() # Set the model to evaluation mode
- for batch in val_dataloader:
- # The same ref_latents
- ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
- accelerator.device, non_blocking=True
- )
- pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
- accelerator.device, non_blocking=True
- )
- bsz, num_frames, c, h, w = ref_pixel_values.shape
- audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
- # audio feature for unet
- audio_prompts = rearrange(
- audio_prompts,
- 'b f c h w-> (b f) c h w'
- )
- audio_prompts = rearrange(
- audio_prompts,
- '(b f) c h w -> (b f) (c h) w',
- b=bsz
- )
- # different masked_latents
- image_pred_train = get_image_pred(
- pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
- image_pred_infer = get_image_pred(
- ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
- process_and_save_images(
- batch,
- image_pred_train,
- image_pred_infer,
- save_dir,
- global_step,
- accelerator,
- cfg.num_images_to_keep,
- syncnet_score
- )
- # only infer 1 image in validation
- break
- net.train() # Set the model back to training mode
|