video_to_greenscreen.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 视频绿幕替换工具
  5. 使用 RobustVideoMatting (RVM) 模型将视频背景替换为绿色
  6. 已优化:本地模型极速加载,不联网、不卡
  7. """
  8. import os
  9. import sys
  10. import cv2
  11. import torch
  12. from torch.utils.data import DataLoader
  13. from torchvision.transforms import ToTensor
  14. import numpy as np
  15. from tqdm import tqdm
  16. import argparse
  17. class VideoReader:
  18. """视频读取器"""
  19. def __init__(self, video_path, transform=None):
  20. self.cap = cv2.VideoCapture(video_path)
  21. if not self.cap.isOpened():
  22. raise ValueError(f"Cannot open video: {video_path}")
  23. self.transform = transform
  24. self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
  25. self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  26. self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  27. self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  28. def __len__(self):
  29. return self.total_frames
  30. def __getitem__(self, idx):
  31. ret, frame = self.cap.read()
  32. if not ret:
  33. raise IndexError(f"Failed to read frame {idx}")
  34. # BGR to RGB
  35. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  36. if self.transform:
  37. frame = self.transform(frame)
  38. return frame
  39. def release(self):
  40. self.cap.release()
  41. class VideoWriter:
  42. """视频写入器"""
  43. def __init__(self, output_path, frame_rate, width, height):
  44. self.output_path = output_path
  45. self.frame_rate = frame_rate
  46. self.width = width
  47. self.height = height
  48. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  49. self.writer = cv2.VideoWriter(output_path, fourcc, frame_rate, (width, height))
  50. def write(self, tensor):
  51. """Write tensor (C, H, W) to video"""
  52. # Tensor to numpy
  53. frame = tensor.cpu().numpy().transpose(1, 2, 0)
  54. frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8)
  55. # RGB to BGR
  56. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  57. self.writer.write(frame)
  58. def release(self):
  59. self.writer.release()
  60. class GreenScreenProcessor:
  61. def __init__(self, model_path=None, device='cuda'):
  62. """
  63. 初始化绿幕处理器 —— 使用 torch.hub 官方结构加载
  64. Args:
  65. model_path: 模型路径,默认使用脚本所在目录的 rvm_resnet50.pth
  66. device: 计算设备 'cuda' 或 'cpu'
  67. """
  68. self.device = device
  69. # 使用动态路径,避免硬编码
  70. if model_path is None:
  71. script_dir = os.path.dirname(os.path.abspath(__file__))
  72. model_path = os.path.join(script_dir, 'rvm_resnet50.pth')
  73. print(f"🚀 加载 RVM 模型: {model_path}")
  74. # 尝试加载模型,如果网络不可用则使用离线模式
  75. print("📦 加载模型结构...")
  76. try:
  77. # 尝试从 GitHub 加载(需要网络)
  78. model = torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50', skip_validation=True)
  79. print("✅ 从 GitHub 加载模型结构成功")
  80. except Exception as e:
  81. # 网络不可用,尝试使用本地缓存
  82. print(f"⚠️ 网络加载失败: {e}")
  83. print("📦 尝试使用本地缓存的模型结构...")
  84. # 设置 torch.hub 为离线模式
  85. torch.hub.set_dir(os.path.expanduser('~/.cache/torch/hub'))
  86. try:
  87. model = torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50', skip_validation=True, force_reload=False)
  88. print("✅ 从本地缓存加载模型结构成功")
  89. except Exception as e2:
  90. print(f"❌ 本地缓存也不可用: {e2}")
  91. print("\n💡 解决方案:")
  92. print(" 1. 确保网络连接正常")
  93. print(" 2. 或者预先下载模型:python -c \"import torch; torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50')\"")
  94. print(" 3. 或者使用已缓存的模型(首次运行后会自动缓存)")
  95. raise
  96. # 加载本地权重
  97. print(f"📦 加载本地权重:{model_path}")
  98. checkpoint = torch.load(model_path, map_location=device, weights_only=True)
  99. model.load_state_dict(checkpoint)
  100. self.model = model.to(device)
  101. self.model.eval()
  102. print("✅ RVM 模型加载完成!")
  103. def process_video(self, input_path, output_path, max_frames=None, downsample_ratio=0.25, progress_callback=None):
  104. """
  105. 处理整个视频
  106. Args:
  107. input_path: 输入视频路径
  108. output_path: 输出视频路径
  109. max_frames: 最大处理帧数
  110. downsample_ratio: 下采样比例
  111. progress_callback: 进度回调函数 callback(current, total, percentage)
  112. """
  113. print(f"\nProcessing video: {input_path}")
  114. reader = VideoReader(input_path, transform=ToTensor())
  115. total_frames = reader.total_frames if max_frames is None else min(reader.total_frames, max_frames)
  116. print(f"Video info: {reader.width}x{reader.height}, FPS={reader.fps}, Frames={total_frames}")
  117. writer = VideoWriter(output_path, reader.fps, reader.width, reader.height)
  118. # 绿色背景(0,255,0)
  119. bgr = torch.tensor([0, 0, 0]).view(3, 1, 1).to(self.device)
  120. rec = [None] * 4
  121. frame_count = 0
  122. with torch.no_grad():
  123. for src in tqdm(DataLoader(reader, batch_size=1), total=total_frames):
  124. if frame_count >= total_frames:
  125. break
  126. src = src.to(self.device)
  127. fgr, pha, *rec = self.model(src, *rec, downsample_ratio)
  128. # 合成绿幕
  129. com = fgr * pha + bgr * (1 - pha)
  130. writer.write(com[0])
  131. frame_count += 1
  132. # 报告进度
  133. if progress_callback and frame_count % 10 == 0: # 每10帧报告一次
  134. percentage = (frame_count / total_frames) * 100
  135. progress_callback(frame_count, total_frames, percentage)
  136. reader.release()
  137. writer.release()
  138. # 确保最终进度被报告
  139. if progress_callback:
  140. progress_callback(frame_count, total_frames, 100.0)
  141. print(f"\n✓ 视频已保存:{output_path}")
  142. print(f"✓ 处理帧数:{frame_count}")
  143. def main():
  144. parser = argparse.ArgumentParser(description='将视频背景替换为绿幕 (RVM 极速版)')
  145. parser.add_argument('input_video', type=str, help='输入视频路径')
  146. parser.add_argument('-o', '--output', type=str, default=None, help='输出视频路径')
  147. parser.add_argument('-m', '--model', type=str, default=None, help='RVM模型路径 (默认使用脚本目录下的rvm_resnet50.pth)')
  148. parser.add_argument('-d', '--device', type=str, default='cuda', help='计算设备 cuda/cpu')
  149. parser.add_argument('--max-frames', type=int, default=None, help='最大处理帧数')
  150. parser.add_argument('--downsample-ratio', type=float, default=0.25, help='下采样比例')
  151. args = parser.parse_args()
  152. script_dir = os.path.dirname(os.path.abspath(__file__))
  153. # 使用动态路径,支持命令行参数覆盖
  154. if args.model:
  155. model_path = args.model
  156. else:
  157. model_path = os.path.join(script_dir, 'rvm_resnet50.pth')
  158. # 输出路径
  159. if args.output is None:
  160. base_name = os.path.splitext(os.path.basename(args.input_video))[0]
  161. output_dir = os.path.join(script_dir, 'video')
  162. os.makedirs(output_dir, exist_ok=True)
  163. args.output = os.path.join(output_dir, f"{base_name}_greenscreen.mp4")
  164. # 检查文件
  165. if not os.path.isfile(args.input_video):
  166. print(f"❌ 输入视频不存在:{args.input_video}")
  167. sys.exit(1)
  168. if not os.path.isfile(model_path):
  169. print(f"❌ 模型不存在:{model_path}")
  170. sys.exit(1)
  171. # 运行
  172. processor = GreenScreenProcessor(model_path, args.device)
  173. try:
  174. processor.process_video(
  175. args.input_video,
  176. args.output,
  177. args.max_frames,
  178. args.downsample_ratio
  179. )
  180. except Exception as e:
  181. print(f"\n❌ 错误:{e}")
  182. import traceback
  183. traceback.print_exc()
  184. sys.exit(1)
  185. if __name__ == "__main__":
  186. main()