inference.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # coding: utf-8
  2. """
  3. The entrance of humans
  4. """
  5. import os
  6. import os.path as osp
  7. import tyro
  8. import subprocess
  9. from src.config.argument_config import ArgumentConfig
  10. from src.config.inference_config import InferenceConfig
  11. from src.config.crop_config import CropConfig
  12. from src.live_portrait_pipeline import LivePortraitPipeline
  13. def partial_fields(target_class, kwargs):
  14. return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
  15. def fast_check_ffmpeg():
  16. try:
  17. subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
  18. return True
  19. except:
  20. return False
  21. def fast_check_args(args: ArgumentConfig):
  22. if not osp.exists(args.source):
  23. raise FileNotFoundError(f"source info not found: {args.source}")
  24. if not osp.exists(args.driving):
  25. raise FileNotFoundError(f"driving info not found: {args.driving}")
  26. def main():
  27. # set tyro theme
  28. tyro.extras.set_accent_color("bright_cyan")
  29. args = tyro.cli(ArgumentConfig)
  30. ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
  31. if osp.exists(ffmpeg_dir):
  32. os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
  33. if not fast_check_ffmpeg():
  34. raise ImportError(
  35. "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
  36. )
  37. fast_check_args(args)
  38. # specify configs for inference
  39. inference_cfg = partial_fields(InferenceConfig, args.__dict__)
  40. crop_cfg = partial_fields(CropConfig, args.__dict__)
  41. live_portrait_pipeline = LivePortraitPipeline(
  42. inference_cfg=inference_cfg,
  43. crop_cfg=crop_cfg
  44. )
  45. # run
  46. live_portrait_pipeline.execute(args)
  47. if __name__ == "__main__":
  48. main()