|
|
@@ -0,0 +1,1075 @@
|
|
|
+# server.py
|
|
|
+import flask
|
|
|
+from flask import Flask, render_template,send_from_directory,request, jsonify
|
|
|
+from flask_sockets import Sockets
|
|
|
+import base64
|
|
|
+import json
|
|
|
+#import gevent
|
|
|
+#from gevent import pywsgi
|
|
|
+#from geventwebsocket.handler import WebSocketHandler
|
|
|
+import re
|
|
|
+import numpy as np
|
|
|
+from threading import Thread,Event
|
|
|
+#import multiprocessing
|
|
|
+
|
|
|
+# 禁用 TorchDynamo 编译(避免 VoxCPM2 兼容性问题)
|
|
|
+import os
|
|
|
+os.environ['TORCHDYNAMO_DISABLE'] = '1'
|
|
|
+import torch
|
|
|
+if hasattr(torch, '_dynamo'):
|
|
|
+ torch._dynamo.config.suppress_errors = True
|
|
|
+import torch.multiprocessing as mp
|
|
|
+
|
|
|
+from aiohttp import web
|
|
|
+import aiohttp
|
|
|
+import aiohttp_cors
|
|
|
+from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration
|
|
|
+from aiortc.rtcrtpsender import RTCRtpSender
|
|
|
+from webrtc import HumanPlayer
|
|
|
+from basereal import BaseReal
|
|
|
+from llm import llm_response
|
|
|
+
|
|
|
+import argparse
|
|
|
+import random
|
|
|
+import shutil
|
|
|
+import asyncio
|
|
|
+import torch
|
|
|
+import os
|
|
|
+import socket
|
|
|
+import signal
|
|
|
+import sys
|
|
|
+from typing import Dict
|
|
|
+from logger import logger
|
|
|
+import gc
|
|
|
+
|
|
|
+
|
|
|
+app = Flask(__name__)
|
|
|
+#sockets = Sockets(app)
|
|
|
+nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
|
|
|
+opt = None
|
|
|
+model = None
|
|
|
+avatar = None
|
|
|
+
|
|
|
+
|
|
|
+#####webrtc###############################
|
|
|
+pcs = set()
|
|
|
+# 记录本次运行创建的 avatar 目录,用于退出时清理
|
|
|
+_avatar_dirs_to_clean = set()
|
|
|
+_enable_avatar_cleanup = False
|
|
|
+
|
|
|
+def randN(N)->int:
|
|
|
+ '''生成长度为 N的随机数 '''
|
|
|
+ min = pow(10, N - 1)
|
|
|
+ max = pow(10, N)
|
|
|
+ return random.randint(min, max - 1)
|
|
|
+
|
|
|
+def build_nerfreal(sessionid:int)->BaseReal:
|
|
|
+ opt.sessionid=sessionid
|
|
|
+ # 记录本次运行使用的 avatar 目录
|
|
|
+ if hasattr(opt, 'avatar_id') and opt.avatar_id:
|
|
|
+ avatar_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'avatars', opt.avatar_id)
|
|
|
+ if os.path.exists(avatar_dir):
|
|
|
+ _avatar_dirs_to_clean.add(avatar_dir)
|
|
|
+ logger.info(f'记录 avatar 目录用于退出时清理: {avatar_dir}')
|
|
|
+
|
|
|
+ if opt.model == 'wav2lip':
|
|
|
+ from lipreal import LipReal
|
|
|
+ nerfreal = LipReal(opt,model,avatar)
|
|
|
+ elif opt.model == 'musetalk':
|
|
|
+ from musereal import MuseReal
|
|
|
+ nerfreal = MuseReal(opt,model,avatar)
|
|
|
+ # elif opt.model == 'ernerf':
|
|
|
+ # from nerfreal import NeRFReal
|
|
|
+ # nerfreal = NeRFReal(opt,model,avatar)
|
|
|
+ elif opt.model == 'ultralight':
|
|
|
+ from lightreal import LightReal
|
|
|
+ nerfreal = LightReal(opt,model,avatar)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported model type: {opt.model}")
|
|
|
+ return nerfreal
|
|
|
+
|
|
|
+#@app.route('/offer', methods=['POST'])
|
|
|
+async def offer(request):
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+ offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
|
|
+
|
|
|
+ # if len(nerfreals) >= opt.max_session:
|
|
|
+ # logger.info('reach max session')
|
|
|
+ # return web.Response(
|
|
|
+ # content_type="application/json",
|
|
|
+ # text=json.dumps(
|
|
|
+ # {"code": -1, "msg": "reach max session"}
|
|
|
+ # ),
|
|
|
+ # )
|
|
|
+ sessionid = randN(6) #len(nerfreals)
|
|
|
+ nerfreals[sessionid] = None
|
|
|
+ logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals))
|
|
|
+ nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
|
|
|
+ nerfreals[sessionid] = nerfreal
|
|
|
+
|
|
|
+ # 内网/跳板机环境:不使用外网 STUN(内网访问不到),只用 host 候选
|
|
|
+ # 通过环境变量 WEBRTC_NAT_IP 指定跳板机对外暴露的 IP(浏览器能访问的 IP)
|
|
|
+ nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
|
|
|
+ if nat_public_ip:
|
|
|
+ logger.info('Using NAT public IP for ICE: %s', nat_public_ip)
|
|
|
+ ice_servers = [] # 内网不用 STUN
|
|
|
+ else:
|
|
|
+ # 使用多个 STUN 服务器(包括国内和国外的)
|
|
|
+ ice_servers = [
|
|
|
+ RTCIceServer(urls='stun:stun.l.google.com:19302'),
|
|
|
+ RTCIceServer(urls='stun:stun1.l.google.com:19302'),
|
|
|
+ RTCIceServer(urls='stun:stun2.l.google.com:19302'),
|
|
|
+ RTCIceServer(urls='stun:stun3.l.google.com:19302'),
|
|
|
+ RTCIceServer(urls='stun:stun4.l.google.com:19302'),
|
|
|
+ ]
|
|
|
+ pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers))
|
|
|
+ pcs.add(pc)
|
|
|
+
|
|
|
+ @pc.on("connectionstatechange")
|
|
|
+ async def on_connectionstatechange():
|
|
|
+ logger.info(f"Session {sessionid} - Connection state is {pc.connectionState}")
|
|
|
+ if pc.connectionState == "connected":
|
|
|
+ logger.info(f"Session {sessionid} - WebRTC connection established successfully!")
|
|
|
+ if pc.connectionState == "failed":
|
|
|
+ logger.error(f"Session {sessionid} - Connection failed!")
|
|
|
+ await pc.close()
|
|
|
+ pcs.discard(pc)
|
|
|
+ if sessionid in nerfreals:
|
|
|
+ del nerfreals[sessionid]
|
|
|
+ if pc.connectionState == "closed":
|
|
|
+ logger.info(f"Session {sessionid} - Connection closed")
|
|
|
+ pcs.discard(pc)
|
|
|
+ if sessionid in nerfreals:
|
|
|
+ del nerfreals[sessionid]
|
|
|
+ # gc.collect()
|
|
|
+
|
|
|
+ player = HumanPlayer(nerfreals[sessionid])
|
|
|
+ audio_sender = pc.addTrack(player.audio)
|
|
|
+ video_sender = pc.addTrack(player.video)
|
|
|
+
|
|
|
+ # 记录轨道添加信息
|
|
|
+ logger.info(f"Added tracks for session {sessionid}: audio={player.audio}, video={player.video}")
|
|
|
+
|
|
|
+ capabilities = RTCRtpSender.getCapabilities("video")
|
|
|
+ preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
|
|
|
+ preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
|
|
|
+ preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
|
|
|
+ transceiver = pc.getTransceivers()[1]
|
|
|
+ transceiver.setCodecPreferences(preferences)
|
|
|
+
|
|
|
+ await pc.setRemoteDescription(offer)
|
|
|
+
|
|
|
+ answer = await pc.createAnswer()
|
|
|
+ await pc.setLocalDescription(answer)
|
|
|
+
|
|
|
+ sdp = pc.localDescription.sdp
|
|
|
+
|
|
|
+ # 跳板机内网场景:如果配置了 WEBRTC_NAT_IP,将 SDP 中的内网 IP 替换为跳板机对外 IP
|
|
|
+ # aioice 已经将真实绑定的端口(AIOICE_PORT_MIN~MAX 范围内)写入 SDP 的 a=candidate 行
|
|
|
+ # 我们只需要把内网 IP 替换成浏览器能访问的 IP 即可
|
|
|
+ nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
|
|
|
+ if nat_public_ip:
|
|
|
+ try:
|
|
|
+ # 获取本机所有非 loopback IPv4 地址
|
|
|
+ local_ips = set()
|
|
|
+ bind_ip = os.environ.get('AIOICE_BIND_IP', '')
|
|
|
+ if bind_ip:
|
|
|
+ local_ips.add(bind_ip)
|
|
|
+ else:
|
|
|
+ hostname = socket.gethostname()
|
|
|
+ for ip_info in socket.getaddrinfo(hostname, None, socket.AF_INET):
|
|
|
+ ip = ip_info[4][0]
|
|
|
+ if not ip.startswith('127.'):
|
|
|
+ local_ips.add(ip)
|
|
|
+ # 也枚举所有网卡 IP(getaddrinfo 可能不全)
|
|
|
+ try:
|
|
|
+ import ifaddr
|
|
|
+ for adapter in ifaddr.get_adapters():
|
|
|
+ for ip_obj in adapter.ips:
|
|
|
+ if isinstance(ip_obj.ip, str) and not ip_obj.ip.startswith('127.'):
|
|
|
+ local_ips.add(ip_obj.ip)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 将 SDP 中所有内网 IP 替换为跳板机对外 IP
|
|
|
+ for local_ip in local_ips:
|
|
|
+ sdp = sdp.replace(local_ip, nat_public_ip)
|
|
|
+ logger.info('NAT mapping: %s -> %s in SDP', local_ips, nat_public_ip)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning('Failed to apply NAT IP mapping to SDP: %s', e)
|
|
|
+ else:
|
|
|
+ logger.info('WEBRTC_NAT_IP not set. SDP uses server internal IPs. '
|
|
|
+ 'Set WEBRTC_NAT_IP=<jumphost_ip> if browser cannot reach server directly.')
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"sdp": sdp, "type": pc.localDescription.type, "sessionid": sessionid}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('Error in offer:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps({"code": -1, "msg": str(e)}),
|
|
|
+ status=500
|
|
|
+ )
|
|
|
+
|
|
|
+async def log(request):
|
|
|
+ """接收前端日志的接口"""
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+ log_type = params.get('type', 'info')
|
|
|
+ message = params.get('message', '')
|
|
|
+ logger.info('[WEBRTC] %s', message)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps({"code": 0}),
|
|
|
+ )
|
|
|
+
|
|
|
+async def human(request):
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+
|
|
|
+ sessionid = params.get('sessionid',0)
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 标记是否已经处理了聊天请求
|
|
|
+ chat_processed = False
|
|
|
+
|
|
|
+ if params.get('interrupt'):
|
|
|
+ # 立即中断当前播放,无论是否是聊天请求
|
|
|
+ nerfreals[sessionid].flush_talk()
|
|
|
+ # 检查是否是用户在介绍过程中提问
|
|
|
+ if params['type'] == 'chat':
|
|
|
+ knowledge_base_type = params.get('knowledge_base', None)
|
|
|
+ during_intro = params.get('during_intro', False)
|
|
|
+ # 处理用户问题(异步执行,避免阻塞事件循环)
|
|
|
+ asyncio.get_event_loop().run_in_executor(
|
|
|
+ None,
|
|
|
+ llm_response,
|
|
|
+ params['text'],
|
|
|
+ nerfreals[sessionid],
|
|
|
+ knowledge_base_type,
|
|
|
+ during_intro
|
|
|
+ )
|
|
|
+ chat_processed = True
|
|
|
+
|
|
|
+ if params['type']=='echo':
|
|
|
+ nerfreals[sessionid].put_msg_txt(params['text'])
|
|
|
+ elif params['type']=='chat' and not chat_processed:
|
|
|
+ # 如果没有中断标志,或者已经处理过聊天请求,则不再处理
|
|
|
+ knowledge_base_type = params.get('knowledge_base', None)
|
|
|
+ during_intro = params.get('during_intro', False) # 是否在介绍过程中
|
|
|
+ asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
|
|
|
+ #nerfreals[sessionid].put_msg_txt(res)
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg":"ok"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def api_chat(request):
|
|
|
+ """
|
|
|
+ 独立聊天API - 让数字人说话/问答
|
|
|
+ 请求体: {
|
|
|
+ "text": "用户说的话",
|
|
|
+ "type": "chat", // "chat"=问答, "echo"=纯播报
|
|
|
+ "interrupt": true, // 是否打断当前播放
|
|
|
+ "sessionid": 123456, // WebRTC session ID
|
|
|
+ "knowledge_base": "", // 可选: 指定知识库
|
|
|
+ "during_intro": false // 可选: 是否在介绍过程中
|
|
|
+ }
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+
|
|
|
+ # 参数验证
|
|
|
+ text = params.get('text', '').strip()
|
|
|
+ if not text:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps({"code": -1, "msg": "text parameter is required"}),
|
|
|
+ )
|
|
|
+
|
|
|
+ sessionid = params.get('sessionid', 0)
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ msg_type = params.get('type', 'chat')
|
|
|
+ interrupt = params.get('interrupt', True)
|
|
|
+ knowledge_base = params.get('knowledge_base', None)
|
|
|
+ during_intro = params.get('during_intro', False)
|
|
|
+
|
|
|
+ logger.info(f"API Chat - Session {sessionid}, Type: {msg_type}, Text: {text[:50]}...")
|
|
|
+
|
|
|
+ # 处理中断
|
|
|
+ if interrupt:
|
|
|
+ nerfreals[sessionid].flush_talk()
|
|
|
+
|
|
|
+ # 根据类型处理
|
|
|
+ if msg_type == 'echo':
|
|
|
+ # 纯播报模式 - 直接播放文本
|
|
|
+ nerfreals[sessionid].put_msg_txt(text)
|
|
|
+ else:
|
|
|
+ # 聊天/问答模式 - 走LLM
|
|
|
+ asyncio.get_event_loop().run_in_executor(
|
|
|
+ None,
|
|
|
+ llm_response,
|
|
|
+ text,
|
|
|
+ nerfreals[sessionid],
|
|
|
+ knowledge_base,
|
|
|
+ during_intro
|
|
|
+ )
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps({
|
|
|
+ "code": 0,
|
|
|
+ "msg": "ok",
|
|
|
+ "data": {
|
|
|
+ "sessionid": sessionid,
|
|
|
+ "type": msg_type,
|
|
|
+ "text": text[:100] # 返回前100字符用于确认
|
|
|
+ }
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('API Chat exception:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps({"code": -1, "msg": str(e)}),
|
|
|
+ )
|
|
|
+
|
|
|
+async def interrupt_talk(request):
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+
|
|
|
+ sessionid = params.get('sessionid',0)
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ nerfreals[sessionid].flush_talk()
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg":"ok"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def humanaudio(request):
|
|
|
+ try:
|
|
|
+ form= await request.post()
|
|
|
+ sessionid = int(form.get('sessionid',0))
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ fileobj = form["file"]
|
|
|
+ filename=fileobj.filename
|
|
|
+ filebytes=fileobj.file.read()
|
|
|
+ nerfreals[sessionid].put_audio_file(filebytes)
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg":"ok"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def set_audiotype(request):
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+
|
|
|
+ sessionid = params.get('sessionid',0)
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg":"ok"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def record(request):
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+
|
|
|
+ sessionid = params.get('sessionid',0)
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ if params['type']=='start_record':
|
|
|
+ # nerfreals[sessionid].put_msg_txt(params['text'])
|
|
|
+ nerfreals[sessionid].start_recording()
|
|
|
+ elif params['type']=='end_record':
|
|
|
+ nerfreals[sessionid].stop_recording()
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg":"ok"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def is_speaking(request):
|
|
|
+ params = await request.json()
|
|
|
+
|
|
|
+ sessionid = params.get('sessionid',0)
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "data": nerfreals[sessionid].is_speaking()}
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def knowledge_intro(request):
|
|
|
+ """返回知识库介绍内容"""
|
|
|
+ try:
|
|
|
+ from knowledge_intro import start_intro_play, knowledge_intro
|
|
|
+ # 启动完整版介绍播放,获取第一条文案
|
|
|
+ play_result = start_intro_play("full")
|
|
|
+ intro_text = play_result.get("text", "")
|
|
|
+
|
|
|
+ params = await request.json()
|
|
|
+ sessionid = params.get('sessionid', 0)
|
|
|
+
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存介绍实例到会话中,便于后续操作
|
|
|
+ if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
|
|
|
+ nerfreals[sessionid].knowledge_intro_instance = knowledge_intro
|
|
|
+
|
|
|
+ # 保存介绍播放状态到会话中
|
|
|
+ if not hasattr(nerfreals[sessionid], 'intro_play_state'):
|
|
|
+ nerfreals[sessionid].intro_play_state = {
|
|
|
+ "is_playing": True,
|
|
|
+ "current_type": "full",
|
|
|
+ "last_played_index": 0,
|
|
|
+ "is_paused": False,
|
|
|
+ "is_waiting_next": False
|
|
|
+ }
|
|
|
+
|
|
|
+ # 使用支持打断恢复的新方法来播放介绍内容
|
|
|
+ nerfreals[sessionid].start_intro_with_interrupt_capability(intro_text)
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "Knowledge intro played successfully", "text": intro_text, "mode": "intro", "play_index": play_result.get("play_index", 1), "total_count": play_result.get("total_count", 8)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception in knowledge_intro:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def resume_interrupted(request):
|
|
|
+ """恢复播放被中断的消息"""
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+ sessionid = params.get('sessionid', 0)
|
|
|
+
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 尝试恢复被中断的消息
|
|
|
+ resumed = nerfreals[sessionid].resume_interrupted()
|
|
|
+
|
|
|
+ if resumed:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "Interrupted messages resumed successfully", "mode": "intro"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "No interrupted messages to resume"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception in resume_interrupted:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def handle_user_question(request):
|
|
|
+ """处理用户提问,暂停当前内容并优先回答问题"""
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+ sessionid = params.get('sessionid', 0)
|
|
|
+ question = params.get('text', '')
|
|
|
+
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ if not question:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": "Question text is required"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查是否在介绍过程中,如果是,则标记
|
|
|
+ during_intro = params.get('during_intro', False)
|
|
|
+ if during_intro and hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
|
|
|
+ # 暂停介绍播放
|
|
|
+ nerfreals[sessionid].knowledge_intro_instance.pause_play()
|
|
|
+
|
|
|
+ # 使用高优先级方法处理用户问题
|
|
|
+ knowledge_base_type = params.get('knowledge_base', None)
|
|
|
+ # 直接调用llm_response,它会使用put_user_question方法
|
|
|
+ asyncio.get_event_loop().run_in_executor(None, llm_response, question, nerfreals[sessionid], knowledge_base_type, during_intro)
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "User question processed successfully", "mode": "qa"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception in handle_user_question:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def continue_after_qa(request):
|
|
|
+ """在用户问题回答完毕后,继续播放之前的内容"""
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+ sessionid = params.get('sessionid', 0)
|
|
|
+
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 尝试恢复被中断的内容
|
|
|
+ if hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
|
|
|
+ # 从knowledge_intro_instance获取下一条内容
|
|
|
+ next_content = nerfreals[sessionid].knowledge_intro_instance.resume_play()
|
|
|
+ if next_content:
|
|
|
+ # 播放下一条介绍内容
|
|
|
+ nerfreals[sessionid].put_msg_txt(next_content['text'])
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "Previously interrupted content resumed successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 如果没有找到knowledge_intro_instance或没有剩余内容
|
|
|
+ resumed = nerfreals[sessionid].resume_interrupted()
|
|
|
+
|
|
|
+ if resumed:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "Previously interrupted content resumed successfully", "mode": "intro"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "No interrupted content to resume"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception in continue_after_qa:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def intro_play_completed(request):
|
|
|
+ """处理介绍播放完成的回调,自动播放下一条介绍内容"""
|
|
|
+ try:
|
|
|
+ params = await request.json()
|
|
|
+ sessionid = params.get('sessionid', 0)
|
|
|
+
|
|
|
+ if sessionid not in nerfreals or nerfreals[sessionid] is None:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查是否有介绍实例
|
|
|
+ if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": "Knowledge intro instance not found"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查介绍播放状态
|
|
|
+ if hasattr(nerfreals[sessionid], 'intro_play_state') and not nerfreals[sessionid].intro_play_state.get("is_playing", True):
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "Introduction playback is paused"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 获取下一条介绍内容
|
|
|
+ next_content = nerfreals[sessionid].knowledge_intro_instance._get_next_content()
|
|
|
+ if next_content:
|
|
|
+ # 播放下一条介绍内容
|
|
|
+ nerfreals[sessionid].put_msg_txt(next_content['text'])
|
|
|
+
|
|
|
+ # 更新播放状态
|
|
|
+ if hasattr(nerfreals[sessionid], 'intro_play_state'):
|
|
|
+ nerfreals[sessionid].intro_play_state["last_played_index"] = next_content.get("play_index", 1)
|
|
|
+
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "Next introduction content played successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8), "is_last": next_content.get("is_last", False)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": 0, "msg": "No more introduction content"}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception('exception in intro_play_completed:')
|
|
|
+ return web.Response(
|
|
|
+ content_type="application/json",
|
|
|
+ text=json.dumps(
|
|
|
+ {"code": -1, "msg": str(e)}
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+async def on_shutdown(app):
|
|
|
+ """关闭时清理所有资源"""
|
|
|
+ logger.info("开始清理所有资源...")
|
|
|
+
|
|
|
+ # 1. 关闭所有 WebRTC 连接
|
|
|
+ coros = [pc.close() for pc in pcs]
|
|
|
+ await asyncio.gather(*coros, return_exceptions=True)
|
|
|
+ pcs.clear()
|
|
|
+
|
|
|
+ # 2. 清理所有数字人实例
|
|
|
+ for sessionid, nerfreal in nerfreals.items():
|
|
|
+ try:
|
|
|
+ if nerfreal is not None:
|
|
|
+ logger.info(f"清理 session {sessionid} 的资源")
|
|
|
+ # 停止 TTS
|
|
|
+ if hasattr(nerfreal, 'tts') and nerfreal.tts:
|
|
|
+ nerfreal.tts.state = State.PAUSE
|
|
|
+ # 清理队列
|
|
|
+ if hasattr(nerfreal, 'msg_queue'):
|
|
|
+ with nerfreal.msg_queue.mutex:
|
|
|
+ nerfreal.msg_queue.queue.clear()
|
|
|
+ if hasattr(nerfreal, 'interrupted_queue'):
|
|
|
+ with nerfreal.interrupted_queue.mutex:
|
|
|
+ nerfreal.interrupted_queue.queue.clear()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"清理 session {sessionid} 时出错: {e}")
|
|
|
+
|
|
|
+ # 3. 清理数字人实例字典
|
|
|
+ nerfreals.clear()
|
|
|
+
|
|
|
+ # 4. 强制垃圾回收
|
|
|
+ import gc
|
|
|
+ gc.collect()
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ logger.info(f"清理后 GPU 显存: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
|
|
|
+
|
|
|
+ logger.info("所有资源清理完成")
|
|
|
+
|
|
|
+def cleanup_avatar_directories():
|
|
|
+ """程序退出时清理 avatar 目录"""
|
|
|
+ global _avatar_dirs_to_clean, _enable_avatar_cleanup
|
|
|
+
|
|
|
+ if not _enable_avatar_cleanup:
|
|
|
+ logger.info("未启用 avatar 目录自动清理功能")
|
|
|
+ return
|
|
|
+
|
|
|
+ if not _avatar_dirs_to_clean:
|
|
|
+ logger.info("没有需要清理的 avatar 目录")
|
|
|
+ return
|
|
|
+
|
|
|
+ logger.info(f"开始清理 {len(_avatar_dirs_to_clean)} 个 avatar 目录...")
|
|
|
+
|
|
|
+ for avatar_dir in _avatar_dirs_to_clean:
|
|
|
+ try:
|
|
|
+ if os.path.exists(avatar_dir):
|
|
|
+ logger.info(f"正在删除 avatar 目录: {avatar_dir}")
|
|
|
+ shutil.rmtree(avatar_dir)
|
|
|
+ logger.info(f"✅ 已删除: {avatar_dir}")
|
|
|
+ else:
|
|
|
+ logger.info(f"目录不存在,无需删除: {avatar_dir}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 删除 avatar 目录失败 {avatar_dir}: {e}")
|
|
|
+
|
|
|
+ _avatar_dirs_to_clean.clear()
|
|
|
+ logger.info("所有 avatar 目录清理完成")
|
|
|
+
|
|
|
+def signal_handler(signum, frame):
|
|
|
+ """信号处理器 - 捕获退出信号"""
|
|
|
+ logger.info(f"收到退出信号 {signum},开始清理...")
|
|
|
+ cleanup_avatar_directories()
|
|
|
+ sys.exit(0)
|
|
|
+
|
|
|
+async def post(url,data):
|
|
|
+ try:
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ async with session.post(url,data=data) as response:
|
|
|
+ return await response.text()
|
|
|
+ except aiohttp.ClientError as e:
|
|
|
+ logger.error(f'POST请求错误: {e}')
|
|
|
+ return None # 明确返回None
|
|
|
+
|
|
|
+async def run(push_url,sessionid):
|
|
|
+ nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
|
|
|
+ nerfreals[sessionid] = nerfreal
|
|
|
+
|
|
|
+ # RTMP 推流模式:跳过 WHIP 连接,使用 basereal.py 中的 ffmpeg 管道推流
|
|
|
+ if push_url.startswith('rtmp://'):
|
|
|
+ logger.info(f'RTMP 推流模式: {push_url},跳过 WHIP 连接,推流将在 render 时自动启动')
|
|
|
+ # 需要启动渲染循环以产生视频/音频帧
|
|
|
+ # 使用 HumanPlayer 触发 render,但不建立 WebRTC 连接
|
|
|
+ player = HumanPlayer(nerfreals[sessionid])
|
|
|
+ # 手动启动 player worker thread 来触发 nerfreal.render()
|
|
|
+ # HumanPlayer._start 会在 track 被 recv 时调用,这里直接启动
|
|
|
+ from threading import Event as ThreadEvent
|
|
|
+ render_quit_event = ThreadEvent()
|
|
|
+ # 直接启动 render 线程,不需要 WebRTC track
|
|
|
+ import threading
|
|
|
+ def rtmp_render_loop():
|
|
|
+ nerfreals[sessionid].render(render_quit_event, loop=None, audio_track=None, video_track=None)
|
|
|
+ render_thread = threading.Thread(target=rtmp_render_loop, daemon=True, name='rtmp_render')
|
|
|
+ render_thread.start()
|
|
|
+ logger.info('RTMP 渲染线程已启动')
|
|
|
+ return
|
|
|
+
|
|
|
+ # WebRTC WHIP 推流(原有逻辑)
|
|
|
+ pc = RTCPeerConnection()
|
|
|
+ pcs.add(pc)
|
|
|
+
|
|
|
+ @pc.on("connectionstatechange")
|
|
|
+ async def on_connectionstatechange():
|
|
|
+ logger.info("Connection state is %s" % pc.connectionState)
|
|
|
+ if pc.connectionState == "failed":
|
|
|
+ await pc.close()
|
|
|
+ pcs.discard(pc)
|
|
|
+
|
|
|
+ player = HumanPlayer(nerfreals[sessionid])
|
|
|
+ audio_sender = pc.addTrack(player.audio)
|
|
|
+ video_sender = pc.addTrack(player.video)
|
|
|
+
|
|
|
+ await pc.setLocalDescription(await pc.createOffer())
|
|
|
+ answer = await post(push_url,pc.localDescription.sdp)
|
|
|
+
|
|
|
+ # 检查POST请求是否成功
|
|
|
+ if answer is None:
|
|
|
+ logger.error(f'推流失败: 无法连接到 {push_url}')
|
|
|
+ await pc.close()
|
|
|
+ pcs.discard(pc)
|
|
|
+ return
|
|
|
+
|
|
|
+ await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
|
|
|
+##########################################
|
|
|
+# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
|
|
+# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
|
|
|
+if __name__ == '__main__':
|
|
|
+ # 注册信号处理器
|
|
|
+ signal.signal(signal.SIGINT, signal_handler)
|
|
|
+ signal.signal(signal.SIGTERM, signal_handler)
|
|
|
+
|
|
|
+ mp.set_start_method('spawn')
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+
|
|
|
+ # audio FPS
|
|
|
+ parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
|
|
|
+ # sliding window left-middle-right length (unit: 20ms)
|
|
|
+ parser.add_argument('-l', type=int, default=10)
|
|
|
+ parser.add_argument('-m', type=int, default=8)
|
|
|
+ parser.add_argument('-r', type=int, default=10)
|
|
|
+
|
|
|
+ parser.add_argument('--W', type=int, default=450, help="GUI width")
|
|
|
+ parser.add_argument('--H', type=int, default=450, help="GUI height")
|
|
|
+
|
|
|
+ #musetalk opt
|
|
|
+ parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars")
|
|
|
+ #parser.add_argument('--bbox_shift', type=int, default=5)
|
|
|
+ parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
|
|
|
+
|
|
|
+ parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
|
|
|
+
|
|
|
+ parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice fishtts tencent doubao indextts2 azuretts qwen3tts voxcpm2api
|
|
|
+ parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural",help="参考文件名或语音模型 ID,默认值为 edgetts 的语音模型 ID zh-CN-YunxiaNeural, 若--tts 指定为 azuretts, 可以使用 Azure 语音模型 ID, 如 zh-CN-XiaoxiaoMultilingualNeural")
|
|
|
+ parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880', help="TTS server URL")
|
|
|
+ parser.add_argument('--QWEN3_TTS_MODEL_PATH', type=str, default='/home/test/Digital_Human/Qwen3-TTS-12Hz-1.7B-Base', help="Qwen3 TTS model path")
|
|
|
+ parser.add_argument('--QWEN3_TTS_LANGUAGE', type=str, default='Chinese', help="Qwen3 TTS language")
|
|
|
+ parser.add_argument('--VOXCPM2_MODEL_PATH', type=str, default='VoxCPM2', help="VoxCPM2 模型路径")
|
|
|
+ parser.add_argument('--VOXCPM2_API_URL', type=str, default='http://localhost:6003', help="VoxCPM2 API 服务地址(API 调用模式)")
|
|
|
+ parser.add_argument('--VOXCPM2_REF_WAV', type=str, default='voice_output.wav', help="VoxCPM2 参考音频路径")
|
|
|
+ parser.add_argument('--VOXCPM2_REF_TEXT', type=str, default='你好,买水果,卖水果,新鲜的水果。', help="VoxCPM2 参考文本")
|
|
|
+ parser.add_argument('--CFG_VALUE', type=float, default=2.0, help="VoxCPM2 CFG value")
|
|
|
+ parser.add_argument('--INFERENCE_TIMESTEPS', type=int, default=10, help="VoxCPM2 inference timesteps")
|
|
|
+ parser.add_argument('--REF_TEXT', type=str, default=None, help="参考文本")
|
|
|
+ # parser.add_argument('--CHARACTER', type=str, default='test')
|
|
|
+ # parser.add_argument('--EMOTION', type=str, default='default')
|
|
|
+
|
|
|
+ parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
|
|
|
+
|
|
|
+ parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam
|
|
|
+ parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
|
|
|
+
|
|
|
+ parser.add_argument('--max_session', type=int, default=1) #multi session count
|
|
|
+ parser.add_argument('--listenport', type=int, default=7868, help="web listen port")
|
|
|
+
|
|
|
+ # Avatar 目录清理配置
|
|
|
+ parser.add_argument('--cleanup_avatar_on_exit', action='store_true', default=False,
|
|
|
+ help="程序退出时自动删除本次使用的 avatar 目录")
|
|
|
+
|
|
|
+ opt = parser.parse_args()
|
|
|
+
|
|
|
+ # 设置全局清理标志
|
|
|
+ _enable_avatar_cleanup = opt.cleanup_avatar_on_exit
|
|
|
+ if _enable_avatar_cleanup:
|
|
|
+ logger.info("✅ 已启用 avatar 目录退出自动清理功能")
|
|
|
+ else:
|
|
|
+ logger.info("ℹ️ 未启用 avatar 目录自动清理(使用 --cleanup_avatar_on_exit 启用)")
|
|
|
+ #app.config.from_object(opt)
|
|
|
+ #print(app.config)
|
|
|
+ opt.customopt = []
|
|
|
+ if opt.customvideo_config!='':
|
|
|
+ with open(opt.customvideo_config,'r') as file:
|
|
|
+ opt.customopt = json.load(file)
|
|
|
+
|
|
|
+ # if opt.model == 'ernerf':
|
|
|
+ # from nerfreal import NeRFReal,load_model,load_avatar
|
|
|
+ # model = load_model(opt)
|
|
|
+ # avatar = load_avatar(opt)
|
|
|
+ if opt.model == 'musetalk':
|
|
|
+ from musereal import MuseReal,load_model,load_avatar,warm_up
|
|
|
+ logger.info(opt)
|
|
|
+ model = load_model()
|
|
|
+ avatar = load_avatar(opt.avatar_id)
|
|
|
+ warm_up(opt.batch_size,model)
|
|
|
+ elif opt.model == 'wav2lip':
|
|
|
+ from lipreal import LipReal,load_model,load_avatar,warm_up
|
|
|
+ logger.info(opt)
|
|
|
+ model = load_model("./models/wav2lip256.pth")
|
|
|
+ avatar = load_avatar(opt.avatar_id)
|
|
|
+ warm_up(opt.batch_size,model,256)
|
|
|
+ elif opt.model == 'ultralight':
|
|
|
+ from lightreal import LightReal,load_model,load_avatar,warm_up
|
|
|
+ logger.info(opt)
|
|
|
+ model = load_model(opt)
|
|
|
+ avatar = load_avatar(opt.avatar_id)
|
|
|
+ warm_up(opt.batch_size,avatar,160)
|
|
|
+
|
|
|
+ # if opt.transport=='rtmp':
|
|
|
+ # thread_quit = Event()
|
|
|
+ # nerfreals[0] = build_nerfreal(0)
|
|
|
+ # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
|
|
|
+ # rendthrd.start()
|
|
|
+ if opt.transport=='virtualcam':
|
|
|
+ thread_quit = Event()
|
|
|
+ nerfreals[0] = build_nerfreal(0)
|
|
|
+ rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
|
|
|
+ rendthrd.start()
|
|
|
+
|
|
|
+ #############################################################################
|
|
|
+ appasync = web.Application(client_max_size=1024**2*100)
|
|
|
+ appasync.on_shutdown.append(on_shutdown)
|
|
|
+ appasync.router.add_post("/offer", offer)
|
|
|
+ appasync.router.add_post("/human", human)
|
|
|
+ appasync.router.add_post("/api/chat", api_chat) # 新增独立聊天API
|
|
|
+ appasync.router.add_post("/humanaudio", humanaudio)
|
|
|
+ appasync.router.add_post("/set_audiotype", set_audiotype)
|
|
|
+ appasync.router.add_post("/record", record)
|
|
|
+ appasync.router.add_post("/interrupt_talk", interrupt_talk)
|
|
|
+ appasync.router.add_post("/is_speaking", is_speaking)
|
|
|
+ appasync.router.add_post("/knowledge_intro", knowledge_intro) # 新增知识库介绍 API
|
|
|
+ appasync.router.add_post("/resume_interrupted", resume_interrupted) # 新增恢复中断消息 API
|
|
|
+ appasync.router.add_post("/handle_user_question", handle_user_question) # 新增用户提问 API
|
|
|
+ appasync.router.add_post("/continue_after_qa", continue_after_qa) # 新增问答后继续播放 API
|
|
|
+ appasync.router.add_post("/intro_play_completed", intro_play_completed) # 新增介绍播放完成回调 API
|
|
|
+ appasync.router.add_post("/log", log) # 前端日志接口
|
|
|
+ appasync.router.add_static('/',path='web')
|
|
|
+
|
|
|
+ # Configure default CORS settings.
|
|
|
+ cors = aiohttp_cors.setup(appasync, defaults={
|
|
|
+ "*": aiohttp_cors.ResourceOptions(
|
|
|
+ allow_credentials=True,
|
|
|
+ expose_headers="*",
|
|
|
+ allow_headers="*",
|
|
|
+ )
|
|
|
+ })
|
|
|
+ # Configure CORS on all routes.
|
|
|
+ for route in list(appasync.router.routes()):
|
|
|
+ cors.add(route)
|
|
|
+
|
|
|
+ pagename='webrtcapi.html'
|
|
|
+ if opt.transport=='rtmp':
|
|
|
+ pagename='echoapi.html'
|
|
|
+ elif opt.transport=='rtcpush':
|
|
|
+ pagename='rtcpushapi.html'
|
|
|
+ logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
|
|
|
+ logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
|
|
|
+ def run_server(runner):
|
|
|
+ try:
|
|
|
+ loop = asyncio.new_event_loop()
|
|
|
+ asyncio.set_event_loop(loop)
|
|
|
+ loop.run_until_complete(runner.setup())
|
|
|
+ site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
|
|
|
+ loop.run_until_complete(site.start())
|
|
|
+ if opt.transport=='rtcpush':
|
|
|
+ for k in range(opt.max_session):
|
|
|
+ push_url = opt.push_url
|
|
|
+ if k!=0:
|
|
|
+ push_url = opt.push_url+str(k)
|
|
|
+ loop.run_until_complete(run(push_url,k))
|
|
|
+ loop.run_forever()
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ logger.info("收到 KeyboardInterrupt,正在退出...")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"服务器运行出错: {e}")
|
|
|
+ finally:
|
|
|
+ # 服务器退出时清理 avatar 目录
|
|
|
+ cleanup_avatar_directories()
|
|
|
+ #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
|
|
|
+ run_server(web.AppRunner(appasync))
|
|
|
+
|
|
|
+ #app.on_shutdown.append(on_shutdown)
|
|
|
+ #app.router.add_post("/offer", offer)
|
|
|
+
|
|
|
+ # print('start websocket server')
|
|
|
+ # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
|
|
|
+ # server.serve_forever()
|
|
|
+
|
|
|
+
|