| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 视频绿幕替换工具
- 使用 RobustVideoMatting (RVM) 模型将视频背景替换为绿色
- 已优化:本地模型极速加载,不联网、不卡
- """
- import os
- import sys
- import cv2
- import torch
- from torch.utils.data import DataLoader
- from torchvision.transforms import ToTensor
- import numpy as np
- from tqdm import tqdm
- import argparse
- class VideoReader:
- """视频读取器"""
- def __init__(self, video_path, transform=None):
- self.cap = cv2.VideoCapture(video_path)
- if not self.cap.isOpened():
- raise ValueError(f"Cannot open video: {video_path}")
-
- self.transform = transform
- self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
- self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
- def __len__(self):
- return self.total_frames
-
- def __getitem__(self, idx):
- ret, frame = self.cap.read()
- if not ret:
- raise IndexError(f"Failed to read frame {idx}")
-
- # BGR to RGB
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-
- if self.transform:
- frame = self.transform(frame)
-
- return frame
-
- def release(self):
- self.cap.release()
- class VideoWriter:
- """视频写入器"""
- def __init__(self, output_path, frame_rate, width, height):
- self.output_path = output_path
- self.frame_rate = frame_rate
- self.width = width
- self.height = height
-
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- self.writer = cv2.VideoWriter(output_path, fourcc, frame_rate, (width, height))
-
- def write(self, tensor):
- """Write tensor (C, H, W) to video"""
- # Tensor to numpy
- frame = tensor.cpu().numpy().transpose(1, 2, 0)
- frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8)
-
- # RGB to BGR
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
-
- self.writer.write(frame)
-
- def release(self):
- self.writer.release()
- class GreenScreenProcessor:
- def __init__(self, model_path=None, device='cuda'):
- """
- 初始化绿幕处理器 —— 使用 torch.hub 官方结构加载
-
- Args:
- model_path: 模型路径,默认使用脚本所在目录的 rvm_resnet50.pth
- device: 计算设备 'cuda' 或 'cpu'
- """
- self.device = device
-
- # 使用动态路径,避免硬编码
- if model_path is None:
- script_dir = os.path.dirname(os.path.abspath(__file__))
- model_path = os.path.join(script_dir, 'rvm_resnet50.pth')
-
- print(f"🚀 加载 RVM 模型: {model_path}")
- # 尝试加载模型,如果网络不可用则使用离线模式
- print("📦 加载模型结构...")
- try:
- # 尝试从 GitHub 加载(需要网络)
- model = torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50', skip_validation=True)
- print("✅ 从 GitHub 加载模型结构成功")
- except Exception as e:
- # 网络不可用,尝试使用本地缓存
- print(f"⚠️ 网络加载失败: {e}")
- print("📦 尝试使用本地缓存的模型结构...")
-
- # 设置 torch.hub 为离线模式
- torch.hub.set_dir(os.path.expanduser('~/.cache/torch/hub'))
- try:
- model = torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50', skip_validation=True, force_reload=False)
- print("✅ 从本地缓存加载模型结构成功")
- except Exception as e2:
- print(f"❌ 本地缓存也不可用: {e2}")
- print("\n💡 解决方案:")
- print(" 1. 确保网络连接正常")
- print(" 2. 或者预先下载模型:python -c \"import torch; torch.hub.load('PeterL1n/RobustVideoMatting', 'resnet50')\"")
- print(" 3. 或者使用已缓存的模型(首次运行后会自动缓存)")
- raise
-
- # 加载本地权重
- print(f"📦 加载本地权重:{model_path}")
- checkpoint = torch.load(model_path, map_location=device, weights_only=True)
- model.load_state_dict(checkpoint)
-
- self.model = model.to(device)
- self.model.eval()
-
- print("✅ RVM 模型加载完成!")
- def process_video(self, input_path, output_path, max_frames=None, downsample_ratio=0.25, progress_callback=None):
- """
- 处理整个视频
-
- Args:
- input_path: 输入视频路径
- output_path: 输出视频路径
- max_frames: 最大处理帧数
- downsample_ratio: 下采样比例
- progress_callback: 进度回调函数 callback(current, total, percentage)
- """
- print(f"\nProcessing video: {input_path}")
-
- reader = VideoReader(input_path, transform=ToTensor())
- total_frames = reader.total_frames if max_frames is None else min(reader.total_frames, max_frames)
-
- print(f"Video info: {reader.width}x{reader.height}, FPS={reader.fps}, Frames={total_frames}")
-
- writer = VideoWriter(output_path, reader.fps, reader.width, reader.height)
-
- # 绿色背景(0,255,0)
- bgr = torch.tensor([0, 0, 0]).view(3, 1, 1).to(self.device)
-
- rec = [None] * 4
- frame_count = 0
-
- with torch.no_grad():
- for src in tqdm(DataLoader(reader, batch_size=1), total=total_frames):
- if frame_count >= total_frames:
- break
-
- src = src.to(self.device)
- fgr, pha, *rec = self.model(src, *rec, downsample_ratio)
-
- # 合成绿幕
- com = fgr * pha + bgr * (1 - pha)
-
- writer.write(com[0])
- frame_count += 1
-
- # 报告进度
- if progress_callback and frame_count % 10 == 0: # 每10帧报告一次
- percentage = (frame_count / total_frames) * 100
- progress_callback(frame_count, total_frames, percentage)
-
- reader.release()
- writer.release()
-
- # 确保最终进度被报告
- if progress_callback:
- progress_callback(frame_count, total_frames, 100.0)
-
- print(f"\n✓ 视频已保存:{output_path}")
- print(f"✓ 处理帧数:{frame_count}")
- def main():
- parser = argparse.ArgumentParser(description='将视频背景替换为绿幕 (RVM 极速版)')
- parser.add_argument('input_video', type=str, help='输入视频路径')
- parser.add_argument('-o', '--output', type=str, default=None, help='输出视频路径')
- parser.add_argument('-m', '--model', type=str, default=None, help='RVM模型路径 (默认使用脚本目录下的rvm_resnet50.pth)')
- parser.add_argument('-d', '--device', type=str, default='cuda', help='计算设备 cuda/cpu')
- parser.add_argument('--max-frames', type=int, default=None, help='最大处理帧数')
- parser.add_argument('--downsample-ratio', type=float, default=0.25, help='下采样比例')
-
- args = parser.parse_args()
- script_dir = os.path.dirname(os.path.abspath(__file__))
-
- # 使用动态路径,支持命令行参数覆盖
- if args.model:
- model_path = args.model
- else:
- model_path = os.path.join(script_dir, 'rvm_resnet50.pth')
- # 输出路径
- if args.output is None:
- base_name = os.path.splitext(os.path.basename(args.input_video))[0]
- output_dir = os.path.join(script_dir, 'video')
- os.makedirs(output_dir, exist_ok=True)
- args.output = os.path.join(output_dir, f"{base_name}_greenscreen.mp4")
-
- # 检查文件
- if not os.path.isfile(args.input_video):
- print(f"❌ 输入视频不存在:{args.input_video}")
- sys.exit(1)
-
- if not os.path.isfile(model_path):
- print(f"❌ 模型不存在:{model_path}")
- sys.exit(1)
-
- # 运行
- processor = GreenScreenProcessor(model_path, args.device)
-
- try:
- processor.process_video(
- args.input_video,
- args.output,
- args.max_frames,
- args.downsample_ratio
- )
- except Exception as e:
- print(f"\n❌ 错误:{e}")
- import traceback
- traceback.print_exc()
- sys.exit(1)
- if __name__ == "__main__":
- main()
|