app.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. ###############################################################################
  2. # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
  3. # email: lipku@foxmail.com
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. ###############################################################################
  17. # server.py
  18. import flask
  19. from flask import Flask, render_template,send_from_directory,request, jsonify
  20. from flask_sockets import Sockets
  21. import base64
  22. import json
  23. #import gevent
  24. #from gevent import pywsgi
  25. #from geventwebsocket.handler import WebSocketHandler
  26. import re
  27. import numpy as np
  28. from threading import Thread,Event
  29. #import multiprocessing
  30. import torch.multiprocessing as mp
  31. from aiohttp import web
  32. import aiohttp
  33. import aiohttp_cors
  34. from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration
  35. from aiortc.rtcrtpsender import RTCRtpSender
  36. from webrtc import HumanPlayer
  37. from basereal import BaseReal
  38. from llm import llm_response
  39. import argparse
  40. import random
  41. import shutil
  42. import asyncio
  43. import torch
  44. import os
  45. import socket
  46. from typing import Dict
  47. from logger import logger
  48. import gc
  49. app = Flask(__name__)
  50. #sockets = Sockets(app)
  51. nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
  52. opt = None
  53. model = None
  54. avatar = None
  55. #####webrtc###############################
  56. pcs = set()
  57. def randN(N)->int:
  58. '''生成长度为 N的随机数 '''
  59. min = pow(10, N - 1)
  60. max = pow(10, N)
  61. return random.randint(min, max - 1)
  62. def build_nerfreal(sessionid:int)->BaseReal:
  63. opt.sessionid=sessionid
  64. if opt.model == 'wav2lip':
  65. from lipreal import LipReal
  66. nerfreal = LipReal(opt,model,avatar)
  67. elif opt.model == 'musetalk':
  68. from musereal import MuseReal
  69. nerfreal = MuseReal(opt,model,avatar)
  70. # elif opt.model == 'ernerf':
  71. # from nerfreal import NeRFReal
  72. # nerfreal = NeRFReal(opt,model,avatar)
  73. elif opt.model == 'ultralight':
  74. from lightreal import LightReal
  75. nerfreal = LightReal(opt,model,avatar)
  76. else:
  77. raise ValueError(f"Unsupported model type: {opt.model}")
  78. return nerfreal
  79. #@app.route('/offer', methods=['POST'])
  80. async def offer(request):
  81. try:
  82. params = await request.json()
  83. offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
  84. # if len(nerfreals) >= opt.max_session:
  85. # logger.info('reach max session')
  86. # return web.Response(
  87. # content_type="application/json",
  88. # text=json.dumps(
  89. # {"code": -1, "msg": "reach max session"}
  90. # ),
  91. # )
  92. sessionid = randN(6) #len(nerfreals)
  93. nerfreals[sessionid] = None
  94. logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals))
  95. nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
  96. nerfreals[sessionid] = nerfreal
  97. # 内网/跳板机环境:不使用外网 STUN(内网访问不到),只用 host 候选
  98. # 通过环境变量 WEBRTC_NAT_IP 指定跳板机对外暴露的 IP(浏览器能访问的 IP)
  99. nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
  100. if nat_public_ip:
  101. logger.info('Using NAT public IP for ICE: %s', nat_public_ip)
  102. ice_servers = [] # 内网不用 STUN
  103. else:
  104. # 使用多个 STUN 服务器(包括国内和国外的)
  105. ice_servers = [
  106. RTCIceServer(urls='stun:stun.l.google.com:19302'),
  107. RTCIceServer(urls='stun:stun1.l.google.com:19302'),
  108. RTCIceServer(urls='stun:stun2.l.google.com:19302'),
  109. RTCIceServer(urls='stun:stun3.l.google.com:19302'),
  110. RTCIceServer(urls='stun:stun4.l.google.com:19302'),
  111. ]
  112. pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers))
  113. pcs.add(pc)
  114. @pc.on("connectionstatechange")
  115. async def on_connectionstatechange():
  116. logger.info(f"Session {sessionid} - Connection state is {pc.connectionState}")
  117. if pc.connectionState == "connected":
  118. logger.info(f"Session {sessionid} - WebRTC connection established successfully!")
  119. if pc.connectionState == "failed":
  120. logger.error(f"Session {sessionid} - Connection failed!")
  121. await pc.close()
  122. pcs.discard(pc)
  123. if sessionid in nerfreals:
  124. del nerfreals[sessionid]
  125. if pc.connectionState == "closed":
  126. logger.info(f"Session {sessionid} - Connection closed")
  127. pcs.discard(pc)
  128. if sessionid in nerfreals:
  129. del nerfreals[sessionid]
  130. # gc.collect()
  131. player = HumanPlayer(nerfreals[sessionid])
  132. audio_sender = pc.addTrack(player.audio)
  133. video_sender = pc.addTrack(player.video)
  134. # 记录轨道添加信息
  135. logger.info(f"Added tracks for session {sessionid}: audio={player.audio}, video={player.video}")
  136. capabilities = RTCRtpSender.getCapabilities("video")
  137. preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
  138. preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
  139. preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
  140. transceiver = pc.getTransceivers()[1]
  141. transceiver.setCodecPreferences(preferences)
  142. await pc.setRemoteDescription(offer)
  143. answer = await pc.createAnswer()
  144. await pc.setLocalDescription(answer)
  145. sdp = pc.localDescription.sdp
  146. # 跳板机内网场景:如果配置了 WEBRTC_NAT_IP,将 SDP 中的内网 IP 替换为跳板机对外 IP
  147. # aioice 已经将真实绑定的端口(AIOICE_PORT_MIN~MAX 范围内)写入 SDP 的 a=candidate 行
  148. # 我们只需要把内网 IP 替换成浏览器能访问的 IP 即可
  149. nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
  150. if nat_public_ip:
  151. try:
  152. # 获取本机所有非 loopback IPv4 地址
  153. local_ips = set()
  154. bind_ip = os.environ.get('AIOICE_BIND_IP', '')
  155. if bind_ip:
  156. local_ips.add(bind_ip)
  157. else:
  158. hostname = socket.gethostname()
  159. for ip_info in socket.getaddrinfo(hostname, None, socket.AF_INET):
  160. ip = ip_info[4][0]
  161. if not ip.startswith('127.'):
  162. local_ips.add(ip)
  163. # 也枚举所有网卡 IP(getaddrinfo 可能不全)
  164. try:
  165. import ifaddr
  166. for adapter in ifaddr.get_adapters():
  167. for ip_obj in adapter.ips:
  168. if isinstance(ip_obj.ip, str) and not ip_obj.ip.startswith('127.'):
  169. local_ips.add(ip_obj.ip)
  170. except Exception:
  171. pass
  172. # 将 SDP 中所有内网 IP 替换为跳板机对外 IP
  173. for local_ip in local_ips:
  174. sdp = sdp.replace(local_ip, nat_public_ip)
  175. logger.info('NAT mapping: %s -> %s in SDP', local_ips, nat_public_ip)
  176. except Exception as e:
  177. logger.warning('Failed to apply NAT IP mapping to SDP: %s', e)
  178. else:
  179. logger.info('WEBRTC_NAT_IP not set. SDP uses server internal IPs. '
  180. 'Set WEBRTC_NAT_IP=<jumphost_ip> if browser cannot reach server directly.')
  181. return web.Response(
  182. content_type="application/json",
  183. text=json.dumps(
  184. {"sdp": sdp, "type": pc.localDescription.type, "sessionid": sessionid}
  185. ),
  186. )
  187. except Exception as e:
  188. logger.exception('Error in offer:')
  189. return web.Response(
  190. content_type="application/json",
  191. text=json.dumps({"code": -1, "msg": str(e)}),
  192. status=500
  193. )
  194. async def log(request):
  195. """接收前端日志的接口"""
  196. try:
  197. params = await request.json()
  198. log_type = params.get('type', 'info')
  199. message = params.get('message', '')
  200. logger.info('[WEBRTC] %s', message)
  201. except Exception:
  202. pass
  203. return web.Response(
  204. content_type="application/json",
  205. text=json.dumps({"code": 0}),
  206. )
  207. async def human(request):
  208. try:
  209. params = await request.json()
  210. sessionid = params.get('sessionid',0)
  211. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  212. return web.Response(
  213. content_type="application/json",
  214. text=json.dumps(
  215. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  216. ),
  217. )
  218. # 标记是否已经处理了聊天请求
  219. chat_processed = False
  220. if params.get('interrupt'):
  221. # 立即中断当前播放,无论是否是聊天请求
  222. nerfreals[sessionid].flush_talk()
  223. # 检查是否是用户在介绍过程中提问
  224. if params['type'] == 'chat':
  225. knowledge_base_type = params.get('knowledge_base', None)
  226. during_intro = params.get('during_intro', False)
  227. # 处理用户问题(同步执行,确保立即处理)
  228. llm_response(params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
  229. chat_processed = True
  230. if params['type']=='echo':
  231. nerfreals[sessionid].put_msg_txt(params['text'])
  232. elif params['type']=='chat' and not chat_processed:
  233. # 如果没有中断标志,或者已经处理过聊天请求,则不再处理
  234. knowledge_base_type = params.get('knowledge_base', None)
  235. during_intro = params.get('during_intro', False) # 是否在介绍过程中
  236. asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
  237. #nerfreals[sessionid].put_msg_txt(res)
  238. return web.Response(
  239. content_type="application/json",
  240. text=json.dumps(
  241. {"code": 0, "msg":"ok"}
  242. ),
  243. )
  244. except Exception as e:
  245. logger.exception('exception:')
  246. return web.Response(
  247. content_type="application/json",
  248. text=json.dumps(
  249. {"code": -1, "msg": str(e)}
  250. ),
  251. )
  252. async def api_chat(request):
  253. """
  254. 独立聊天API - 让数字人说话/问答
  255. 请求体: {
  256. "text": "用户说的话",
  257. "type": "chat", // "chat"=问答, "echo"=纯播报
  258. "interrupt": true, // 是否打断当前播放
  259. "sessionid": 123456, // WebRTC session ID
  260. "knowledge_base": "", // 可选: 指定知识库
  261. "during_intro": false // 可选: 是否在介绍过程中
  262. }
  263. """
  264. try:
  265. params = await request.json()
  266. # 参数验证
  267. text = params.get('text', '').strip()
  268. if not text:
  269. return web.Response(
  270. content_type="application/json",
  271. text=json.dumps({"code": -1, "msg": "text parameter is required"}),
  272. )
  273. sessionid = params.get('sessionid', 0)
  274. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  275. return web.Response(
  276. content_type="application/json",
  277. text=json.dumps(
  278. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  279. ),
  280. )
  281. msg_type = params.get('type', 'chat')
  282. interrupt = params.get('interrupt', True)
  283. knowledge_base = params.get('knowledge_base', None)
  284. during_intro = params.get('during_intro', False)
  285. logger.info(f"API Chat - Session {sessionid}, Type: {msg_type}, Text: {text[:50]}...")
  286. # 处理中断
  287. if interrupt:
  288. nerfreals[sessionid].flush_talk()
  289. # 根据类型处理
  290. if msg_type == 'echo':
  291. # 纯播报模式 - 直接播放文本
  292. nerfreals[sessionid].put_msg_txt(text)
  293. else:
  294. # 聊天/问答模式 - 走LLM
  295. asyncio.get_event_loop().run_in_executor(
  296. None,
  297. llm_response,
  298. text,
  299. nerfreals[sessionid],
  300. knowledge_base,
  301. during_intro
  302. )
  303. return web.Response(
  304. content_type="application/json",
  305. text=json.dumps({
  306. "code": 0,
  307. "msg": "ok",
  308. "data": {
  309. "sessionid": sessionid,
  310. "type": msg_type,
  311. "text": text[:100] # 返回前100字符用于确认
  312. }
  313. }),
  314. )
  315. except Exception as e:
  316. logger.exception('API Chat exception:')
  317. return web.Response(
  318. content_type="application/json",
  319. text=json.dumps({"code": -1, "msg": str(e)}),
  320. )
  321. async def interrupt_talk(request):
  322. try:
  323. params = await request.json()
  324. sessionid = params.get('sessionid',0)
  325. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  326. return web.Response(
  327. content_type="application/json",
  328. text=json.dumps(
  329. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  330. ),
  331. )
  332. nerfreals[sessionid].flush_talk()
  333. return web.Response(
  334. content_type="application/json",
  335. text=json.dumps(
  336. {"code": 0, "msg":"ok"}
  337. ),
  338. )
  339. except Exception as e:
  340. logger.exception('exception:')
  341. return web.Response(
  342. content_type="application/json",
  343. text=json.dumps(
  344. {"code": -1, "msg": str(e)}
  345. ),
  346. )
  347. async def humanaudio(request):
  348. try:
  349. form= await request.post()
  350. sessionid = int(form.get('sessionid',0))
  351. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  352. return web.Response(
  353. content_type="application/json",
  354. text=json.dumps(
  355. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  356. ),
  357. )
  358. fileobj = form["file"]
  359. filename=fileobj.filename
  360. filebytes=fileobj.file.read()
  361. nerfreals[sessionid].put_audio_file(filebytes)
  362. return web.Response(
  363. content_type="application/json",
  364. text=json.dumps(
  365. {"code": 0, "msg":"ok"}
  366. ),
  367. )
  368. except Exception as e:
  369. logger.exception('exception:')
  370. return web.Response(
  371. content_type="application/json",
  372. text=json.dumps(
  373. {"code": -1, "msg": str(e)}
  374. ),
  375. )
  376. async def set_audiotype(request):
  377. try:
  378. params = await request.json()
  379. sessionid = params.get('sessionid',0)
  380. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  381. return web.Response(
  382. content_type="application/json",
  383. text=json.dumps(
  384. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  385. ),
  386. )
  387. nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
  388. return web.Response(
  389. content_type="application/json",
  390. text=json.dumps(
  391. {"code": 0, "msg":"ok"}
  392. ),
  393. )
  394. except Exception as e:
  395. logger.exception('exception:')
  396. return web.Response(
  397. content_type="application/json",
  398. text=json.dumps(
  399. {"code": -1, "msg": str(e)}
  400. ),
  401. )
  402. async def record(request):
  403. try:
  404. params = await request.json()
  405. sessionid = params.get('sessionid',0)
  406. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  407. return web.Response(
  408. content_type="application/json",
  409. text=json.dumps(
  410. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  411. ),
  412. )
  413. if params['type']=='start_record':
  414. # nerfreals[sessionid].put_msg_txt(params['text'])
  415. nerfreals[sessionid].start_recording()
  416. elif params['type']=='end_record':
  417. nerfreals[sessionid].stop_recording()
  418. return web.Response(
  419. content_type="application/json",
  420. text=json.dumps(
  421. {"code": 0, "msg":"ok"}
  422. ),
  423. )
  424. except Exception as e:
  425. logger.exception('exception:')
  426. return web.Response(
  427. content_type="application/json",
  428. text=json.dumps(
  429. {"code": -1, "msg": str(e)}
  430. ),
  431. )
  432. async def is_speaking(request):
  433. params = await request.json()
  434. sessionid = params.get('sessionid',0)
  435. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  436. return web.Response(
  437. content_type="application/json",
  438. text=json.dumps(
  439. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  440. ),
  441. )
  442. return web.Response(
  443. content_type="application/json",
  444. text=json.dumps(
  445. {"code": 0, "data": nerfreals[sessionid].is_speaking()}
  446. )
  447. )
  448. async def knowledge_intro(request):
  449. """返回知识库介绍内容"""
  450. try:
  451. from knowledge_intro import start_intro_play, knowledge_intro
  452. # 启动完整版介绍播放,获取第一条文案
  453. play_result = start_intro_play("full")
  454. intro_text = play_result.get("text", "")
  455. params = await request.json()
  456. sessionid = params.get('sessionid', 0)
  457. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  458. return web.Response(
  459. content_type="application/json",
  460. text=json.dumps(
  461. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  462. ),
  463. )
  464. # 保存介绍实例到会话中,便于后续操作
  465. if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  466. nerfreals[sessionid].knowledge_intro_instance = knowledge_intro
  467. # 保存介绍播放状态到会话中
  468. if not hasattr(nerfreals[sessionid], 'intro_play_state'):
  469. nerfreals[sessionid].intro_play_state = {
  470. "is_playing": True,
  471. "current_type": "full",
  472. "last_played_index": 0,
  473. "is_paused": False,
  474. "is_waiting_next": False
  475. }
  476. # 使用支持打断恢复的新方法来播放介绍内容
  477. nerfreals[sessionid].start_intro_with_interrupt_capability(intro_text)
  478. return web.Response(
  479. content_type="application/json",
  480. text=json.dumps(
  481. {"code": 0, "msg": "Knowledge intro played successfully", "text": intro_text, "mode": "intro", "play_index": play_result.get("play_index", 1), "total_count": play_result.get("total_count", 8)}
  482. ),
  483. )
  484. except Exception as e:
  485. logger.exception('exception in knowledge_intro:')
  486. return web.Response(
  487. content_type="application/json",
  488. text=json.dumps(
  489. {"code": -1, "msg": str(e)}
  490. ),
  491. )
  492. async def resume_interrupted(request):
  493. """恢复播放被中断的消息"""
  494. try:
  495. params = await request.json()
  496. sessionid = params.get('sessionid', 0)
  497. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  498. return web.Response(
  499. content_type="application/json",
  500. text=json.dumps(
  501. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  502. ),
  503. )
  504. # 尝试恢复被中断的消息
  505. resumed = nerfreals[sessionid].resume_interrupted()
  506. if resumed:
  507. return web.Response(
  508. content_type="application/json",
  509. text=json.dumps(
  510. {"code": 0, "msg": "Interrupted messages resumed successfully", "mode": "intro"}
  511. ),
  512. )
  513. else:
  514. return web.Response(
  515. content_type="application/json",
  516. text=json.dumps(
  517. {"code": 0, "msg": "No interrupted messages to resume"}
  518. ),
  519. )
  520. except Exception as e:
  521. logger.exception('exception in resume_interrupted:')
  522. return web.Response(
  523. content_type="application/json",
  524. text=json.dumps(
  525. {"code": -1, "msg": str(e)}
  526. ),
  527. )
  528. async def handle_user_question(request):
  529. """处理用户提问,暂停当前内容并优先回答问题"""
  530. try:
  531. params = await request.json()
  532. sessionid = params.get('sessionid', 0)
  533. question = params.get('text', '')
  534. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  535. return web.Response(
  536. content_type="application/json",
  537. text=json.dumps(
  538. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  539. ),
  540. )
  541. if not question:
  542. return web.Response(
  543. content_type="application/json",
  544. text=json.dumps(
  545. {"code": -1, "msg": "Question text is required"}
  546. ),
  547. )
  548. # 检查是否在介绍过程中,如果是,则标记
  549. during_intro = params.get('during_intro', False)
  550. if during_intro and hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  551. # 暂停介绍播放
  552. nerfreals[sessionid].knowledge_intro_instance.pause_play()
  553. # 使用高优先级方法处理用户问题
  554. knowledge_base_type = params.get('knowledge_base', None)
  555. # 直接调用llm_response,它会使用put_user_question方法
  556. asyncio.get_event_loop().run_in_executor(None, llm_response, question, nerfreals[sessionid], knowledge_base_type, during_intro)
  557. return web.Response(
  558. content_type="application/json",
  559. text=json.dumps(
  560. {"code": 0, "msg": "User question processed successfully", "mode": "qa"}
  561. ),
  562. )
  563. except Exception as e:
  564. logger.exception('exception in handle_user_question:')
  565. return web.Response(
  566. content_type="application/json",
  567. text=json.dumps(
  568. {"code": -1, "msg": str(e)}
  569. ),
  570. )
  571. async def continue_after_qa(request):
  572. """在用户问题回答完毕后,继续播放之前的内容"""
  573. try:
  574. params = await request.json()
  575. sessionid = params.get('sessionid', 0)
  576. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  577. return web.Response(
  578. content_type="application/json",
  579. text=json.dumps(
  580. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  581. ),
  582. )
  583. # 尝试恢复被中断的内容
  584. if hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  585. # 从knowledge_intro_instance获取下一条内容
  586. next_content = nerfreals[sessionid].knowledge_intro_instance.resume_play()
  587. if next_content:
  588. # 播放下一条介绍内容
  589. nerfreals[sessionid].put_msg_txt(next_content['text'])
  590. return web.Response(
  591. content_type="application/json",
  592. text=json.dumps(
  593. {"code": 0, "msg": "Previously interrupted content resumed successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8)}
  594. ),
  595. )
  596. # 如果没有找到knowledge_intro_instance或没有剩余内容
  597. resumed = nerfreals[sessionid].resume_interrupted()
  598. if resumed:
  599. return web.Response(
  600. content_type="application/json",
  601. text=json.dumps(
  602. {"code": 0, "msg": "Previously interrupted content resumed successfully", "mode": "intro"}
  603. ),
  604. )
  605. else:
  606. return web.Response(
  607. content_type="application/json",
  608. text=json.dumps(
  609. {"code": 0, "msg": "No interrupted content to resume"}
  610. ),
  611. )
  612. except Exception as e:
  613. logger.exception('exception in continue_after_qa:')
  614. return web.Response(
  615. content_type="application/json",
  616. text=json.dumps(
  617. {"code": -1, "msg": str(e)}
  618. ),
  619. )
  620. async def intro_play_completed(request):
  621. """处理介绍播放完成的回调,自动播放下一条介绍内容"""
  622. try:
  623. params = await request.json()
  624. sessionid = params.get('sessionid', 0)
  625. if sessionid not in nerfreals or nerfreals[sessionid] is None:
  626. return web.Response(
  627. content_type="application/json",
  628. text=json.dumps(
  629. {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
  630. ),
  631. )
  632. # 检查是否有介绍实例
  633. if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
  634. return web.Response(
  635. content_type="application/json",
  636. text=json.dumps(
  637. {"code": -1, "msg": "Knowledge intro instance not found"}
  638. ),
  639. )
  640. # 检查介绍播放状态
  641. if hasattr(nerfreals[sessionid], 'intro_play_state') and not nerfreals[sessionid].intro_play_state.get("is_playing", True):
  642. return web.Response(
  643. content_type="application/json",
  644. text=json.dumps(
  645. {"code": 0, "msg": "Introduction playback is paused"}
  646. ),
  647. )
  648. # 获取下一条介绍内容
  649. next_content = nerfreals[sessionid].knowledge_intro_instance._get_next_content()
  650. if next_content:
  651. # 播放下一条介绍内容
  652. nerfreals[sessionid].put_msg_txt(next_content['text'])
  653. # 更新播放状态
  654. if hasattr(nerfreals[sessionid], 'intro_play_state'):
  655. nerfreals[sessionid].intro_play_state["last_played_index"] = next_content.get("play_index", 1)
  656. return web.Response(
  657. content_type="application/json",
  658. text=json.dumps(
  659. {"code": 0, "msg": "Next introduction content played successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8), "is_last": next_content.get("is_last", False)}
  660. ),
  661. )
  662. else:
  663. return web.Response(
  664. content_type="application/json",
  665. text=json.dumps(
  666. {"code": 0, "msg": "No more introduction content"}
  667. ),
  668. )
  669. except Exception as e:
  670. logger.exception('exception in intro_play_completed:')
  671. return web.Response(
  672. content_type="application/json",
  673. text=json.dumps(
  674. {"code": -1, "msg": str(e)}
  675. ),
  676. )
  677. async def on_shutdown(app):
  678. # close peer connections
  679. coros = [pc.close() for pc in pcs]
  680. await asyncio.gather(*coros)
  681. pcs.clear()
  682. async def post(url,data):
  683. try:
  684. async with aiohttp.ClientSession() as session:
  685. async with session.post(url,data=data) as response:
  686. return await response.text()
  687. except aiohttp.ClientError as e:
  688. logger.info(f'Error: {e}')
  689. async def run(push_url,sessionid):
  690. nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
  691. nerfreals[sessionid] = nerfreal
  692. pc = RTCPeerConnection()
  693. pcs.add(pc)
  694. @pc.on("connectionstatechange")
  695. async def on_connectionstatechange():
  696. logger.info("Connection state is %s" % pc.connectionState)
  697. if pc.connectionState == "failed":
  698. await pc.close()
  699. pcs.discard(pc)
  700. player = HumanPlayer(nerfreals[sessionid])
  701. audio_sender = pc.addTrack(player.audio)
  702. video_sender = pc.addTrack(player.video)
  703. await pc.setLocalDescription(await pc.createOffer())
  704. answer = await post(push_url,pc.localDescription.sdp)
  705. await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
  706. ##########################################
  707. # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
  708. # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
  709. if __name__ == '__main__':
  710. mp.set_start_method('spawn')
  711. parser = argparse.ArgumentParser()
  712. # audio FPS
  713. parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
  714. # sliding window left-middle-right length (unit: 20ms)
  715. parser.add_argument('-l', type=int, default=10)
  716. parser.add_argument('-m', type=int, default=8)
  717. parser.add_argument('-r', type=int, default=10)
  718. parser.add_argument('--W', type=int, default=450, help="GUI width")
  719. parser.add_argument('--H', type=int, default=450, help="GUI height")
  720. #musetalk opt
  721. parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars")
  722. #parser.add_argument('--bbox_shift', type=int, default=5)
  723. parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
  724. parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
  725. parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice fishtts tencent doubao indextts2 azuretts qwen3tts
  726. parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural",help="参考文件名或语音模型 ID,默认值为 edgetts 的语音模型 ID zh-CN-YunxiaNeural, 若--tts 指定为 azuretts, 可以使用 Azure 语音模型 ID, 如 zh-CN-XiaoxiaoMultilingualNeural")
  727. parser.add_argument('--REF_TEXT', type=str, default=None)
  728. parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
  729. parser.add_argument('--QWEN3_TTS_MODEL_PATH', type=str, default='/home/test/Digital_Human/Qwen3-TTS-12Hz-1.7B-Base', help="Qwen3-TTS 模型本地路径")
  730. parser.add_argument('--QWEN3_TTS_LANGUAGE', type=str, default='Chinese', help="Qwen3-TTS 语言设置 (Chinese, English 等)")
  731. # parser.add_argument('--CHARACTER', type=str, default='test')
  732. # parser.add_argument('--EMOTION', type=str, default='default')
  733. parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
  734. parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam
  735. parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
  736. parser.add_argument('--max_session', type=int, default=1) #multi session count
  737. parser.add_argument('--listenport', type=int, default=7868, help="web listen port")
  738. opt = parser.parse_args()
  739. #app.config.from_object(opt)
  740. #print(app.config)
  741. opt.customopt = []
  742. if opt.customvideo_config!='':
  743. with open(opt.customvideo_config,'r') as file:
  744. opt.customopt = json.load(file)
  745. # if opt.model == 'ernerf':
  746. # from nerfreal import NeRFReal,load_model,load_avatar
  747. # model = load_model(opt)
  748. # avatar = load_avatar(opt)
  749. if opt.model == 'musetalk':
  750. from musereal import MuseReal,load_model,load_avatar,warm_up
  751. logger.info(opt)
  752. model = load_model()
  753. avatar = load_avatar(opt.avatar_id)
  754. warm_up(opt.batch_size,model)
  755. elif opt.model == 'wav2lip':
  756. from lipreal import LipReal,load_model,load_avatar,warm_up
  757. logger.info(opt)
  758. model = load_model("./models/wav2lip256.pth")
  759. avatar = load_avatar(opt.avatar_id)
  760. warm_up(opt.batch_size,model,256)
  761. elif opt.model == 'ultralight':
  762. from lightreal import LightReal,load_model,load_avatar,warm_up
  763. logger.info(opt)
  764. model = load_model(opt)
  765. avatar = load_avatar(opt.avatar_id)
  766. warm_up(opt.batch_size,avatar,160)
  767. # if opt.transport=='rtmp':
  768. # thread_quit = Event()
  769. # nerfreals[0] = build_nerfreal(0)
  770. # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
  771. # rendthrd.start()
  772. if opt.transport=='virtualcam':
  773. thread_quit = Event()
  774. nerfreals[0] = build_nerfreal(0)
  775. rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
  776. rendthrd.start()
  777. #############################################################################
  778. appasync = web.Application(client_max_size=1024**2*100)
  779. appasync.on_shutdown.append(on_shutdown)
  780. appasync.router.add_post("/offer", offer)
  781. appasync.router.add_post("/human", human)
  782. appasync.router.add_post("/api/chat", api_chat) # 新增独立聊天API
  783. appasync.router.add_post("/humanaudio", humanaudio)
  784. appasync.router.add_post("/set_audiotype", set_audiotype)
  785. appasync.router.add_post("/record", record)
  786. appasync.router.add_post("/interrupt_talk", interrupt_talk)
  787. appasync.router.add_post("/is_speaking", is_speaking)
  788. appasync.router.add_post("/knowledge_intro", knowledge_intro) # 新增知识库介绍 API
  789. appasync.router.add_post("/resume_interrupted", resume_interrupted) # 新增恢复中断消息 API
  790. appasync.router.add_post("/handle_user_question", handle_user_question) # 新增用户提问 API
  791. appasync.router.add_post("/continue_after_qa", continue_after_qa) # 新增问答后继续播放 API
  792. appasync.router.add_post("/intro_play_completed", intro_play_completed) # 新增介绍播放完成回调 API
  793. appasync.router.add_post("/log", log) # 前端日志接口
  794. appasync.router.add_static('/',path='web')
  795. # Configure default CORS settings.
  796. cors = aiohttp_cors.setup(appasync, defaults={
  797. "*": aiohttp_cors.ResourceOptions(
  798. allow_credentials=True,
  799. expose_headers="*",
  800. allow_headers="*",
  801. )
  802. })
  803. # Configure CORS on all routes.
  804. for route in list(appasync.router.routes()):
  805. cors.add(route)
  806. pagename='webrtcapi.html'
  807. if opt.transport=='rtmp':
  808. pagename='echoapi.html'
  809. elif opt.transport=='rtcpush':
  810. pagename='rtcpushapi.html'
  811. logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
  812. logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
  813. def run_server(runner):
  814. loop = asyncio.new_event_loop()
  815. asyncio.set_event_loop(loop)
  816. loop.run_until_complete(runner.setup())
  817. site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
  818. loop.run_until_complete(site.start())
  819. if opt.transport=='rtcpush':
  820. for k in range(opt.max_session):
  821. push_url = opt.push_url
  822. if k!=0:
  823. push_url = opt.push_url+str(k)
  824. loop.run_until_complete(run(push_url,k))
  825. loop.run_forever()
  826. #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
  827. run_server(web.AppRunner(appasync))
  828. #app.on_shutdown.append(on_shutdown)
  829. #app.router.add_post("/offer", offer)
  830. # print('start websocket server')
  831. # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
  832. # server.serve_forever()