| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075 |
- # server.py
- import flask
- from flask import Flask, render_template,send_from_directory,request, jsonify
- from flask_sockets import Sockets
- import base64
- import json
- #import gevent
- #from gevent import pywsgi
- #from geventwebsocket.handler import WebSocketHandler
- import re
- import numpy as np
- from threading import Thread,Event
- #import multiprocessing
- # 禁用 TorchDynamo 编译(避免 VoxCPM2 兼容性问题)
- import os
- os.environ['TORCHDYNAMO_DISABLE'] = '1'
- import torch
- if hasattr(torch, '_dynamo'):
- torch._dynamo.config.suppress_errors = True
- import torch.multiprocessing as mp
- from aiohttp import web
- import aiohttp
- import aiohttp_cors
- from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration
- from aiortc.rtcrtpsender import RTCRtpSender
- from webrtc import HumanPlayer
- from basereal import BaseReal
- from llm import llm_response
- import argparse
- import random
- import shutil
- import asyncio
- import torch
- import os
- import socket
- import signal
- import sys
- from typing import Dict
- from logger import logger
- import gc
- app = Flask(__name__)
- #sockets = Sockets(app)
- nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
- opt = None
- model = None
- avatar = None
-
- #####webrtc###############################
- pcs = set()
- # 记录本次运行创建的 avatar 目录,用于退出时清理
- _avatar_dirs_to_clean = set()
- _enable_avatar_cleanup = False
- def randN(N)->int:
- '''生成长度为 N的随机数 '''
- min = pow(10, N - 1)
- max = pow(10, N)
- return random.randint(min, max - 1)
- def build_nerfreal(sessionid:int)->BaseReal:
- opt.sessionid=sessionid
- # 记录本次运行使用的 avatar 目录
- if hasattr(opt, 'avatar_id') and opt.avatar_id:
- avatar_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'avatars', opt.avatar_id)
- if os.path.exists(avatar_dir):
- _avatar_dirs_to_clean.add(avatar_dir)
- logger.info(f'记录 avatar 目录用于退出时清理: {avatar_dir}')
-
- if opt.model == 'wav2lip':
- from lipreal import LipReal
- nerfreal = LipReal(opt,model,avatar)
- elif opt.model == 'musetalk':
- from musereal import MuseReal
- nerfreal = MuseReal(opt,model,avatar)
- # elif opt.model == 'ernerf':
- # from nerfreal import NeRFReal
- # nerfreal = NeRFReal(opt,model,avatar)
- elif opt.model == 'ultralight':
- from lightreal import LightReal
- nerfreal = LightReal(opt,model,avatar)
- else:
- raise ValueError(f"Unsupported model type: {opt.model}")
- return nerfreal
- #@app.route('/offer', methods=['POST'])
- async def offer(request):
- try:
- params = await request.json()
- offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
- # if len(nerfreals) >= opt.max_session:
- # logger.info('reach max session')
- # return web.Response(
- # content_type="application/json",
- # text=json.dumps(
- # {"code": -1, "msg": "reach max session"}
- # ),
- # )
- sessionid = randN(6) #len(nerfreals)
- nerfreals[sessionid] = None
- logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals))
- nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
- nerfreals[sessionid] = nerfreal
-
- # 内网/跳板机环境:不使用外网 STUN(内网访问不到),只用 host 候选
- # 通过环境变量 WEBRTC_NAT_IP 指定跳板机对外暴露的 IP(浏览器能访问的 IP)
- nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
- if nat_public_ip:
- logger.info('Using NAT public IP for ICE: %s', nat_public_ip)
- ice_servers = [] # 内网不用 STUN
- else:
- # 使用多个 STUN 服务器(包括国内和国外的)
- ice_servers = [
- RTCIceServer(urls='stun:stun.l.google.com:19302'),
- RTCIceServer(urls='stun:stun1.l.google.com:19302'),
- RTCIceServer(urls='stun:stun2.l.google.com:19302'),
- RTCIceServer(urls='stun:stun3.l.google.com:19302'),
- RTCIceServer(urls='stun:stun4.l.google.com:19302'),
- ]
- pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers))
- pcs.add(pc)
- @pc.on("connectionstatechange")
- async def on_connectionstatechange():
- logger.info(f"Session {sessionid} - Connection state is {pc.connectionState}")
- if pc.connectionState == "connected":
- logger.info(f"Session {sessionid} - WebRTC connection established successfully!")
- if pc.connectionState == "failed":
- logger.error(f"Session {sessionid} - Connection failed!")
- await pc.close()
- pcs.discard(pc)
- if sessionid in nerfreals:
- del nerfreals[sessionid]
- if pc.connectionState == "closed":
- logger.info(f"Session {sessionid} - Connection closed")
- pcs.discard(pc)
- if sessionid in nerfreals:
- del nerfreals[sessionid]
- # gc.collect()
- player = HumanPlayer(nerfreals[sessionid])
- audio_sender = pc.addTrack(player.audio)
- video_sender = pc.addTrack(player.video)
-
- # 记录轨道添加信息
- logger.info(f"Added tracks for session {sessionid}: audio={player.audio}, video={player.video}")
-
- capabilities = RTCRtpSender.getCapabilities("video")
- preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
- preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
- preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
- transceiver = pc.getTransceivers()[1]
- transceiver.setCodecPreferences(preferences)
- await pc.setRemoteDescription(offer)
- answer = await pc.createAnswer()
- await pc.setLocalDescription(answer)
- sdp = pc.localDescription.sdp
- # 跳板机内网场景:如果配置了 WEBRTC_NAT_IP,将 SDP 中的内网 IP 替换为跳板机对外 IP
- # aioice 已经将真实绑定的端口(AIOICE_PORT_MIN~MAX 范围内)写入 SDP 的 a=candidate 行
- # 我们只需要把内网 IP 替换成浏览器能访问的 IP 即可
- nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
- if nat_public_ip:
- try:
- # 获取本机所有非 loopback IPv4 地址
- local_ips = set()
- bind_ip = os.environ.get('AIOICE_BIND_IP', '')
- if bind_ip:
- local_ips.add(bind_ip)
- else:
- hostname = socket.gethostname()
- for ip_info in socket.getaddrinfo(hostname, None, socket.AF_INET):
- ip = ip_info[4][0]
- if not ip.startswith('127.'):
- local_ips.add(ip)
- # 也枚举所有网卡 IP(getaddrinfo 可能不全)
- try:
- import ifaddr
- for adapter in ifaddr.get_adapters():
- for ip_obj in adapter.ips:
- if isinstance(ip_obj.ip, str) and not ip_obj.ip.startswith('127.'):
- local_ips.add(ip_obj.ip)
- except Exception:
- pass
- # 将 SDP 中所有内网 IP 替换为跳板机对外 IP
- for local_ip in local_ips:
- sdp = sdp.replace(local_ip, nat_public_ip)
- logger.info('NAT mapping: %s -> %s in SDP', local_ips, nat_public_ip)
- except Exception as e:
- logger.warning('Failed to apply NAT IP mapping to SDP: %s', e)
- else:
- logger.info('WEBRTC_NAT_IP not set. SDP uses server internal IPs. '
- 'Set WEBRTC_NAT_IP=<jumphost_ip> if browser cannot reach server directly.')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"sdp": sdp, "type": pc.localDescription.type, "sessionid": sessionid}
- ),
- )
- except Exception as e:
- logger.exception('Error in offer:')
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": -1, "msg": str(e)}),
- status=500
- )
- async def log(request):
- """接收前端日志的接口"""
- try:
- params = await request.json()
- log_type = params.get('type', 'info')
- message = params.get('message', '')
- logger.info('[WEBRTC] %s', message)
- except Exception:
- pass
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": 0}),
- )
- async def human(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 标记是否已经处理了聊天请求
- chat_processed = False
-
- if params.get('interrupt'):
- # 立即中断当前播放,无论是否是聊天请求
- nerfreals[sessionid].flush_talk()
- # 检查是否是用户在介绍过程中提问
- if params['type'] == 'chat':
- knowledge_base_type = params.get('knowledge_base', None)
- during_intro = params.get('during_intro', False)
- # 处理用户问题(异步执行,避免阻塞事件循环)
- asyncio.get_event_loop().run_in_executor(
- None,
- llm_response,
- params['text'],
- nerfreals[sessionid],
- knowledge_base_type,
- during_intro
- )
- chat_processed = True
- if params['type']=='echo':
- nerfreals[sessionid].put_msg_txt(params['text'])
- elif params['type']=='chat' and not chat_processed:
- # 如果没有中断标志,或者已经处理过聊天请求,则不再处理
- knowledge_base_type = params.get('knowledge_base', None)
- during_intro = params.get('during_intro', False) # 是否在介绍过程中
- asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
- #nerfreals[sessionid].put_msg_txt(res)
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def api_chat(request):
- """
- 独立聊天API - 让数字人说话/问答
- 请求体: {
- "text": "用户说的话",
- "type": "chat", // "chat"=问答, "echo"=纯播报
- "interrupt": true, // 是否打断当前播放
- "sessionid": 123456, // WebRTC session ID
- "knowledge_base": "", // 可选: 指定知识库
- "during_intro": false // 可选: 是否在介绍过程中
- }
- """
- try:
- params = await request.json()
-
- # 参数验证
- text = params.get('text', '').strip()
- if not text:
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": -1, "msg": "text parameter is required"}),
- )
-
- sessionid = params.get('sessionid', 0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- msg_type = params.get('type', 'chat')
- interrupt = params.get('interrupt', True)
- knowledge_base = params.get('knowledge_base', None)
- during_intro = params.get('during_intro', False)
-
- logger.info(f"API Chat - Session {sessionid}, Type: {msg_type}, Text: {text[:50]}...")
-
- # 处理中断
- if interrupt:
- nerfreals[sessionid].flush_talk()
-
- # 根据类型处理
- if msg_type == 'echo':
- # 纯播报模式 - 直接播放文本
- nerfreals[sessionid].put_msg_txt(text)
- else:
- # 聊天/问答模式 - 走LLM
- asyncio.get_event_loop().run_in_executor(
- None,
- llm_response,
- text,
- nerfreals[sessionid],
- knowledge_base,
- during_intro
- )
-
- return web.Response(
- content_type="application/json",
- text=json.dumps({
- "code": 0,
- "msg": "ok",
- "data": {
- "sessionid": sessionid,
- "type": msg_type,
- "text": text[:100] # 返回前100字符用于确认
- }
- }),
- )
-
- except Exception as e:
- logger.exception('API Chat exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": -1, "msg": str(e)}),
- )
- async def interrupt_talk(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- nerfreals[sessionid].flush_talk()
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def humanaudio(request):
- try:
- form= await request.post()
- sessionid = int(form.get('sessionid',0))
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- fileobj = form["file"]
- filename=fileobj.filename
- filebytes=fileobj.file.read()
- nerfreals[sessionid].put_audio_file(filebytes)
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def set_audiotype(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def record(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- if params['type']=='start_record':
- # nerfreals[sessionid].put_msg_txt(params['text'])
- nerfreals[sessionid].start_recording()
- elif params['type']=='end_record':
- nerfreals[sessionid].stop_recording()
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def is_speaking(request):
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "data": nerfreals[sessionid].is_speaking()}
- )
- )
- async def knowledge_intro(request):
- """返回知识库介绍内容"""
- try:
- from knowledge_intro import start_intro_play, knowledge_intro
- # 启动完整版介绍播放,获取第一条文案
- play_result = start_intro_play("full")
- intro_text = play_result.get("text", "")
-
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 保存介绍实例到会话中,便于后续操作
- if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- nerfreals[sessionid].knowledge_intro_instance = knowledge_intro
-
- # 保存介绍播放状态到会话中
- if not hasattr(nerfreals[sessionid], 'intro_play_state'):
- nerfreals[sessionid].intro_play_state = {
- "is_playing": True,
- "current_type": "full",
- "last_played_index": 0,
- "is_paused": False,
- "is_waiting_next": False
- }
-
- # 使用支持打断恢复的新方法来播放介绍内容
- nerfreals[sessionid].start_intro_with_interrupt_capability(intro_text)
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Knowledge intro played successfully", "text": intro_text, "mode": "intro", "play_index": play_result.get("play_index", 1), "total_count": play_result.get("total_count", 8)}
- ),
- )
- except Exception as e:
- logger.exception('exception in knowledge_intro:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def resume_interrupted(request):
- """恢复播放被中断的消息"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 尝试恢复被中断的消息
- resumed = nerfreals[sessionid].resume_interrupted()
-
- if resumed:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Interrupted messages resumed successfully", "mode": "intro"}
- ),
- )
- else:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "No interrupted messages to resume"}
- ),
- )
- except Exception as e:
- logger.exception('exception in resume_interrupted:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def handle_user_question(request):
- """处理用户提问,暂停当前内容并优先回答问题"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
- question = params.get('text', '')
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- if not question:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": "Question text is required"}
- ),
- )
-
- # 检查是否在介绍过程中,如果是,则标记
- during_intro = params.get('during_intro', False)
- if during_intro and hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- # 暂停介绍播放
- nerfreals[sessionid].knowledge_intro_instance.pause_play()
-
- # 使用高优先级方法处理用户问题
- knowledge_base_type = params.get('knowledge_base', None)
- # 直接调用llm_response,它会使用put_user_question方法
- asyncio.get_event_loop().run_in_executor(None, llm_response, question, nerfreals[sessionid], knowledge_base_type, during_intro)
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "User question processed successfully", "mode": "qa"}
- ),
- )
- except Exception as e:
- logger.exception('exception in handle_user_question:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def continue_after_qa(request):
- """在用户问题回答完毕后,继续播放之前的内容"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 尝试恢复被中断的内容
- if hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- # 从knowledge_intro_instance获取下一条内容
- next_content = nerfreals[sessionid].knowledge_intro_instance.resume_play()
- if next_content:
- # 播放下一条介绍内容
- nerfreals[sessionid].put_msg_txt(next_content['text'])
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Previously interrupted content resumed successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8)}
- ),
- )
-
- # 如果没有找到knowledge_intro_instance或没有剩余内容
- resumed = nerfreals[sessionid].resume_interrupted()
-
- if resumed:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Previously interrupted content resumed successfully", "mode": "intro"}
- ),
- )
- else:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "No interrupted content to resume"}
- ),
- )
- except Exception as e:
- logger.exception('exception in continue_after_qa:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def intro_play_completed(request):
- """处理介绍播放完成的回调,自动播放下一条介绍内容"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 检查是否有介绍实例
- if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": "Knowledge intro instance not found"}
- ),
- )
-
- # 检查介绍播放状态
- if hasattr(nerfreals[sessionid], 'intro_play_state') and not nerfreals[sessionid].intro_play_state.get("is_playing", True):
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Introduction playback is paused"}
- ),
- )
-
- # 获取下一条介绍内容
- next_content = nerfreals[sessionid].knowledge_intro_instance._get_next_content()
- if next_content:
- # 播放下一条介绍内容
- nerfreals[sessionid].put_msg_txt(next_content['text'])
-
- # 更新播放状态
- if hasattr(nerfreals[sessionid], 'intro_play_state'):
- nerfreals[sessionid].intro_play_state["last_played_index"] = next_content.get("play_index", 1)
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Next introduction content played successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8), "is_last": next_content.get("is_last", False)}
- ),
- )
- else:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "No more introduction content"}
- ),
- )
- except Exception as e:
- logger.exception('exception in intro_play_completed:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def on_shutdown(app):
- """关闭时清理所有资源"""
- logger.info("开始清理所有资源...")
-
- # 1. 关闭所有 WebRTC 连接
- coros = [pc.close() for pc in pcs]
- await asyncio.gather(*coros, return_exceptions=True)
- pcs.clear()
-
- # 2. 清理所有数字人实例
- for sessionid, nerfreal in nerfreals.items():
- try:
- if nerfreal is not None:
- logger.info(f"清理 session {sessionid} 的资源")
- # 停止 TTS
- if hasattr(nerfreal, 'tts') and nerfreal.tts:
- nerfreal.tts.state = State.PAUSE
- # 清理队列
- if hasattr(nerfreal, 'msg_queue'):
- with nerfreal.msg_queue.mutex:
- nerfreal.msg_queue.queue.clear()
- if hasattr(nerfreal, 'interrupted_queue'):
- with nerfreal.interrupted_queue.mutex:
- nerfreal.interrupted_queue.queue.clear()
- except Exception as e:
- logger.error(f"清理 session {sessionid} 时出错: {e}")
-
- # 3. 清理数字人实例字典
- nerfreals.clear()
-
- # 4. 强制垃圾回收
- import gc
- gc.collect()
-
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- logger.info(f"清理后 GPU 显存: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
-
- logger.info("所有资源清理完成")
- def cleanup_avatar_directories():
- """程序退出时清理 avatar 目录"""
- global _avatar_dirs_to_clean, _enable_avatar_cleanup
-
- if not _enable_avatar_cleanup:
- logger.info("未启用 avatar 目录自动清理功能")
- return
-
- if not _avatar_dirs_to_clean:
- logger.info("没有需要清理的 avatar 目录")
- return
-
- logger.info(f"开始清理 {len(_avatar_dirs_to_clean)} 个 avatar 目录...")
-
- for avatar_dir in _avatar_dirs_to_clean:
- try:
- if os.path.exists(avatar_dir):
- logger.info(f"正在删除 avatar 目录: {avatar_dir}")
- shutil.rmtree(avatar_dir)
- logger.info(f"✅ 已删除: {avatar_dir}")
- else:
- logger.info(f"目录不存在,无需删除: {avatar_dir}")
- except Exception as e:
- logger.error(f"❌ 删除 avatar 目录失败 {avatar_dir}: {e}")
-
- _avatar_dirs_to_clean.clear()
- logger.info("所有 avatar 目录清理完成")
- def signal_handler(signum, frame):
- """信号处理器 - 捕获退出信号"""
- logger.info(f"收到退出信号 {signum},开始清理...")
- cleanup_avatar_directories()
- sys.exit(0)
- async def post(url,data):
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url,data=data) as response:
- return await response.text()
- except aiohttp.ClientError as e:
- logger.error(f'POST请求错误: {e}')
- return None # 明确返回None
- async def run(push_url,sessionid):
- nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
- nerfreals[sessionid] = nerfreal
- # RTMP 推流模式:跳过 WHIP 连接,使用 basereal.py 中的 ffmpeg 管道推流
- if push_url.startswith('rtmp://'):
- logger.info(f'RTMP 推流模式: {push_url},跳过 WHIP 连接,推流将在 render 时自动启动')
- # 需要启动渲染循环以产生视频/音频帧
- # 使用 HumanPlayer 触发 render,但不建立 WebRTC 连接
- player = HumanPlayer(nerfreals[sessionid])
- # 手动启动 player worker thread 来触发 nerfreal.render()
- # HumanPlayer._start 会在 track 被 recv 时调用,这里直接启动
- from threading import Event as ThreadEvent
- render_quit_event = ThreadEvent()
- # 直接启动 render 线程,不需要 WebRTC track
- import threading
- def rtmp_render_loop():
- nerfreals[sessionid].render(render_quit_event, loop=None, audio_track=None, video_track=None)
- render_thread = threading.Thread(target=rtmp_render_loop, daemon=True, name='rtmp_render')
- render_thread.start()
- logger.info('RTMP 渲染线程已启动')
- return
-
- # WebRTC WHIP 推流(原有逻辑)
- pc = RTCPeerConnection()
- pcs.add(pc)
- @pc.on("connectionstatechange")
- async def on_connectionstatechange():
- logger.info("Connection state is %s" % pc.connectionState)
- if pc.connectionState == "failed":
- await pc.close()
- pcs.discard(pc)
- player = HumanPlayer(nerfreals[sessionid])
- audio_sender = pc.addTrack(player.audio)
- video_sender = pc.addTrack(player.video)
- await pc.setLocalDescription(await pc.createOffer())
- answer = await post(push_url,pc.localDescription.sdp)
-
- # 检查POST请求是否成功
- if answer is None:
- logger.error(f'推流失败: 无法连接到 {push_url}')
- await pc.close()
- pcs.discard(pc)
- return
-
- await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
- ##########################################
- # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
- # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
- if __name__ == '__main__':
- # 注册信号处理器
- signal.signal(signal.SIGINT, signal_handler)
- signal.signal(signal.SIGTERM, signal_handler)
-
- mp.set_start_method('spawn')
- parser = argparse.ArgumentParser()
-
- # audio FPS
- parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
- # sliding window left-middle-right length (unit: 20ms)
- parser.add_argument('-l', type=int, default=10)
- parser.add_argument('-m', type=int, default=8)
- parser.add_argument('-r', type=int, default=10)
- parser.add_argument('--W', type=int, default=450, help="GUI width")
- parser.add_argument('--H', type=int, default=450, help="GUI height")
- #musetalk opt
- parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars")
- #parser.add_argument('--bbox_shift', type=int, default=5)
- parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
- parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
- parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice fishtts tencent doubao indextts2 azuretts qwen3tts voxcpm2api
- parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural",help="参考文件名或语音模型 ID,默认值为 edgetts 的语音模型 ID zh-CN-YunxiaNeural, 若--tts 指定为 azuretts, 可以使用 Azure 语音模型 ID, 如 zh-CN-XiaoxiaoMultilingualNeural")
- parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880', help="TTS server URL")
- parser.add_argument('--QWEN3_TTS_MODEL_PATH', type=str, default='/home/test/Digital_Human/Qwen3-TTS-12Hz-1.7B-Base', help="Qwen3 TTS model path")
- parser.add_argument('--QWEN3_TTS_LANGUAGE', type=str, default='Chinese', help="Qwen3 TTS language")
- parser.add_argument('--VOXCPM2_MODEL_PATH', type=str, default='VoxCPM2', help="VoxCPM2 模型路径")
- parser.add_argument('--VOXCPM2_API_URL', type=str, default='http://localhost:6003', help="VoxCPM2 API 服务地址(API 调用模式)")
- parser.add_argument('--VOXCPM2_REF_WAV', type=str, default='voice_output.wav', help="VoxCPM2 参考音频路径")
- parser.add_argument('--VOXCPM2_REF_TEXT', type=str, default='你好,买水果,卖水果,新鲜的水果。', help="VoxCPM2 参考文本")
- parser.add_argument('--CFG_VALUE', type=float, default=2.0, help="VoxCPM2 CFG value")
- parser.add_argument('--INFERENCE_TIMESTEPS', type=int, default=10, help="VoxCPM2 inference timesteps")
- parser.add_argument('--REF_TEXT', type=str, default=None, help="参考文本")
- # parser.add_argument('--CHARACTER', type=str, default='test')
- # parser.add_argument('--EMOTION', type=str, default='default')
- parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
- parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam
- parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
- parser.add_argument('--max_session', type=int, default=1) #multi session count
- parser.add_argument('--listenport', type=int, default=7868, help="web listen port")
-
- # Avatar 目录清理配置
- parser.add_argument('--cleanup_avatar_on_exit', action='store_true', default=False,
- help="程序退出时自动删除本次使用的 avatar 目录")
- opt = parser.parse_args()
-
- # 设置全局清理标志
- _enable_avatar_cleanup = opt.cleanup_avatar_on_exit
- if _enable_avatar_cleanup:
- logger.info("✅ 已启用 avatar 目录退出自动清理功能")
- else:
- logger.info("ℹ️ 未启用 avatar 目录自动清理(使用 --cleanup_avatar_on_exit 启用)")
- #app.config.from_object(opt)
- #print(app.config)
- opt.customopt = []
- if opt.customvideo_config!='':
- with open(opt.customvideo_config,'r') as file:
- opt.customopt = json.load(file)
- # if opt.model == 'ernerf':
- # from nerfreal import NeRFReal,load_model,load_avatar
- # model = load_model(opt)
- # avatar = load_avatar(opt)
- if opt.model == 'musetalk':
- from musereal import MuseReal,load_model,load_avatar,warm_up
- logger.info(opt)
- model = load_model()
- avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,model)
- elif opt.model == 'wav2lip':
- from lipreal import LipReal,load_model,load_avatar,warm_up
- logger.info(opt)
- model = load_model("./models/wav2lip256.pth")
- avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,model,256)
- elif opt.model == 'ultralight':
- from lightreal import LightReal,load_model,load_avatar,warm_up
- logger.info(opt)
- model = load_model(opt)
- avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,avatar,160)
- # if opt.transport=='rtmp':
- # thread_quit = Event()
- # nerfreals[0] = build_nerfreal(0)
- # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
- # rendthrd.start()
- if opt.transport=='virtualcam':
- thread_quit = Event()
- nerfreals[0] = build_nerfreal(0)
- rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
- rendthrd.start()
- #############################################################################
- appasync = web.Application(client_max_size=1024**2*100)
- appasync.on_shutdown.append(on_shutdown)
- appasync.router.add_post("/offer", offer)
- appasync.router.add_post("/human", human)
- appasync.router.add_post("/api/chat", api_chat) # 新增独立聊天API
- appasync.router.add_post("/humanaudio", humanaudio)
- appasync.router.add_post("/set_audiotype", set_audiotype)
- appasync.router.add_post("/record", record)
- appasync.router.add_post("/interrupt_talk", interrupt_talk)
- appasync.router.add_post("/is_speaking", is_speaking)
- appasync.router.add_post("/knowledge_intro", knowledge_intro) # 新增知识库介绍 API
- appasync.router.add_post("/resume_interrupted", resume_interrupted) # 新增恢复中断消息 API
- appasync.router.add_post("/handle_user_question", handle_user_question) # 新增用户提问 API
- appasync.router.add_post("/continue_after_qa", continue_after_qa) # 新增问答后继续播放 API
- appasync.router.add_post("/intro_play_completed", intro_play_completed) # 新增介绍播放完成回调 API
- appasync.router.add_post("/log", log) # 前端日志接口
- appasync.router.add_static('/',path='web')
- # Configure default CORS settings.
- cors = aiohttp_cors.setup(appasync, defaults={
- "*": aiohttp_cors.ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_headers="*",
- )
- })
- # Configure CORS on all routes.
- for route in list(appasync.router.routes()):
- cors.add(route)
- pagename='webrtcapi.html'
- if opt.transport=='rtmp':
- pagename='echoapi.html'
- elif opt.transport=='rtcpush':
- pagename='rtcpushapi.html'
- logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
- logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
- def run_server(runner):
- try:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(runner.setup())
- site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
- loop.run_until_complete(site.start())
- if opt.transport=='rtcpush':
- for k in range(opt.max_session):
- push_url = opt.push_url
- if k!=0:
- push_url = opt.push_url+str(k)
- loop.run_until_complete(run(push_url,k))
- loop.run_forever()
- except KeyboardInterrupt:
- logger.info("收到 KeyboardInterrupt,正在退出...")
- except Exception as e:
- logger.error(f"服务器运行出错: {e}")
- finally:
- # 服务器退出时清理 avatar 目录
- cleanup_avatar_directories()
- #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
- run_server(web.AppRunner(appasync))
- #app.on_shutdown.append(on_shutdown)
- #app.router.add_post("/offer", offer)
- # print('start websocket server')
- # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
- # server.serve_forever()
-
-
|