| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- # coding: utf-8
- """
- Benchmark the inference speed of each module in LivePortrait.
- TODO: heavy GPT style, need to refactor
- """
- import torch
- torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
- import yaml
- import time
- import numpy as np
- from src.utils.helper import load_model, concat_feat
- from src.config.inference_config import InferenceConfig
- def initialize_inputs(batch_size=1, device_id=0):
- """
- Generate random input tensors and move them to GPU
- """
- feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
- kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
- kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
- source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
- generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
- eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
- lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
- feat_stitching = concat_feat(kp_source, kp_driving).half()
- feat_eye = concat_feat(kp_source, eye_close_ratio).half()
- feat_lip = concat_feat(kp_source, lip_close_ratio).half()
- inputs = {
- 'feature_3d': feature_3d,
- 'kp_source': kp_source,
- 'kp_driving': kp_driving,
- 'source_image': source_image,
- 'generator_input': generator_input,
- 'feat_stitching': feat_stitching,
- 'feat_eye': feat_eye,
- 'feat_lip': feat_lip
- }
- return inputs
- def load_and_compile_models(cfg, model_config):
- """
- Load and compile models for inference
- """
- appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
- motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
- warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
- spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
- stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
- models_with_params = [
- ('Appearance Feature Extractor', appearance_feature_extractor),
- ('Motion Extractor', motion_extractor),
- ('Warping Network', warping_module),
- ('SPADE Decoder', spade_generator)
- ]
- compiled_models = {}
- for name, model in models_with_params:
- model = model.half()
- model = torch.compile(model, mode='max-autotune') # Optimize for inference
- model.eval() # Switch to evaluation mode
- compiled_models[name] = model
- retargeting_models = ['stitching', 'eye', 'lip']
- for retarget in retargeting_models:
- module = stitching_retargeting_module[retarget].half()
- module = torch.compile(module, mode='max-autotune') # Optimize for inference
- module.eval() # Switch to evaluation mode
- stitching_retargeting_module[retarget] = module
- return compiled_models, stitching_retargeting_module
- def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
- """
- Warm up models to prepare them for benchmarking
- """
- print("Warm up start!")
- with torch.no_grad():
- for _ in range(10):
- compiled_models['Appearance Feature Extractor'](inputs['source_image'])
- compiled_models['Motion Extractor'](inputs['source_image'])
- compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
- compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
- stitching_retargeting_module['stitching'](inputs['feat_stitching'])
- stitching_retargeting_module['eye'](inputs['feat_eye'])
- stitching_retargeting_module['lip'](inputs['feat_lip'])
- print("Warm up end!")
- def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
- """
- Measure inference times for each model
- """
- times = {name: [] for name in compiled_models.keys()}
- times['Stitching and Retargeting Modules'] = []
- overall_times = []
- with torch.no_grad():
- for _ in range(100):
- torch.cuda.synchronize()
- overall_start = time.time()
- start = time.time()
- compiled_models['Appearance Feature Extractor'](inputs['source_image'])
- torch.cuda.synchronize()
- times['Appearance Feature Extractor'].append(time.time() - start)
- start = time.time()
- compiled_models['Motion Extractor'](inputs['source_image'])
- torch.cuda.synchronize()
- times['Motion Extractor'].append(time.time() - start)
- start = time.time()
- compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
- torch.cuda.synchronize()
- times['Warping Network'].append(time.time() - start)
- start = time.time()
- compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
- torch.cuda.synchronize()
- times['SPADE Decoder'].append(time.time() - start)
- start = time.time()
- stitching_retargeting_module['stitching'](inputs['feat_stitching'])
- stitching_retargeting_module['eye'](inputs['feat_eye'])
- stitching_retargeting_module['lip'](inputs['feat_lip'])
- torch.cuda.synchronize()
- times['Stitching and Retargeting Modules'].append(time.time() - start)
- overall_times.append(time.time() - overall_start)
- return times, overall_times
- def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
- """
- Print benchmark results with average and standard deviation of inference times
- """
- average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
- std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
- for name, model in compiled_models.items():
- num_params = sum(p.numel() for p in model.parameters())
- num_params_in_millions = num_params / 1e6
- print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
- for index, retarget in enumerate(retargeting_models):
- num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
- num_params_in_millions = num_params / 1e6
- print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
- for name, avg_time in average_times.items():
- std_time = std_times[name]
- print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
- def main():
- """
- Main function to benchmark speed and model parameters
- """
- # Load configuration
- cfg = InferenceConfig()
- model_config_path = cfg.models_config
- with open(model_config_path, 'r') as file:
- model_config = yaml.safe_load(file)
- # Sample input tensors
- inputs = initialize_inputs(device_id = cfg.device_id)
- # Load and compile models
- compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
- # Warm up models
- warm_up_models(compiled_models, stitching_retargeting_module, inputs)
- # Measure inference times
- times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
- # Print benchmark results
- print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
- if __name__ == "__main__":
- main()
|