speed.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # coding: utf-8
  2. """
  3. Benchmark the inference speed of each module in LivePortrait.
  4. TODO: heavy GPT style, need to refactor
  5. """
  6. import torch
  7. torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
  8. import yaml
  9. import time
  10. import numpy as np
  11. from src.utils.helper import load_model, concat_feat
  12. from src.config.inference_config import InferenceConfig
  13. def initialize_inputs(batch_size=1, device_id=0):
  14. """
  15. Generate random input tensors and move them to GPU
  16. """
  17. feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
  18. kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
  19. kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
  20. source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
  21. generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
  22. eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
  23. lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
  24. feat_stitching = concat_feat(kp_source, kp_driving).half()
  25. feat_eye = concat_feat(kp_source, eye_close_ratio).half()
  26. feat_lip = concat_feat(kp_source, lip_close_ratio).half()
  27. inputs = {
  28. 'feature_3d': feature_3d,
  29. 'kp_source': kp_source,
  30. 'kp_driving': kp_driving,
  31. 'source_image': source_image,
  32. 'generator_input': generator_input,
  33. 'feat_stitching': feat_stitching,
  34. 'feat_eye': feat_eye,
  35. 'feat_lip': feat_lip
  36. }
  37. return inputs
  38. def load_and_compile_models(cfg, model_config):
  39. """
  40. Load and compile models for inference
  41. """
  42. appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
  43. motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
  44. warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
  45. spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
  46. stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
  47. models_with_params = [
  48. ('Appearance Feature Extractor', appearance_feature_extractor),
  49. ('Motion Extractor', motion_extractor),
  50. ('Warping Network', warping_module),
  51. ('SPADE Decoder', spade_generator)
  52. ]
  53. compiled_models = {}
  54. for name, model in models_with_params:
  55. model = model.half()
  56. model = torch.compile(model, mode='max-autotune') # Optimize for inference
  57. model.eval() # Switch to evaluation mode
  58. compiled_models[name] = model
  59. retargeting_models = ['stitching', 'eye', 'lip']
  60. for retarget in retargeting_models:
  61. module = stitching_retargeting_module[retarget].half()
  62. module = torch.compile(module, mode='max-autotune') # Optimize for inference
  63. module.eval() # Switch to evaluation mode
  64. stitching_retargeting_module[retarget] = module
  65. return compiled_models, stitching_retargeting_module
  66. def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
  67. """
  68. Warm up models to prepare them for benchmarking
  69. """
  70. print("Warm up start!")
  71. with torch.no_grad():
  72. for _ in range(10):
  73. compiled_models['Appearance Feature Extractor'](inputs['source_image'])
  74. compiled_models['Motion Extractor'](inputs['source_image'])
  75. compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
  76. compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
  77. stitching_retargeting_module['stitching'](inputs['feat_stitching'])
  78. stitching_retargeting_module['eye'](inputs['feat_eye'])
  79. stitching_retargeting_module['lip'](inputs['feat_lip'])
  80. print("Warm up end!")
  81. def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
  82. """
  83. Measure inference times for each model
  84. """
  85. times = {name: [] for name in compiled_models.keys()}
  86. times['Stitching and Retargeting Modules'] = []
  87. overall_times = []
  88. with torch.no_grad():
  89. for _ in range(100):
  90. torch.cuda.synchronize()
  91. overall_start = time.time()
  92. start = time.time()
  93. compiled_models['Appearance Feature Extractor'](inputs['source_image'])
  94. torch.cuda.synchronize()
  95. times['Appearance Feature Extractor'].append(time.time() - start)
  96. start = time.time()
  97. compiled_models['Motion Extractor'](inputs['source_image'])
  98. torch.cuda.synchronize()
  99. times['Motion Extractor'].append(time.time() - start)
  100. start = time.time()
  101. compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
  102. torch.cuda.synchronize()
  103. times['Warping Network'].append(time.time() - start)
  104. start = time.time()
  105. compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
  106. torch.cuda.synchronize()
  107. times['SPADE Decoder'].append(time.time() - start)
  108. start = time.time()
  109. stitching_retargeting_module['stitching'](inputs['feat_stitching'])
  110. stitching_retargeting_module['eye'](inputs['feat_eye'])
  111. stitching_retargeting_module['lip'](inputs['feat_lip'])
  112. torch.cuda.synchronize()
  113. times['Stitching and Retargeting Modules'].append(time.time() - start)
  114. overall_times.append(time.time() - overall_start)
  115. return times, overall_times
  116. def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
  117. """
  118. Print benchmark results with average and standard deviation of inference times
  119. """
  120. average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
  121. std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
  122. for name, model in compiled_models.items():
  123. num_params = sum(p.numel() for p in model.parameters())
  124. num_params_in_millions = num_params / 1e6
  125. print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
  126. for index, retarget in enumerate(retargeting_models):
  127. num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
  128. num_params_in_millions = num_params / 1e6
  129. print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
  130. for name, avg_time in average_times.items():
  131. std_time = std_times[name]
  132. print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
  133. def main():
  134. """
  135. Main function to benchmark speed and model parameters
  136. """
  137. # Load configuration
  138. cfg = InferenceConfig()
  139. model_config_path = cfg.models_config
  140. with open(model_config_path, 'r') as file:
  141. model_config = yaml.safe_load(file)
  142. # Sample input tensors
  143. inputs = initialize_inputs(device_id = cfg.device_id)
  144. # Load and compile models
  145. compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
  146. # Warm up models
  147. warm_up_models(compiled_models, stitching_retargeting_module, inputs)
  148. # Measure inference times
  149. times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
  150. # Print benchmark results
  151. print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
  152. if __name__ == "__main__":
  153. main()