| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940 |
- ###############################################################################
- # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
- # email: lipku@foxmail.com
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- ###############################################################################
- # server.py
- import flask
- from flask import Flask, render_template,send_from_directory,request, jsonify
- from flask_sockets import Sockets
- import base64
- import json
- #import gevent
- #from gevent import pywsgi
- #from geventwebsocket.handler import WebSocketHandler
- import re
- import numpy as np
- from threading import Thread,Event
- #import multiprocessing
- import torch.multiprocessing as mp
- from aiohttp import web
- import aiohttp
- import aiohttp_cors
- from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration
- from aiortc.rtcrtpsender import RTCRtpSender
- from webrtc import HumanPlayer
- from basereal import BaseReal
- from llm import llm_response
- import argparse
- import random
- import shutil
- import asyncio
- import torch
- import os
- import socket
- from typing import Dict
- from logger import logger
- import gc
- app = Flask(__name__)
- #sockets = Sockets(app)
- nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
- opt = None
- model = None
- avatar = None
-
- #####webrtc###############################
- pcs = set()
- def randN(N)->int:
- '''生成长度为 N的随机数 '''
- min = pow(10, N - 1)
- max = pow(10, N)
- return random.randint(min, max - 1)
- def build_nerfreal(sessionid:int)->BaseReal:
- opt.sessionid=sessionid
- if opt.model == 'wav2lip':
- from lipreal import LipReal
- nerfreal = LipReal(opt,model,avatar)
- elif opt.model == 'musetalk':
- from musereal import MuseReal
- nerfreal = MuseReal(opt,model,avatar)
- # elif opt.model == 'ernerf':
- # from nerfreal import NeRFReal
- # nerfreal = NeRFReal(opt,model,avatar)
- elif opt.model == 'ultralight':
- from lightreal import LightReal
- nerfreal = LightReal(opt,model,avatar)
- else:
- raise ValueError(f"Unsupported model type: {opt.model}")
- return nerfreal
- #@app.route('/offer', methods=['POST'])
- async def offer(request):
- try:
- params = await request.json()
- offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
- # if len(nerfreals) >= opt.max_session:
- # logger.info('reach max session')
- # return web.Response(
- # content_type="application/json",
- # text=json.dumps(
- # {"code": -1, "msg": "reach max session"}
- # ),
- # )
- sessionid = randN(6) #len(nerfreals)
- nerfreals[sessionid] = None
- logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals))
- nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
- nerfreals[sessionid] = nerfreal
-
- # 内网/跳板机环境:不使用外网 STUN(内网访问不到),只用 host 候选
- # 通过环境变量 WEBRTC_NAT_IP 指定跳板机对外暴露的 IP(浏览器能访问的 IP)
- nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
- if nat_public_ip:
- logger.info('Using NAT public IP for ICE: %s', nat_public_ip)
- ice_servers = [] # 内网不用 STUN
- else:
- # 使用多个 STUN 服务器(包括国内和国外的)
- ice_servers = [
- RTCIceServer(urls='stun:stun.l.google.com:19302'),
- RTCIceServer(urls='stun:stun1.l.google.com:19302'),
- RTCIceServer(urls='stun:stun2.l.google.com:19302'),
- RTCIceServer(urls='stun:stun3.l.google.com:19302'),
- RTCIceServer(urls='stun:stun4.l.google.com:19302'),
- ]
- pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers))
- pcs.add(pc)
- @pc.on("connectionstatechange")
- async def on_connectionstatechange():
- logger.info(f"Session {sessionid} - Connection state is {pc.connectionState}")
- if pc.connectionState == "connected":
- logger.info(f"Session {sessionid} - WebRTC connection established successfully!")
- if pc.connectionState == "failed":
- logger.error(f"Session {sessionid} - Connection failed!")
- await pc.close()
- pcs.discard(pc)
- if sessionid in nerfreals:
- del nerfreals[sessionid]
- if pc.connectionState == "closed":
- logger.info(f"Session {sessionid} - Connection closed")
- pcs.discard(pc)
- if sessionid in nerfreals:
- del nerfreals[sessionid]
- # gc.collect()
- player = HumanPlayer(nerfreals[sessionid])
- audio_sender = pc.addTrack(player.audio)
- video_sender = pc.addTrack(player.video)
-
- # 记录轨道添加信息
- logger.info(f"Added tracks for session {sessionid}: audio={player.audio}, video={player.video}")
-
- capabilities = RTCRtpSender.getCapabilities("video")
- preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs))
- preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs))
- preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs))
- transceiver = pc.getTransceivers()[1]
- transceiver.setCodecPreferences(preferences)
- await pc.setRemoteDescription(offer)
- answer = await pc.createAnswer()
- await pc.setLocalDescription(answer)
- sdp = pc.localDescription.sdp
- # 跳板机内网场景:如果配置了 WEBRTC_NAT_IP,将 SDP 中的内网 IP 替换为跳板机对外 IP
- # aioice 已经将真实绑定的端口(AIOICE_PORT_MIN~MAX 范围内)写入 SDP 的 a=candidate 行
- # 我们只需要把内网 IP 替换成浏览器能访问的 IP 即可
- nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '')
- if nat_public_ip:
- try:
- # 获取本机所有非 loopback IPv4 地址
- local_ips = set()
- bind_ip = os.environ.get('AIOICE_BIND_IP', '')
- if bind_ip:
- local_ips.add(bind_ip)
- else:
- hostname = socket.gethostname()
- for ip_info in socket.getaddrinfo(hostname, None, socket.AF_INET):
- ip = ip_info[4][0]
- if not ip.startswith('127.'):
- local_ips.add(ip)
- # 也枚举所有网卡 IP(getaddrinfo 可能不全)
- try:
- import ifaddr
- for adapter in ifaddr.get_adapters():
- for ip_obj in adapter.ips:
- if isinstance(ip_obj.ip, str) and not ip_obj.ip.startswith('127.'):
- local_ips.add(ip_obj.ip)
- except Exception:
- pass
- # 将 SDP 中所有内网 IP 替换为跳板机对外 IP
- for local_ip in local_ips:
- sdp = sdp.replace(local_ip, nat_public_ip)
- logger.info('NAT mapping: %s -> %s in SDP', local_ips, nat_public_ip)
- except Exception as e:
- logger.warning('Failed to apply NAT IP mapping to SDP: %s', e)
- else:
- logger.info('WEBRTC_NAT_IP not set. SDP uses server internal IPs. '
- 'Set WEBRTC_NAT_IP=<jumphost_ip> if browser cannot reach server directly.')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"sdp": sdp, "type": pc.localDescription.type, "sessionid": sessionid}
- ),
- )
- except Exception as e:
- logger.exception('Error in offer:')
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": -1, "msg": str(e)}),
- status=500
- )
- async def log(request):
- """接收前端日志的接口"""
- try:
- params = await request.json()
- log_type = params.get('type', 'info')
- message = params.get('message', '')
- logger.info('[WEBRTC] %s', message)
- except Exception:
- pass
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": 0}),
- )
- async def human(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 标记是否已经处理了聊天请求
- chat_processed = False
-
- if params.get('interrupt'):
- # 立即中断当前播放,无论是否是聊天请求
- nerfreals[sessionid].flush_talk()
- # 检查是否是用户在介绍过程中提问
- if params['type'] == 'chat':
- knowledge_base_type = params.get('knowledge_base', None)
- during_intro = params.get('during_intro', False)
- # 处理用户问题(同步执行,确保立即处理)
- llm_response(params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
- chat_processed = True
- if params['type']=='echo':
- nerfreals[sessionid].put_msg_txt(params['text'])
- elif params['type']=='chat' and not chat_processed:
- # 如果没有中断标志,或者已经处理过聊天请求,则不再处理
- knowledge_base_type = params.get('knowledge_base', None)
- during_intro = params.get('during_intro', False) # 是否在介绍过程中
- asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid], knowledge_base_type, during_intro)
- #nerfreals[sessionid].put_msg_txt(res)
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def api_chat(request):
- """
- 独立聊天API - 让数字人说话/问答
- 请求体: {
- "text": "用户说的话",
- "type": "chat", // "chat"=问答, "echo"=纯播报
- "interrupt": true, // 是否打断当前播放
- "sessionid": 123456, // WebRTC session ID
- "knowledge_base": "", // 可选: 指定知识库
- "during_intro": false // 可选: 是否在介绍过程中
- }
- """
- try:
- params = await request.json()
-
- # 参数验证
- text = params.get('text', '').strip()
- if not text:
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": -1, "msg": "text parameter is required"}),
- )
-
- sessionid = params.get('sessionid', 0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- msg_type = params.get('type', 'chat')
- interrupt = params.get('interrupt', True)
- knowledge_base = params.get('knowledge_base', None)
- during_intro = params.get('during_intro', False)
-
- logger.info(f"API Chat - Session {sessionid}, Type: {msg_type}, Text: {text[:50]}...")
-
- # 处理中断
- if interrupt:
- nerfreals[sessionid].flush_talk()
-
- # 根据类型处理
- if msg_type == 'echo':
- # 纯播报模式 - 直接播放文本
- nerfreals[sessionid].put_msg_txt(text)
- else:
- # 聊天/问答模式 - 走LLM
- asyncio.get_event_loop().run_in_executor(
- None,
- llm_response,
- text,
- nerfreals[sessionid],
- knowledge_base,
- during_intro
- )
-
- return web.Response(
- content_type="application/json",
- text=json.dumps({
- "code": 0,
- "msg": "ok",
- "data": {
- "sessionid": sessionid,
- "type": msg_type,
- "text": text[:100] # 返回前100字符用于确认
- }
- }),
- )
-
- except Exception as e:
- logger.exception('API Chat exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps({"code": -1, "msg": str(e)}),
- )
- async def interrupt_talk(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- nerfreals[sessionid].flush_talk()
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def humanaudio(request):
- try:
- form= await request.post()
- sessionid = int(form.get('sessionid',0))
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- fileobj = form["file"]
- filename=fileobj.filename
- filebytes=fileobj.file.read()
- nerfreals[sessionid].put_audio_file(filebytes)
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def set_audiotype(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit'])
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def record(request):
- try:
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- if params['type']=='start_record':
- # nerfreals[sessionid].put_msg_txt(params['text'])
- nerfreals[sessionid].start_recording()
- elif params['type']=='end_record':
- nerfreals[sessionid].stop_recording()
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg":"ok"}
- ),
- )
- except Exception as e:
- logger.exception('exception:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def is_speaking(request):
- params = await request.json()
- sessionid = params.get('sessionid',0)
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "data": nerfreals[sessionid].is_speaking()}
- )
- )
- async def knowledge_intro(request):
- """返回知识库介绍内容"""
- try:
- from knowledge_intro import start_intro_play, knowledge_intro
- # 启动完整版介绍播放,获取第一条文案
- play_result = start_intro_play("full")
- intro_text = play_result.get("text", "")
-
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 保存介绍实例到会话中,便于后续操作
- if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- nerfreals[sessionid].knowledge_intro_instance = knowledge_intro
-
- # 保存介绍播放状态到会话中
- if not hasattr(nerfreals[sessionid], 'intro_play_state'):
- nerfreals[sessionid].intro_play_state = {
- "is_playing": True,
- "current_type": "full",
- "last_played_index": 0,
- "is_paused": False,
- "is_waiting_next": False
- }
-
- # 使用支持打断恢复的新方法来播放介绍内容
- nerfreals[sessionid].start_intro_with_interrupt_capability(intro_text)
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Knowledge intro played successfully", "text": intro_text, "mode": "intro", "play_index": play_result.get("play_index", 1), "total_count": play_result.get("total_count", 8)}
- ),
- )
- except Exception as e:
- logger.exception('exception in knowledge_intro:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def resume_interrupted(request):
- """恢复播放被中断的消息"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 尝试恢复被中断的消息
- resumed = nerfreals[sessionid].resume_interrupted()
-
- if resumed:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Interrupted messages resumed successfully", "mode": "intro"}
- ),
- )
- else:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "No interrupted messages to resume"}
- ),
- )
- except Exception as e:
- logger.exception('exception in resume_interrupted:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def handle_user_question(request):
- """处理用户提问,暂停当前内容并优先回答问题"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
- question = params.get('text', '')
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- if not question:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": "Question text is required"}
- ),
- )
-
- # 检查是否在介绍过程中,如果是,则标记
- during_intro = params.get('during_intro', False)
- if during_intro and hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- # 暂停介绍播放
- nerfreals[sessionid].knowledge_intro_instance.pause_play()
-
- # 使用高优先级方法处理用户问题
- knowledge_base_type = params.get('knowledge_base', None)
- # 直接调用llm_response,它会使用put_user_question方法
- asyncio.get_event_loop().run_in_executor(None, llm_response, question, nerfreals[sessionid], knowledge_base_type, during_intro)
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "User question processed successfully", "mode": "qa"}
- ),
- )
- except Exception as e:
- logger.exception('exception in handle_user_question:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def continue_after_qa(request):
- """在用户问题回答完毕后,继续播放之前的内容"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 尝试恢复被中断的内容
- if hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- # 从knowledge_intro_instance获取下一条内容
- next_content = nerfreals[sessionid].knowledge_intro_instance.resume_play()
- if next_content:
- # 播放下一条介绍内容
- nerfreals[sessionid].put_msg_txt(next_content['text'])
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Previously interrupted content resumed successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8)}
- ),
- )
-
- # 如果没有找到knowledge_intro_instance或没有剩余内容
- resumed = nerfreals[sessionid].resume_interrupted()
-
- if resumed:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Previously interrupted content resumed successfully", "mode": "intro"}
- ),
- )
- else:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "No interrupted content to resume"}
- ),
- )
- except Exception as e:
- logger.exception('exception in continue_after_qa:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def intro_play_completed(request):
- """处理介绍播放完成的回调,自动播放下一条介绍内容"""
- try:
- params = await request.json()
- sessionid = params.get('sessionid', 0)
-
- if sessionid not in nerfreals or nerfreals[sessionid] is None:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": f"Session {sessionid} not found or not initialized"}
- ),
- )
-
- # 检查是否有介绍实例
- if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'):
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": "Knowledge intro instance not found"}
- ),
- )
-
- # 检查介绍播放状态
- if hasattr(nerfreals[sessionid], 'intro_play_state') and not nerfreals[sessionid].intro_play_state.get("is_playing", True):
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Introduction playback is paused"}
- ),
- )
-
- # 获取下一条介绍内容
- next_content = nerfreals[sessionid].knowledge_intro_instance._get_next_content()
- if next_content:
- # 播放下一条介绍内容
- nerfreals[sessionid].put_msg_txt(next_content['text'])
-
- # 更新播放状态
- if hasattr(nerfreals[sessionid], 'intro_play_state'):
- nerfreals[sessionid].intro_play_state["last_played_index"] = next_content.get("play_index", 1)
-
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "Next introduction content played successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8), "is_last": next_content.get("is_last", False)}
- ),
- )
- else:
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": 0, "msg": "No more introduction content"}
- ),
- )
- except Exception as e:
- logger.exception('exception in intro_play_completed:')
- return web.Response(
- content_type="application/json",
- text=json.dumps(
- {"code": -1, "msg": str(e)}
- ),
- )
- async def on_shutdown(app):
- # close peer connections
- coros = [pc.close() for pc in pcs]
- await asyncio.gather(*coros)
- pcs.clear()
- async def post(url,data):
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url,data=data) as response:
- return await response.text()
- except aiohttp.ClientError as e:
- logger.info(f'Error: {e}')
- async def run(push_url,sessionid):
- nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid)
- nerfreals[sessionid] = nerfreal
- pc = RTCPeerConnection()
- pcs.add(pc)
- @pc.on("connectionstatechange")
- async def on_connectionstatechange():
- logger.info("Connection state is %s" % pc.connectionState)
- if pc.connectionState == "failed":
- await pc.close()
- pcs.discard(pc)
- player = HumanPlayer(nerfreals[sessionid])
- audio_sender = pc.addTrack(player.audio)
- video_sender = pc.addTrack(player.video)
- await pc.setLocalDescription(await pc.createOffer())
- answer = await post(push_url,pc.localDescription.sdp)
- await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer'))
- ##########################################
- # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
- # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
- if __name__ == '__main__':
- mp.set_start_method('spawn')
- parser = argparse.ArgumentParser()
-
- # audio FPS
- parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50")
- # sliding window left-middle-right length (unit: 20ms)
- parser.add_argument('-l', type=int, default=10)
- parser.add_argument('-m', type=int, default=8)
- parser.add_argument('-r', type=int, default=10)
- parser.add_argument('--W', type=int, default=450, help="GUI width")
- parser.add_argument('--H', type=int, default=450, help="GUI height")
- #musetalk opt
- parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars")
- #parser.add_argument('--bbox_shift', type=int, default=5)
- parser.add_argument('--batch_size', type=int, default=16, help="infer batch")
- parser.add_argument('--customvideo_config', type=str, default='', help="custom action json")
- parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice fishtts tencent doubao indextts2 azuretts qwen3tts
- parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural",help="参考文件名或语音模型 ID,默认值为 edgetts 的语音模型 ID zh-CN-YunxiaNeural, 若--tts 指定为 azuretts, 可以使用 Azure 语音模型 ID, 如 zh-CN-XiaoxiaoMultilingualNeural")
- parser.add_argument('--REF_TEXT', type=str, default=None)
- parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000
- parser.add_argument('--QWEN3_TTS_MODEL_PATH', type=str, default='/home/test/Digital_Human/Qwen3-TTS-12Hz-1.7B-Base', help="Qwen3-TTS 模型本地路径")
- parser.add_argument('--QWEN3_TTS_LANGUAGE', type=str, default='Chinese', help="Qwen3-TTS 语言设置 (Chinese, English 等)")
- # parser.add_argument('--CHARACTER', type=str, default='test')
- # parser.add_argument('--EMOTION', type=str, default='default')
- parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight
- parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam
- parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream
- parser.add_argument('--max_session', type=int, default=1) #multi session count
- parser.add_argument('--listenport', type=int, default=7868, help="web listen port")
- opt = parser.parse_args()
- #app.config.from_object(opt)
- #print(app.config)
- opt.customopt = []
- if opt.customvideo_config!='':
- with open(opt.customvideo_config,'r') as file:
- opt.customopt = json.load(file)
- # if opt.model == 'ernerf':
- # from nerfreal import NeRFReal,load_model,load_avatar
- # model = load_model(opt)
- # avatar = load_avatar(opt)
- if opt.model == 'musetalk':
- from musereal import MuseReal,load_model,load_avatar,warm_up
- logger.info(opt)
- model = load_model()
- avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,model)
- elif opt.model == 'wav2lip':
- from lipreal import LipReal,load_model,load_avatar,warm_up
- logger.info(opt)
- model = load_model("./models/wav2lip256.pth")
- avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,model,256)
- elif opt.model == 'ultralight':
- from lightreal import LightReal,load_model,load_avatar,warm_up
- logger.info(opt)
- model = load_model(opt)
- avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,avatar,160)
- # if opt.transport=='rtmp':
- # thread_quit = Event()
- # nerfreals[0] = build_nerfreal(0)
- # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
- # rendthrd.start()
- if opt.transport=='virtualcam':
- thread_quit = Event()
- nerfreals[0] = build_nerfreal(0)
- rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,))
- rendthrd.start()
- #############################################################################
- appasync = web.Application(client_max_size=1024**2*100)
- appasync.on_shutdown.append(on_shutdown)
- appasync.router.add_post("/offer", offer)
- appasync.router.add_post("/human", human)
- appasync.router.add_post("/api/chat", api_chat) # 新增独立聊天API
- appasync.router.add_post("/humanaudio", humanaudio)
- appasync.router.add_post("/set_audiotype", set_audiotype)
- appasync.router.add_post("/record", record)
- appasync.router.add_post("/interrupt_talk", interrupt_talk)
- appasync.router.add_post("/is_speaking", is_speaking)
- appasync.router.add_post("/knowledge_intro", knowledge_intro) # 新增知识库介绍 API
- appasync.router.add_post("/resume_interrupted", resume_interrupted) # 新增恢复中断消息 API
- appasync.router.add_post("/handle_user_question", handle_user_question) # 新增用户提问 API
- appasync.router.add_post("/continue_after_qa", continue_after_qa) # 新增问答后继续播放 API
- appasync.router.add_post("/intro_play_completed", intro_play_completed) # 新增介绍播放完成回调 API
- appasync.router.add_post("/log", log) # 前端日志接口
- appasync.router.add_static('/',path='web')
- # Configure default CORS settings.
- cors = aiohttp_cors.setup(appasync, defaults={
- "*": aiohttp_cors.ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_headers="*",
- )
- })
- # Configure CORS on all routes.
- for route in list(appasync.router.routes()):
- cors.add(route)
- pagename='webrtcapi.html'
- if opt.transport=='rtmp':
- pagename='echoapi.html'
- elif opt.transport=='rtcpush':
- pagename='rtcpushapi.html'
- logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
- logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
- def run_server(runner):
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- loop.run_until_complete(runner.setup())
- site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
- loop.run_until_complete(site.start())
- if opt.transport=='rtcpush':
- for k in range(opt.max_session):
- push_url = opt.push_url
- if k!=0:
- push_url = opt.push_url+str(k)
- loop.run_until_complete(run(push_url,k))
- loop.run_forever()
- #Thread(target=run_server, args=(web.AppRunner(appasync),)).start()
- run_server(web.AppRunner(appasync))
- #app.on_shutdown.append(on_shutdown)
- #app.router.add_post("/offer", offer)
- # print('start websocket server')
- # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler)
- # server.serve_forever()
-
-
|