app.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075
  1. # server.py
  2. import flask
  3. from flask import Flask, render_template,send_from_directory,request, jsonify
  4. from flask_sockets import Sockets
  5. import base64
  6. import json
  7. #import gevent
  8. #from gevent import pywsgi
  9. #from geventwebsocket.handler import WebSocketHandler
  10. import re
  11. import numpy as np
  12. from threading import Thread,Event
  13. #import multiprocessing
  14. # 禁用 TorchDynamo 编译(避免 VoxCPM2 兼容性问题)
  15. import os
  16. os.environ['TORCHDYNAMO_DISABLE'] = '1'
  17. import torch
  18. if hasattr(torch, '_dynamo'):
  19. torch._dynamo.config.suppress_errors = True
  20. import torch.multiprocessing as mp
  21. from aiohttp import web
  22. import aiohttp
  23. import aiohttp_cors
  24. from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration
  25. from aiortc.rtcrtpsender import RTCRtpSender
  26. from webrtc import HumanPlayer
  27. from basereal import BaseReal
  28. from llm import llm_response
  29. import argparse
  30. import random
  31. import shutil
  32. import asyncio
  33. import torch
  34. import os
  35. import socket
  36. import signal
  37. import sys
  38. from typing import Dict
  39. from logger import logger
  40. import gc
  41. app = Flask(__name__)
  42. #sockets = Sockets(app)
  43. nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
  44. opt = None
  45. model = None
  46. avatar = None
  47. #####webrtc###############################
  48. pcs = set()
  49. # 记录本次运行创建的 avatar 目录,用于退出时清理
  50. _avatar_dirs_to_clean = set()
  51. _enable_avatar_cleanup = False
  52. def randN(N)->int:
  53. '''生成长度为 N的随机数 '''
  54. min = pow(10, N - 1)
  55. max = pow(10, N)
  56. return random.randint(min, max - 1)
  57. def build_nerfreal(sessionid:int)->BaseReal:
  58. opt.sessionid=sessionid
  59. # 记录本次运行使用的 avatar 目录
  60. if hasattr(opt, 'avatar_id') and opt.avatar_id:
  61. avatar_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'avatars', opt.avatar_id)
  62. if os.path.exists(avatar_dir):
  63. _avatar_dirs_to_clean.add(avatar_dir)
  64. logger.info(f'记录 avatar 目录用于退出时清理: {avatar_dir}')
  65. if opt.model == 'wav2lip':
  66. from lipreal import LipReal
  67. nerfreal = LipReal(opt,model,avatar)
  68. elif opt.model == 'musetalk':
  69. from musereal import MuseReal
  70. nerfreal = MuseReal(opt,model,avatar)
  71. # elif opt.model == 'ernerf':
  72. # from nerfreal import NeRFReal
  73. # nerfreal = NeRFReal(opt,model,avatar)
  74. elif opt.model == 'ultralight':
  75. from lightreal import LightReal
  76. nerfreal = LightReal(opt,model,avatar)
  77. else:
  78. raise ValueError(f"Unsupported model type: {opt.model}")
  79. return nerfreal
  80. #@app.route('/offer', methods=['POST'])
  81. async def offer(request):
  82. try:
  83. params = await request.json()
  84. offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
  85. # if len(nerfreals) >= opt.max_session:
  86. # logger.info('reach max session')
  87. # return web.Response(
  88. # content_type="application/json",
  89. # text=json.dumps(
  90. # {"code": -1, "msg": "reach max session"}
  91. # ),
  92. # )
  93. sessionid = randN(6) #len(nerfreals)
  94. nerfreals[sessionid] = None
  95. logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals))
  96. nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
  97. nerfreals[sessionid] = nerfreal
  98. # 内网/跳板机环境:不使用外网 STUN(内网访问不到),只用 host 候选
  99. # 通过环境变量 WEBRTC_NAT_IP 指定跳板机对外暴露的 IP(浏览器能访问的 IP)
  100. nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
  101. if nat_public_ip:
  102. logger.info('Using NAT public IP for ICE: %s', nat_public_ip)
  103. ice_servers = [] # 内网不用 STUN
  104. else:
  105. # 使用多个 STUN 服务器(包括国内和国外的)
  106. ice_servers = [
  107. RTCIceServer(urls='stun:stun.l.google.com:19302'),
  108. RTCIceServer(urls='stun:stun1.l.google.com:19302'),
  109. RTCIceServer(urls='stun:stun2.l.google.com:19302'),
  110. RTCIceServer(urls='stun:stun3.l.google.com:19302'),
  111. RTCIceServer(urls='stun:stun4.l.google.com:19302'),
  112. ]
  113. pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers))
  114. pcs.add(pc)
  115. @pc.on("connectionstatechange")
  116. async def on_connectionstatechange():
  117. logger.info(f"Session {sessionid} - Connection state is {pc.connectionState}")
  118. if pc.connectionState == "connected":
  119. logger.info(f"Session {sessionid} - WebRTC connection established successfully!")
  120. if pc.connectionState == "failed":
  121. logger.error(f"Session {sessionid} - Connection failed!")
  122. await pc.close()
  123. pcs.discard(pc)
  124. if sessionid in nerfreals:
  125. del nerfreals[sessionid]
  126. if pc.connectionState == "closed":
  127. logger.info(f"Session {sessionid} - Connection closed")
  128. pcs.discard(pc)
  129. if sessionid in nerfreals:
  130. del nerfreals[sessionid]
  131. # gc.collect()
  132. player = HumanPlayer(nerfreals[sessionid])
  133. audio_sender = pc.addTrack(player.audio)
  134. video_sender = pc.addTrack(player.video)
  135. # 记录轨道添加信息
  136. logger.info(f"Added tracks for session {sessionid}: audio={player.audio}, video={player.video}")
  137. capabilities = RTCRtpSender.getCapabilities("video")
  138. preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
  139. preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
  140. preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
  141. transceiver = pc.getTransceivers()[1]
  142. transceiver.setCodecPreferences(preferences)
  143. await pc.setRemoteDescription(offer)
  144. answer = await pc.createAnswer()
  145. await pc.setLocalDescription(answer)
  146. sdp = pc.localDescription.sdp
  147. # 跳板机内网场景:如果配置了 WEBRTC_NAT_IP,将 SDP 中的内网 IP 替换为跳板机对外 IP
  148. # aioice 已经将真实绑定的端口(AIOICE_PORT_MIN~MAX 范围内)写入 SDP 的 a=candidate 行
  149. # 我们只需要把内网 IP 替换成浏览器能访问的 IP 即可
  150. nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
  151. if nat_public_ip:
  152. try:
  153. # 获取本机所有非 loopback IPv4 地址
  154. local_ips = set()
  155. bind_ip = os.environ.get('AIOICE_BIND_IP', '')
  156. if bind_ip:
  157. local_ips.add(bind_ip)
  158. else:
  159. hostname = socket.gethostname()
  160. for ip_info in socket.getaddrinfo(hostname, None, socket.AF_INET):
  161. ip = ip_info[4][0]
  162. if not ip.startswith('127.'):
  163. local_ips.add(ip)
  164. # 也枚举所有网卡 IP(getaddrinfo 可能不全)
  165. try:
  166. import ifaddr
  167. for adapter in ifaddr.get_adapters():
  168. for ip_obj in adapter.ips:
  169. if isinstance(ip_obj.ip, str) and not ip_obj.ip.startswith('127.'):
  170. local_ips.add(ip_obj.ip)
  171. except Exception:
  172. pass
  173. # 将 SDP 中所有内网 IP 替换为跳板机对外 IP
  174. for local_ip in local_ips:
  175. sdp = sdp.replace(local_ip, nat_public_ip)
  176. logger.info('NAT mapping: %s -> %s in SDP', local_ips, nat_public_ip)
  177. except Exception as e:
  178. logger.warning('Failed to apply NAT IP mapping to SDP: %s', e)
  179. else:
  180. logger.info('WEBRTC_NAT_IP not set. SDP uses server internal IPs. '
  181. 'Set WEBRTC_NAT_IP=<jumphost_ip> if browser cannot reach server directly.')
  182. return web.Response(
  183. content_type="application/json",
  184. text=json.dumps(
  185. {"sdp": sdp, "type": pc.localDescription.type, "sessionid": sessionid}
  186. ),
  187. )
  188. except Exception as e:
  189. logger.exception('Error in offer:')
  190. return web.Response(
  191. content_type="application/json",
  192. text=json.dumps({"code": -1, "msg": str(e)}),
  193. status=500
  194. )
  195. async def log(request):
  196. """接收前端日志的接口"""
  197. try:
  198. params = await request.json()
  199. log_type = params.get('type', 'info')
  200. message = params.get('message', '')
  201. logger.info('[WEBRTC] %s', message)
  202. except Exception:
  203. pass
  204. return web.Response(
  205. content_type="application/json",
  206. text=json.dumps({"code": 0}),
  207. )
  208. async def human(request):
  209. try:
  210. params = await request.json()
  211. sessionid = params.get('sessionid',0)
  212. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  213. return web.Response(
  214. content_type="application/json",
  215. text=json.dumps(
  216. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  217. ),
  218. )
  219. # 标记是否已经处理了聊天请求
  220. chat_processed = False
  221. if params.get('interrupt'):
  222. # 立即中断当前播放,无论是否是聊天请求
  223. nerfreals[sessionid].flush_talk()
  224. # 检查是否是用户在介绍过程中提问
  225. if params['type'] == 'chat':
  226. knowledge_base_type = params.get('knowledge_base', None)
  227. during_intro = params.get('during_intro', False)
  228. # 处理用户问题(异步执行,避免阻塞事件循环)
  229. asyncio.get_event_loop().run_in_executor(
  230. None,
  231. llm_response,
  232. params['text'],
  233. nerfreals[sessionid],
  234. knowledge_base_type,
  235. during_intro
  236. )
  237. chat_processed = True
  238. if params['type']=='echo':
  239. nerfreals[sessionid].put_msg_txt(params['text'])
  240. elif params['type']=='chat' and not chat_processed:
  241. # 如果没有中断标志,或者已经处理过聊天请求,则不再处理
  242. knowledge_base_type = params.get('knowledge_base', None)
  243. during_intro = params.get('during_intro', False) # 是否在介绍过程中
  244. asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
  245. #nerfreals[sessionid].put_msg_txt(res)
  246. return web.Response(
  247. content_type="application/json",
  248. text=json.dumps(
  249. {"code": 0, "msg":"ok"}
  250. ),
  251. )
  252. except Exception as e:
  253. logger.exception('exception:')
  254. return web.Response(
  255. content_type="application/json",
  256. text=json.dumps(
  257. {"code": -1, "msg": str(e)}
  258. ),
  259. )
  260. async def api_chat(request):
  261. """
  262. 独立聊天API - 让数字人说话/问答
  263. 请求体: {
  264. "text": "用户说的话",
  265. "type": "chat", // "chat"=问答, "echo"=纯播报
  266. "interrupt": true, // 是否打断当前播放
  267. "sessionid": 123456, // WebRTC session ID
  268. "knowledge_base": "", // 可选: 指定知识库
  269. "during_intro": false // 可选: 是否在介绍过程中
  270. }
  271. """
  272. try:
  273. params = await request.json()
  274. # 参数验证
  275. text = params.get('text', '').strip()
  276. if not text:
  277. return web.Response(
  278. content_type="application/json",
  279. text=json.dumps({"code": -1, "msg": "text parameter is required"}),
  280. )
  281. sessionid = params.get('sessionid', 0)
  282. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  283. return web.Response(
  284. content_type="application/json",
  285. text=json.dumps(
  286. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  287. ),
  288. )
  289. msg_type = params.get('type', 'chat')
  290. interrupt = params.get('interrupt', True)
  291. knowledge_base = params.get('knowledge_base', None)
  292. during_intro = params.get('during_intro', False)
  293. logger.info(f"API Chat - Session {sessionid}, Type: {msg_type}, Text: {text[:50]}...")
  294. # 处理中断
  295. if interrupt:
  296. nerfreals[sessionid].flush_talk()
  297. # 根据类型处理
  298. if msg_type == 'echo':
  299. # 纯播报模式 - 直接播放文本
  300. nerfreals[sessionid].put_msg_txt(text)
  301. else:
  302. # 聊天/问答模式 - 走LLM
  303. asyncio.get_event_loop().run_in_executor(
  304. None,
  305. llm_response,
  306. text,
  307. nerfreals[sessionid],
  308. knowledge_base,
  309. during_intro
  310. )
  311. return web.Response(
  312. content_type="application/json",
  313. text=json.dumps({
  314. "code": 0,
  315. "msg": "ok",
  316. "data": {
  317. "sessionid": sessionid,
  318. "type": msg_type,
  319. "text": text[:100] # 返回前100字符用于确认
  320. }
  321. }),
  322. )
  323. except Exception as e:
  324. logger.exception('API Chat exception:')
  325. return web.Response(
  326. content_type="application/json",
  327. text=json.dumps({"code": -1, "msg": str(e)}),
  328. )
  329. async def interrupt_talk(request):
  330. try:
  331. params = await request.json()
  332. sessionid = params.get('sessionid',0)
  333. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  334. return web.Response(
  335. content_type="application/json",
  336. text=json.dumps(
  337. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  338. ),
  339. )
  340. nerfreals[sessionid].flush_talk()
  341. return web.Response(
  342. content_type="application/json",
  343. text=json.dumps(
  344. {"code": 0, "msg":"ok"}
  345. ),
  346. )
  347. except Exception as e:
  348. logger.exception('exception:')
  349. return web.Response(
  350. content_type="application/json",
  351. text=json.dumps(
  352. {"code": -1, "msg": str(e)}
  353. ),
  354. )
  355. async def humanaudio(request):
  356. try:
  357. form= await request.post()
  358. sessionid = int(form.get('sessionid',0))
  359. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  360. return web.Response(
  361. content_type="application/json",
  362. text=json.dumps(
  363. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  364. ),
  365. )
  366. fileobj = form["file"]
  367. filename=fileobj.filename
  368. filebytes=fileobj.file.read()
  369. nerfreals[sessionid].put_audio_file(filebytes)
  370. return web.Response(
  371. content_type="application/json",
  372. text=json.dumps(
  373. {"code": 0, "msg":"ok"}
  374. ),
  375. )
  376. except Exception as e:
  377. logger.exception('exception:')
  378. return web.Response(
  379. content_type="application/json",
  380. text=json.dumps(
  381. {"code": -1, "msg": str(e)}
  382. ),
  383. )
  384. async def set_audiotype(request):
  385. try:
  386. params = await request.json()
  387. sessionid = params.get('sessionid',0)
  388. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  389. return web.Response(
  390. content_type="application/json",
  391. text=json.dumps(
  392. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  393. ),
  394. )
  395. nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
  396. return web.Response(
  397. content_type="application/json",
  398. text=json.dumps(
  399. {"code": 0, "msg":"ok"}
  400. ),
  401. )
  402. except Exception as e:
  403. logger.exception('exception:')
  404. return web.Response(
  405. content_type="application/json",
  406. text=json.dumps(
  407. {"code": -1, "msg": str(e)}
  408. ),
  409. )
  410. async def record(request):
  411. try:
  412. params = await request.json()
  413. sessionid = params.get('sessionid',0)
  414. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  415. return web.Response(
  416. content_type="application/json",
  417. text=json.dumps(
  418. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  419. ),
  420. )
  421. if params['type']=='start_record':
  422. # nerfreals[sessionid].put_msg_txt(params['text'])
  423. nerfreals[sessionid].start_recording()
  424. elif params['type']=='end_record':
  425. nerfreals[sessionid].stop_recording()
  426. return web.Response(
  427. content_type="application/json",
  428. text=json.dumps(
  429. {"code": 0, "msg":"ok"}
  430. ),
  431. )
  432. except Exception as e:
  433. logger.exception('exception:')
  434. return web.Response(
  435. content_type="application/json",
  436. text=json.dumps(
  437. {"code": -1, "msg": str(e)}
  438. ),
  439. )
  440. async def is_speaking(request):
  441. params = await request.json()
  442. sessionid = params.get('sessionid',0)
  443. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  444. return web.Response(
  445. content_type="application/json",
  446. text=json.dumps(
  447. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  448. ),
  449. )
  450. return web.Response(
  451. content_type="application/json",
  452. text=json.dumps(
  453. {"code": 0, "data": nerfreals[sessionid].is_speaking()}
  454. )
  455. )
  456. async def knowledge_intro(request):
  457. """返回知识库介绍内容"""
  458. try:
  459. from knowledge_intro import start_intro_play, knowledge_intro
  460. # 启动完整版介绍播放,获取第一条文案
  461. play_result = start_intro_play("full")
  462. intro_text = play_result.get("text", "")
  463. params = await request.json()
  464. sessionid = params.get('sessionid', 0)
  465. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  466. return web.Response(
  467. content_type="application/json",
  468. text=json.dumps(
  469. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  470. ),
  471. )
  472. # 保存介绍实例到会话中,便于后续操作
  473. if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  474. nerfreals[sessionid].knowledge_intro_instance = knowledge_intro
  475. # 保存介绍播放状态到会话中
  476. if not hasattr(nerfreals[sessionid], 'intro_play_state'):
  477. nerfreals[sessionid].intro_play_state = {
  478. "is_playing": True,
  479. "current_type": "full",
  480. "last_played_index": 0,
  481. "is_paused": False,
  482. "is_waiting_next": False
  483. }
  484. # 使用支持打断恢复的新方法来播放介绍内容
  485. nerfreals[sessionid].start_intro_with_interrupt_capability(intro_text)
  486. return web.Response(
  487. content_type="application/json",
  488. text=json.dumps(
  489. {"code": 0, "msg": "Knowledge intro played successfully", "text": intro_text, "mode": "intro", "play_index": play_result.get("play_index", 1), "total_count": play_result.get("total_count", 8)}
  490. ),
  491. )
  492. except Exception as e:
  493. logger.exception('exception in knowledge_intro:')
  494. return web.Response(
  495. content_type="application/json",
  496. text=json.dumps(
  497. {"code": -1, "msg": str(e)}
  498. ),
  499. )
  500. async def resume_interrupted(request):
  501. """恢复播放被中断的消息"""
  502. try:
  503. params = await request.json()
  504. sessionid = params.get('sessionid', 0)
  505. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  506. return web.Response(
  507. content_type="application/json",
  508. text=json.dumps(
  509. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  510. ),
  511. )
  512. # 尝试恢复被中断的消息
  513. resumed = nerfreals[sessionid].resume_interrupted()
  514. if resumed:
  515. return web.Response(
  516. content_type="application/json",
  517. text=json.dumps(
  518. {"code": 0, "msg": "Interrupted messages resumed successfully", "mode": "intro"}
  519. ),
  520. )
  521. else:
  522. return web.Response(
  523. content_type="application/json",
  524. text=json.dumps(
  525. {"code": 0, "msg": "No interrupted messages to resume"}
  526. ),
  527. )
  528. except Exception as e:
  529. logger.exception('exception in resume_interrupted:')
  530. return web.Response(
  531. content_type="application/json",
  532. text=json.dumps(
  533. {"code": -1, "msg": str(e)}
  534. ),
  535. )
  536. async def handle_user_question(request):
  537. """处理用户提问,暂停当前内容并优先回答问题"""
  538. try:
  539. params = await request.json()
  540. sessionid = params.get('sessionid', 0)
  541. question = params.get('text', '')
  542. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  543. return web.Response(
  544. content_type="application/json",
  545. text=json.dumps(
  546. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  547. ),
  548. )
  549. if not question:
  550. return web.Response(
  551. content_type="application/json",
  552. text=json.dumps(
  553. {"code": -1, "msg": "Question text is required"}
  554. ),
  555. )
  556. # 检查是否在介绍过程中,如果是,则标记
  557. during_intro = params.get('during_intro', False)
  558. if during_intro and hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  559. # 暂停介绍播放
  560. nerfreals[sessionid].knowledge_intro_instance.pause_play()
  561. # 使用高优先级方法处理用户问题
  562. knowledge_base_type = params.get('knowledge_base', None)
  563. # 直接调用llm_response,它会使用put_user_question方法
  564. asyncio.get_event_loop().run_in_executor(None, llm_response, question, nerfreals[sessionid], knowledge_base_type, during_intro)
  565. return web.Response(
  566. content_type="application/json",
  567. text=json.dumps(
  568. {"code": 0, "msg": "User question processed successfully", "mode": "qa"}
  569. ),
  570. )
  571. except Exception as e:
  572. logger.exception('exception in handle_user_question:')
  573. return web.Response(
  574. content_type="application/json",
  575. text=json.dumps(
  576. {"code": -1, "msg": str(e)}
  577. ),
  578. )
  579. async def continue_after_qa(request):
  580. """在用户问题回答完毕后,继续播放之前的内容"""
  581. try:
  582. params = await request.json()
  583. sessionid = params.get('sessionid', 0)
  584. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  585. return web.Response(
  586. content_type="application/json",
  587. text=json.dumps(
  588. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  589. ),
  590. )
  591. # 尝试恢复被中断的内容
  592. if hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  593. # 从knowledge_intro_instance获取下一条内容
  594. next_content = nerfreals[sessionid].knowledge_intro_instance.resume_play()
  595. if next_content:
  596. # 播放下一条介绍内容
  597. nerfreals[sessionid].put_msg_txt(next_content['text'])
  598. return web.Response(
  599. content_type="application/json",
  600. text=json.dumps(
  601. {"code": 0, "msg": "Previously interrupted content resumed successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8)}
  602. ),
  603. )
  604. # 如果没有找到knowledge_intro_instance或没有剩余内容
  605. resumed = nerfreals[sessionid].resume_interrupted()
  606. if resumed:
  607. return web.Response(
  608. content_type="application/json",
  609. text=json.dumps(
  610. {"code": 0, "msg": "Previously interrupted content resumed successfully", "mode": "intro"}
  611. ),
  612. )
  613. else:
  614. return web.Response(
  615. content_type="application/json",
  616. text=json.dumps(
  617. {"code": 0, "msg": "No interrupted content to resume"}
  618. ),
  619. )
  620. except Exception as e:
  621. logger.exception('exception in continue_after_qa:')
  622. return web.Response(
  623. content_type="application/json",
  624. text=json.dumps(
  625. {"code": -1, "msg": str(e)}
  626. ),
  627. )
  628. async def intro_play_completed(request):
  629. """处理介绍播放完成的回调,自动播放下一条介绍内容"""
  630. try:
  631. params = await request.json()
  632. sessionid = params.get('sessionid', 0)
  633. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  634. return web.Response(
  635. content_type="application/json",
  636. text=json.dumps(
  637. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  638. ),
  639. )
  640. # 检查是否有介绍实例
  641. if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  642. return web.Response(
  643. content_type="application/json",
  644. text=json.dumps(
  645. {"code": -1, "msg": "Knowledge intro instance not found"}
  646. ),
  647. )
  648. # 检查介绍播放状态
  649. if hasattr(nerfreals[sessionid], 'intro_play_state') and not nerfreals[sessionid].intro_play_state.get("is_playing", True):
  650. return web.Response(
  651. content_type="application/json",
  652. text=json.dumps(
  653. {"code": 0, "msg": "Introduction playback is paused"}
  654. ),
  655. )
  656. # 获取下一条介绍内容
  657. next_content = nerfreals[sessionid].knowledge_intro_instance._get_next_content()
  658. if next_content:
  659. # 播放下一条介绍内容
  660. nerfreals[sessionid].put_msg_txt(next_content['text'])
  661. # 更新播放状态
  662. if hasattr(nerfreals[sessionid], 'intro_play_state'):
  663. nerfreals[sessionid].intro_play_state["last_played_index"] = next_content.get("play_index", 1)
  664. return web.Response(
  665. content_type="application/json",
  666. text=json.dumps(
  667. {"code": 0, "msg": "Next introduction content played successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8), "is_last": next_content.get("is_last", False)}
  668. ),
  669. )
  670. else:
  671. return web.Response(
  672. content_type="application/json",
  673. text=json.dumps(
  674. {"code": 0, "msg": "No more introduction content"}
  675. ),
  676. )
  677. except Exception as e:
  678. logger.exception('exception in intro_play_completed:')
  679. return web.Response(
  680. content_type="application/json",
  681. text=json.dumps(
  682. {"code": -1, "msg": str(e)}
  683. ),
  684. )
  685. async def on_shutdown(app):
  686. """关闭时清理所有资源"""
  687. logger.info("开始清理所有资源...")
  688. # 1. 关闭所有 WebRTC 连接
  689. coros = [pc.close() for pc in pcs]
  690. await asyncio.gather(*coros, return_exceptions=True)
  691. pcs.clear()
  692. # 2. 清理所有数字人实例
  693. for sessionid, nerfreal in nerfreals.items():
  694. try:
  695. if nerfreal is not None:
  696. logger.info(f"清理 session {sessionid} 的资源")
  697. # 停止 TTS
  698. if hasattr(nerfreal, 'tts') and nerfreal.tts:
  699. nerfreal.tts.state = State.PAUSE
  700. # 清理队列
  701. if hasattr(nerfreal, 'msg_queue'):
  702. with nerfreal.msg_queue.mutex:
  703. nerfreal.msg_queue.queue.clear()
  704. if hasattr(nerfreal, 'interrupted_queue'):
  705. with nerfreal.interrupted_queue.mutex:
  706. nerfreal.interrupted_queue.queue.clear()
  707. except Exception as e:
  708. logger.error(f"清理 session {sessionid} 时出错: {e}")
  709. # 3. 清理数字人实例字典
  710. nerfreals.clear()
  711. # 4. 强制垃圾回收
  712. import gc
  713. gc.collect()
  714. if torch.cuda.is_available():
  715. torch.cuda.empty_cache()
  716. logger.info(f"清理后 GPU 显存: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
  717. logger.info("所有资源清理完成")
  718. def cleanup_avatar_directories():
  719. """程序退出时清理 avatar 目录"""
  720. global _avatar_dirs_to_clean, _enable_avatar_cleanup
  721. if not _enable_avatar_cleanup:
  722. logger.info("未启用 avatar 目录自动清理功能")
  723. return
  724. if not _avatar_dirs_to_clean:
  725. logger.info("没有需要清理的 avatar 目录")
  726. return
  727. logger.info(f"开始清理 {len(_avatar_dirs_to_clean)} 个 avatar 目录...")
  728. for avatar_dir in _avatar_dirs_to_clean:
  729. try:
  730. if os.path.exists(avatar_dir):
  731. logger.info(f"正在删除 avatar 目录: {avatar_dir}")
  732. shutil.rmtree(avatar_dir)
  733. logger.info(f"✅ 已删除: {avatar_dir}")
  734. else:
  735. logger.info(f"目录不存在,无需删除: {avatar_dir}")
  736. except Exception as e:
  737. logger.error(f"❌ 删除 avatar 目录失败 {avatar_dir}: {e}")
  738. _avatar_dirs_to_clean.clear()
  739. logger.info("所有 avatar 目录清理完成")
  740. def signal_handler(signum, frame):
  741. """信号处理器 - 捕获退出信号"""
  742. logger.info(f"收到退出信号 {signum},开始清理...")
  743. cleanup_avatar_directories()
  744. sys.exit(0)
  745. async def post(url,data):
  746. try:
  747. async with aiohttp.ClientSession() as session:
  748. async with session.post(url,data=data) as response:
  749. return await response.text()
  750. except aiohttp.ClientError as e:
  751. logger.error(f'POST请求错误: {e}')
  752. return None # 明确返回None
  753. async def run(push_url,sessionid):
  754. nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
  755. nerfreals[sessionid] = nerfreal
  756. # RTMP 推流模式:跳过 WHIP 连接,使用 basereal.py 中的 ffmpeg 管道推流
  757. if push_url.startswith('rtmp://'):
  758. logger.info(f'RTMP 推流模式: {push_url},跳过 WHIP 连接,推流将在 render 时自动启动')
  759. # 需要启动渲染循环以产生视频/音频帧
  760. # 使用 HumanPlayer 触发 render,但不建立 WebRTC 连接
  761. player = HumanPlayer(nerfreals[sessionid])
  762. # 手动启动 player worker thread 来触发 nerfreal.render()
  763. # HumanPlayer._start 会在 track 被 recv 时调用,这里直接启动
  764. from threading import Event as ThreadEvent
  765. render_quit_event = ThreadEvent()
  766. # 直接启动 render 线程,不需要 WebRTC track
  767. import threading
  768. def rtmp_render_loop():
  769. nerfreals[sessionid].render(render_quit_event, loop=None, audio_track=None, video_track=None)
  770. render_thread = threading.Thread(target=rtmp_render_loop, daemon=True, name='rtmp_render')
  771. render_thread.start()
  772. logger.info('RTMP 渲染线程已启动')
  773. return
  774. # WebRTC WHIP 推流(原有逻辑)
  775. pc = RTCPeerConnection()
  776. pcs.add(pc)
  777. @pc.on("connectionstatechange")
  778. async def on_connectionstatechange():
  779. logger.info("Connection state is %s" % pc.connectionState)
  780. if pc.connectionState == "failed":
  781. await pc.close()
  782. pcs.discard(pc)
  783. player = HumanPlayer(nerfreals[sessionid])
  784. audio_sender = pc.addTrack(player.audio)
  785. video_sender = pc.addTrack(player.video)
  786. await pc.setLocalDescription(await pc.createOffer())
  787. answer = await post(push_url,pc.localDescription.sdp)
  788. # 检查POST请求是否成功
  789. if answer is None:
  790. logger.error(f'推流失败: 无法连接到 {push_url}')
  791. await pc.close()
  792. pcs.discard(pc)
  793. return
  794. await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
  795. ##########################################
  796. # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
  797. # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
  798. if __name__ == '__main__':
  799. # 注册信号处理器
  800. signal.signal(signal.SIGINT, signal_handler)
  801. signal.signal(signal.SIGTERM, signal_handler)
  802. mp.set_start_method('spawn')
  803. parser = argparse.ArgumentParser()
  804. # audio FPS
  805. parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
  806. # sliding window left-middle-right length (unit: 20ms)
  807. parser.add_argument('-l', type=int, default=10)
  808. parser.add_argument('-m', type=int, default=8)
  809. parser.add_argument('-r', type=int, default=10)
  810. parser.add_argument('--W', type=int, default=450, help="GUI width")
  811. parser.add_argument('--H', type=int, default=450, help="GUI height")
  812. #musetalk opt
  813. parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars")
  814. #parser.add_argument('--bbox_shift', type=int, default=5)
  815. parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
  816. parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
  817. parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice fishtts tencent doubao indextts2 azuretts qwen3tts voxcpm2api
  818. parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural",help="参考文件名或语音模型 ID,默认值为 edgetts 的语音模型 ID zh-CN-YunxiaNeural, 若--tts 指定为 azuretts, 可以使用 Azure 语音模型 ID, 如 zh-CN-XiaoxiaoMultilingualNeural")
  819. parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880', help="TTS server URL")
  820. parser.add_argument('--QWEN3_TTS_MODEL_PATH', type=str, default='/home/test/Digital_Human/Qwen3-TTS-12Hz-1.7B-Base', help="Qwen3 TTS model path")
  821. parser.add_argument('--QWEN3_TTS_LANGUAGE', type=str, default='Chinese', help="Qwen3 TTS language")
  822. parser.add_argument('--VOXCPM2_MODEL_PATH', type=str, default='VoxCPM2', help="VoxCPM2 模型路径")
  823. parser.add_argument('--VOXCPM2_API_URL', type=str, default='http://localhost:6003', help="VoxCPM2 API 服务地址(API 调用模式)")
  824. parser.add_argument('--VOXCPM2_REF_WAV', type=str, default='voice_output.wav', help="VoxCPM2 参考音频路径")
  825. parser.add_argument('--VOXCPM2_REF_TEXT', type=str, default='你好,买水果,卖水果,新鲜的水果。', help="VoxCPM2 参考文本")
  826. parser.add_argument('--CFG_VALUE', type=float, default=2.0, help="VoxCPM2 CFG value")
  827. parser.add_argument('--INFERENCE_TIMESTEPS', type=int, default=10, help="VoxCPM2 inference timesteps")
  828. parser.add_argument('--REF_TEXT', type=str, default=None, help="参考文本")
  829. # parser.add_argument('--CHARACTER', type=str, default='test')
  830. # parser.add_argument('--EMOTION', type=str, default='default')
  831. parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
  832. parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam
  833. parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
  834. parser.add_argument('--max_session', type=int, default=1) #multi session count
  835. parser.add_argument('--listenport', type=int, default=7868, help="web listen port")
  836. # Avatar 目录清理配置
  837. parser.add_argument('--cleanup_avatar_on_exit', action='store_true', default=False,
  838. help="程序退出时自动删除本次使用的 avatar 目录")
  839. opt = parser.parse_args()
  840. # 设置全局清理标志
  841. _enable_avatar_cleanup = opt.cleanup_avatar_on_exit
  842. if _enable_avatar_cleanup:
  843. logger.info("✅ 已启用 avatar 目录退出自动清理功能")
  844. else:
  845. logger.info("ℹ️ 未启用 avatar 目录自动清理(使用 --cleanup_avatar_on_exit 启用)")
  846. #app.config.from_object(opt)
  847. #print(app.config)
  848. opt.customopt = []
  849. if opt.customvideo_config!='':
  850. with open(opt.customvideo_config,'r') as file:
  851. opt.customopt = json.load(file)
  852. # if opt.model == 'ernerf':
  853. # from nerfreal import NeRFReal,load_model,load_avatar
  854. # model = load_model(opt)
  855. # avatar = load_avatar(opt)
  856. if opt.model == 'musetalk':
  857. from musereal import MuseReal,load_model,load_avatar,warm_up
  858. logger.info(opt)
  859. model = load_model()
  860. avatar = load_avatar(opt.avatar_id)
  861. warm_up(opt.batch_size,model)
  862. elif opt.model == 'wav2lip':
  863. from lipreal import LipReal,load_model,load_avatar,warm_up
  864. logger.info(opt)
  865. model = load_model("./models/wav2lip256.pth")
  866. avatar = load_avatar(opt.avatar_id)
  867. warm_up(opt.batch_size,model,256)
  868. elif opt.model == 'ultralight':
  869. from lightreal import LightReal,load_model,load_avatar,warm_up
  870. logger.info(opt)
  871. model = load_model(opt)
  872. avatar = load_avatar(opt.avatar_id)
  873. warm_up(opt.batch_size,avatar,160)
  874. # if opt.transport=='rtmp':
  875. # thread_quit = Event()
  876. # nerfreals[0] = build_nerfreal(0)
  877. # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
  878. # rendthrd.start()
  879. if opt.transport=='virtualcam':
  880. thread_quit = Event()
  881. nerfreals[0] = build_nerfreal(0)
  882. rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
  883. rendthrd.start()
  884. #############################################################################
  885. appasync = web.Application(client_max_size=1024**2*100)
  886. appasync.on_shutdown.append(on_shutdown)
  887. appasync.router.add_post("/offer", offer)
  888. appasync.router.add_post("/human", human)
  889. appasync.router.add_post("/api/chat", api_chat) # 新增独立聊天API
  890. appasync.router.add_post("/humanaudio", humanaudio)
  891. appasync.router.add_post("/set_audiotype", set_audiotype)
  892. appasync.router.add_post("/record", record)
  893. appasync.router.add_post("/interrupt_talk", interrupt_talk)
  894. appasync.router.add_post("/is_speaking", is_speaking)
  895. appasync.router.add_post("/knowledge_intro", knowledge_intro) # 新增知识库介绍 API
  896. appasync.router.add_post("/resume_interrupted", resume_interrupted) # 新增恢复中断消息 API
  897. appasync.router.add_post("/handle_user_question", handle_user_question) # 新增用户提问 API
  898. appasync.router.add_post("/continue_after_qa", continue_after_qa) # 新增问答后继续播放 API
  899. appasync.router.add_post("/intro_play_completed", intro_play_completed) # 新增介绍播放完成回调 API
  900. appasync.router.add_post("/log", log) # 前端日志接口
  901. appasync.router.add_static('/',path='web')
  902. # Configure default CORS settings.
  903. cors = aiohttp_cors.setup(appasync, defaults={
  904. "*": aiohttp_cors.ResourceOptions(
  905. allow_credentials=True,
  906. expose_headers="*",
  907. allow_headers="*",
  908. )
  909. })
  910. # Configure CORS on all routes.
  911. for route in list(appasync.router.routes()):
  912. cors.add(route)
  913. pagename='webrtcapi.html'
  914. if opt.transport=='rtmp':
  915. pagename='echoapi.html'
  916. elif opt.transport=='rtcpush':
  917. pagename='rtcpushapi.html'
  918. logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
  919. logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
  920. def run_server(runner):
  921. try:
  922. loop = asyncio.new_event_loop()
  923. asyncio.set_event_loop(loop)
  924. loop.run_until_complete(runner.setup())
  925. site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
  926. loop.run_until_complete(site.start())
  927. if opt.transport=='rtcpush':
  928. for k in range(opt.max_session):
  929. push_url = opt.push_url
  930. if k!=0:
  931. push_url = opt.push_url+str(k)
  932. loop.run_until_complete(run(push_url,k))
  933. loop.run_forever()
  934. except KeyboardInterrupt:
  935. logger.info("收到 KeyboardInterrupt,正在退出...")
  936. except Exception as e:
  937. logger.error(f"服务器运行出错: {e}")
  938. finally:
  939. # 服务器退出时清理 avatar 目录
  940. cleanup_avatar_directories()
  941. #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
  942. run_server(web.AppRunner(appasync))
  943. #app.on_shutdown.append(on_shutdown)
  944. #app.router.add_post("/offer", offer)
  945. # print('start websocket server')
  946. # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
  947. # server.serve_forever()