#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 视频绿幕替换工具 使用 RobustVideoMatting (RVM) 模型将视频背景替换为绿色 已优化:本地模型极速加载,不联网、不卡 """ import os import sys import cv2 import torch from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import numpy as np from tqdm import tqdm import argparse class VideoReader: """视频读取器""" def __init__(self, video_path, transform=None): self.cap = cv2.VideoCapture(video_path) if not self.cap.isOpened(): raise ValueError(f"Cannot open video: {video_path}") self.transform = transform self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) def __len__(self): return self.total_frames def __getitem__(self, idx): ret, frame = self.cap.read() if not ret: raise IndexError(f"Failed to read frame {idx}") # BGR to RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if self.transform: frame = self.transform(frame) return frame def release(self): self.cap.release() class VideoWriter: """视频写入器""" def __init__(self, output_path, frame_rate, width, height): self.output_path = output_path self.frame_rate = frame_rate self.width = width self.height = height fourcc = cv2.VideoWriter_fourcc(*'mp4v') self.writer = cv2.VideoWriter(output_path, fourcc, frame_rate, (width, height)) def write(self, tensor): """Write tensor (C, H, W) to video""" # Tensor to numpy frame = tensor.cpu().numpy().transpose(1, 2, 0) frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8) # RGB to BGR frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) self.writer.write(frame) def release(self): self.writer.release() class GreenScreenProcessor: def __init__(self, model_path=None, device='cuda'): """ 初始化绿幕处理器 —— 使用 torch.hub 官方结构加载 Args: model_path: 模型路径,默认使用脚本所在目录的 rvm_resnet50.pth device: 计算设备 'cuda' 或 'cpu' """ self.device = device # 使用动态路径,避免硬编码 if model_path is None: script_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(script_dir, 'rvm_resnet50.pth') print(f"🚀 加载 RVM 模型: {model_path}") # 尝试加载模型,如果网络不可用则使用离线模式 print("📦 加载模型结构...") try: # 尝试从 GitHub 加载(需要网络) model = torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50', skip_validation=True) print("✅ 从 GitHub 加载模型结构成功") except Exception as e: # 网络不可用,尝试使用本地缓存 print(f"⚠️ 网络加载失败: {e}") print("📦 尝试使用本地缓存的模型结构...") # 设置 torch.hub 为离线模式 torch.hub.set_dir(os.path.expanduser('~/.cache/torch/hub')) try: model = torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50', skip_validation=True, force_reload=False) print("✅ 从本地缓存加载模型结构成功") except Exception as e2: print(f"❌ 本地缓存也不可用: {e2}") print("\n💡 解决方案:") print(" 1. 确保网络连接正常") print(" 2. 或者预先下载模型:python -c \"import torch; torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50')\"") print(" 3. 或者使用已缓存的模型(首次运行后会自动缓存)") raise # 加载本地权重 print(f"📦 加载本地权重:{model_path}") # 兼容CPU和GPU,关闭weights_only避免反序列化错误 checkpoint = torch.load(model_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint) self.model = model.to(device) self.model.eval() print("✅ RVM 模型加载完成!") def process_video(self, input_path, output_path, max_frames=None, downsample_ratio=0.25, progress_callback=None): """ 处理整个视频 Args: input_path: 输入视频路径 output_path: 输出视频路径 max_frames: 最大处理帧数 downsample_ratio: 下采样比例 progress_callback: 进度回调函数 callback(current, total, percentage) """ print(f"\nProcessing video: {input_path}") reader = VideoReader(input_path, transform=ToTensor()) total_frames = reader.total_frames if max_frames is None else min(reader.total_frames, max_frames) print(f"Video info: {reader.width}x{reader.height}, FPS={reader.fps}, Frames={total_frames}") writer = VideoWriter(output_path, reader.fps, reader.width, reader.height) # 绿色背景(0,255,0) bgr = torch.tensor([0, 0, 0]).view(3, 1, 1).to(self.device) rec = [None] * 4 frame_count = 0 with torch.no_grad(): for src in tqdm(DataLoader(reader, batch_size=1), total=total_frames): if frame_count >= total_frames: break src = src.to(self.device) fgr, pha, *rec = self.model(src, *rec, downsample_ratio) # 合成绿幕 com = fgr * pha + bgr * (1 - pha) writer.write(com[0]) frame_count += 1 # 报告进度 if progress_callback and frame_count % 10 == 0: # 每10帧报告一次 percentage = (frame_count / total_frames) * 100 progress_callback(frame_count, total_frames, percentage) reader.release() writer.release() # 确保最终进度被报告 if progress_callback: progress_callback(frame_count, total_frames, 100.0) print(f"\n✓ 视频已保存:{output_path}") print(f"✓ 处理帧数:{frame_count}") def main(): parser = argparse.ArgumentParser(description='将视频背景替换为绿幕 (RVM 极速版)') parser.add_argument('input_video', type=str, help='输入视频路径') parser.add_argument('-o', '--output', type=str, default=None, help='输出视频路径') parser.add_argument('-m', '--model', type=str, default=None, help='RVM模型路径 (默认使用脚本目录下的rvm_resnet50.pth)') parser.add_argument('-d', '--device', type=str, default='cpu', help='计算设备 cuda/cpu') parser.add_argument('--max-frames', type=int, default=None, help='最大处理帧数') parser.add_argument('--downsample-ratio', type=float, default=0.25, help='下采样比例') args = parser.parse_args() script_dir = os.path.dirname(os.path.abspath(__file__)) # 使用动态路径,支持命令行参数覆盖 if args.model: model_path = args.model else: model_path = os.path.join(script_dir, 'rvm_resnet50.pth') # 输出路径 if args.output is None: base_name = os.path.splitext(os.path.basename(args.input_video))[0] output_dir = os.path.join(script_dir, 'video') os.makedirs(output_dir, exist_ok=True) args.output = os.path.join(output_dir, f"{base_name}_greenscreen.mp4") # 检查文件 if not os.path.isfile(args.input_video): print(f"❌ 输入视频不存在:{args.input_video}") sys.exit(1) if not os.path.isfile(model_path): print(f"❌ 模型不存在:{model_path}") sys.exit(1) # 运行 processor = GreenScreenProcessor(model_path, args.device) try: processor.process_video( args.input_video, args.output, args.max_frames, args.downsample_ratio ) except Exception as e: print(f"\n❌ 错误:{e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()