############################################################################### # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking # email: lipku@foxmail.com # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### # server.py import flask from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets import base64 import json #import gevent #from gevent import pywsgi #from geventwebsocket.handler import WebSocketHandler import re import numpy as np from threading import Thread,Event #import multiprocessing import torch.multiprocessing as mp from aiohttp import web import aiohttp import aiohttp_cors from aiortc import RTCPeerConnection, RTCSessionDescription,RTCIceServer,RTCConfiguration from aiortc.rtcrtpsender import RTCRtpSender from webrtc import HumanPlayer from basereal import BaseReal from llm import llm_response import argparse import random import shutil import asyncio import torch import os import socket from typing import Dict from logger import logger import gc app = Flask(__name__) #sockets = Sockets(app) nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal opt = None model = None avatar = None #####webrtc############################### pcs = set() def randN(N)->int: '''生成长度为 N的随机数 ''' min = pow(10, N - 1) max = pow(10, N) return random.randint(min, max - 1) def build_nerfreal(sessionid:int)->BaseReal: opt.sessionid=sessionid if opt.model == 'wav2lip': from lipreal import LipReal nerfreal = LipReal(opt,model,avatar) elif opt.model == 'musetalk': from musereal import MuseReal nerfreal = MuseReal(opt,model,avatar) # elif opt.model == 'ernerf': # from nerfreal import NeRFReal # nerfreal = NeRFReal(opt,model,avatar) elif opt.model == 'ultralight': from lightreal import LightReal nerfreal = LightReal(opt,model,avatar) else: raise ValueError(f"Unsupported model type: {opt.model}") return nerfreal #@app.route('/offer', methods=['POST']) async def offer(request): try: params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) # if len(nerfreals) >= opt.max_session: # logger.info('reach max session') # return web.Response( # content_type="application/json", # text=json.dumps( # {"code": -1, "msg": "reach max session"} # ), # ) sessionid = randN(6) #len(nerfreals) nerfreals[sessionid] = None logger.info('sessionid=%d, session num=%d',sessionid,len(nerfreals)) nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreals[sessionid] = nerfreal # 内网/跳板机环境:不使用外网 STUN(内网访问不到),只用 host 候选 # 通过环境变量 WEBRTC_NAT_IP 指定跳板机对外暴露的 IP(浏览器能访问的 IP) nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '') if nat_public_ip: logger.info('Using NAT public IP for ICE: %s', nat_public_ip) ice_servers = [] # 内网不用 STUN else: # 使用多个 STUN 服务器(包括国内和国外的) ice_servers = [ RTCIceServer(urls='stun:stun.l.google.com:19302'), RTCIceServer(urls='stun:stun1.l.google.com:19302'), RTCIceServer(urls='stun:stun2.l.google.com:19302'), RTCIceServer(urls='stun:stun3.l.google.com:19302'), RTCIceServer(urls='stun:stun4.l.google.com:19302'), ] pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers)) pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info(f"Session {sessionid} - Connection state is {pc.connectionState}") if pc.connectionState == "connected": logger.info(f"Session {sessionid} - WebRTC connection established successfully!") if pc.connectionState == "failed": logger.error(f"Session {sessionid} - Connection failed!") await pc.close() pcs.discard(pc) if sessionid in nerfreals: del nerfreals[sessionid] if pc.connectionState == "closed": logger.info(f"Session {sessionid} - Connection closed") pcs.discard(pc) if sessionid in nerfreals: del nerfreals[sessionid] # gc.collect() player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) # 记录轨道添加信息 logger.info(f"Added tracks for session {sessionid}: audio={player.audio}, video={player.video}") capabilities = RTCRtpSender.getCapabilities("video") preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) transceiver = pc.getTransceivers()[1] transceiver.setCodecPreferences(preferences) await pc.setRemoteDescription(offer) answer = await pc.createAnswer() await pc.setLocalDescription(answer) sdp = pc.localDescription.sdp # 跳板机内网场景:如果配置了 WEBRTC_NAT_IP,将 SDP 中的内网 IP 替换为跳板机对外 IP # aioice 已经将真实绑定的端口(AIOICE_PORT_MIN~MAX 范围内)写入 SDP 的 a=candidate 行 # 我们只需要把内网 IP 替换成浏览器能访问的 IP 即可 nat_public_ip = os.environ.get('WEBRTC_NAT_IP', '') if nat_public_ip: try: # 获取本机所有非 loopback IPv4 地址 local_ips = set() bind_ip = os.environ.get('AIOICE_BIND_IP', '') if bind_ip: local_ips.add(bind_ip) else: hostname = socket.gethostname() for ip_info in socket.getaddrinfo(hostname, None, socket.AF_INET): ip = ip_info[4][0] if not ip.startswith('127.'): local_ips.add(ip) # 也枚举所有网卡 IP(getaddrinfo 可能不全) try: import ifaddr for adapter in ifaddr.get_adapters(): for ip_obj in adapter.ips: if isinstance(ip_obj.ip, str) and not ip_obj.ip.startswith('127.'): local_ips.add(ip_obj.ip) except Exception: pass # 将 SDP 中所有内网 IP 替换为跳板机对外 IP for local_ip in local_ips: sdp = sdp.replace(local_ip, nat_public_ip) logger.info('NAT mapping: %s -> %s in SDP', local_ips, nat_public_ip) except Exception as e: logger.warning('Failed to apply NAT IP mapping to SDP: %s', e) else: logger.info('WEBRTC_NAT_IP not set. SDP uses server internal IPs. ' 'Set WEBRTC_NAT_IP= if browser cannot reach server directly.') return web.Response( content_type="application/json", text=json.dumps( {"sdp": sdp, "type": pc.localDescription.type, "sessionid": sessionid} ), ) except Exception as e: logger.exception('Error in offer:') return web.Response( content_type="application/json", text=json.dumps({"code": -1, "msg": str(e)}), status=500 ) async def log(request): """接收前端日志的接口""" try: params = await request.json() log_type = params.get('type', 'info') message = params.get('message', '') logger.info('[WEBRTC] %s', message) except Exception: pass return web.Response( content_type="application/json", text=json.dumps({"code": 0}), ) async def human(request): try: params = await request.json() sessionid = params.get('sessionid',0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) # 标记是否已经处理了聊天请求 chat_processed = False if params.get('interrupt'): # 立即中断当前播放,无论是否是聊天请求 nerfreals[sessionid].flush_talk() # 检查是否是用户在介绍过程中提问 if params['type'] == 'chat': knowledge_base_type = params.get('knowledge_base', None) during_intro = params.get('during_intro', False) # 处理用户问题(同步执行,确保立即处理) llm_response(params['text'], nerfreals[sessionid], knowledge_base_type, during_intro) chat_processed = True if params['type']=='echo': nerfreals[sessionid].put_msg_txt(params['text']) elif params['type']=='chat' and not chat_processed: # 如果没有中断标志,或者已经处理过聊天请求,则不再处理 knowledge_base_type = params.get('knowledge_base', None) during_intro = params.get('during_intro', False) # 是否在介绍过程中 asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'], nerfreals[sessionid], knowledge_base_type, during_intro) #nerfreals[sessionid].put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def api_chat(request): """ 独立聊天API - 让数字人说话/问答 请求体: { "text": "用户说的话", "type": "chat", // "chat"=问答, "echo"=纯播报 "interrupt": true, // 是否打断当前播放 "sessionid": 123456, // WebRTC session ID "knowledge_base": "", // 可选: 指定知识库 "during_intro": false // 可选: 是否在介绍过程中 } """ try: params = await request.json() # 参数验证 text = params.get('text', '').strip() if not text: return web.Response( content_type="application/json", text=json.dumps({"code": -1, "msg": "text parameter is required"}), ) sessionid = params.get('sessionid', 0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) msg_type = params.get('type', 'chat') interrupt = params.get('interrupt', True) knowledge_base = params.get('knowledge_base', None) during_intro = params.get('during_intro', False) logger.info(f"API Chat - Session {sessionid}, Type: {msg_type}, Text: {text[:50]}...") # 处理中断 if interrupt: nerfreals[sessionid].flush_talk() # 根据类型处理 if msg_type == 'echo': # 纯播报模式 - 直接播放文本 nerfreals[sessionid].put_msg_txt(text) else: # 聊天/问答模式 - 走LLM asyncio.get_event_loop().run_in_executor( None, llm_response, text, nerfreals[sessionid], knowledge_base, during_intro ) return web.Response( content_type="application/json", text=json.dumps({ "code": 0, "msg": "ok", "data": { "sessionid": sessionid, "type": msg_type, "text": text[:100] # 返回前100字符用于确认 } }), ) except Exception as e: logger.exception('API Chat exception:') return web.Response( content_type="application/json", text=json.dumps({"code": -1, "msg": str(e)}), ) async def interrupt_talk(request): try: params = await request.json() sessionid = params.get('sessionid',0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) nerfreals[sessionid].flush_talk() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def humanaudio(request): try: form= await request.post() sessionid = int(form.get('sessionid',0)) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) fileobj = form["file"] filename=fileobj.filename filebytes=fileobj.file.read() nerfreals[sessionid].put_audio_file(filebytes) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def set_audiotype(request): try: params = await request.json() sessionid = params.get('sessionid',0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit']) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def record(request): try: params = await request.json() sessionid = params.get('sessionid',0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) if params['type']=='start_record': # nerfreals[sessionid].put_msg_txt(params['text']) nerfreals[sessionid].start_recording() elif params['type']=='end_record': nerfreals[sessionid].stop_recording() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: logger.exception('exception:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def is_speaking(request): params = await request.json() sessionid = params.get('sessionid',0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data": nerfreals[sessionid].is_speaking()} ) ) async def knowledge_intro(request): """返回知识库介绍内容""" try: from knowledge_intro import start_intro_play, knowledge_intro # 启动完整版介绍播放,获取第一条文案 play_result = start_intro_play("full") intro_text = play_result.get("text", "") params = await request.json() sessionid = params.get('sessionid', 0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) # 保存介绍实例到会话中,便于后续操作 if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'): nerfreals[sessionid].knowledge_intro_instance = knowledge_intro # 保存介绍播放状态到会话中 if not hasattr(nerfreals[sessionid], 'intro_play_state'): nerfreals[sessionid].intro_play_state = { "is_playing": True, "current_type": "full", "last_played_index": 0, "is_paused": False, "is_waiting_next": False } # 使用支持打断恢复的新方法来播放介绍内容 nerfreals[sessionid].start_intro_with_interrupt_capability(intro_text) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "Knowledge intro played successfully", "text": intro_text, "mode": "intro", "play_index": play_result.get("play_index", 1), "total_count": play_result.get("total_count", 8)} ), ) except Exception as e: logger.exception('exception in knowledge_intro:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def resume_interrupted(request): """恢复播放被中断的消息""" try: params = await request.json() sessionid = params.get('sessionid', 0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) # 尝试恢复被中断的消息 resumed = nerfreals[sessionid].resume_interrupted() if resumed: return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "Interrupted messages resumed successfully", "mode": "intro"} ), ) else: return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "No interrupted messages to resume"} ), ) except Exception as e: logger.exception('exception in resume_interrupted:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def handle_user_question(request): """处理用户提问,暂停当前内容并优先回答问题""" try: params = await request.json() sessionid = params.get('sessionid', 0) question = params.get('text', '') if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) if not question: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": "Question text is required"} ), ) # 检查是否在介绍过程中,如果是,则标记 during_intro = params.get('during_intro', False) if during_intro and hasattr(nerfreals[sessionid], 'knowledge_intro_instance'): # 暂停介绍播放 nerfreals[sessionid].knowledge_intro_instance.pause_play() # 使用高优先级方法处理用户问题 knowledge_base_type = params.get('knowledge_base', None) # 直接调用llm_response,它会使用put_user_question方法 asyncio.get_event_loop().run_in_executor(None, llm_response, question, nerfreals[sessionid], knowledge_base_type, during_intro) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "User question processed successfully", "mode": "qa"} ), ) except Exception as e: logger.exception('exception in handle_user_question:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def continue_after_qa(request): """在用户问题回答完毕后,继续播放之前的内容""" try: params = await request.json() sessionid = params.get('sessionid', 0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) # 尝试恢复被中断的内容 if hasattr(nerfreals[sessionid], 'knowledge_intro_instance'): # 从knowledge_intro_instance获取下一条内容 next_content = nerfreals[sessionid].knowledge_intro_instance.resume_play() if next_content: # 播放下一条介绍内容 nerfreals[sessionid].put_msg_txt(next_content['text']) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "Previously interrupted content resumed successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8)} ), ) # 如果没有找到knowledge_intro_instance或没有剩余内容 resumed = nerfreals[sessionid].resume_interrupted() if resumed: return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "Previously interrupted content resumed successfully", "mode": "intro"} ), ) else: return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "No interrupted content to resume"} ), ) except Exception as e: logger.exception('exception in continue_after_qa:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def intro_play_completed(request): """处理介绍播放完成的回调,自动播放下一条介绍内容""" try: params = await request.json() sessionid = params.get('sessionid', 0) if sessionid not in nerfreals or nerfreals[sessionid] is None: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": f"Session {sessionid} not found or not initialized"} ), ) # 检查是否有介绍实例 if not hasattr(nerfreals[sessionid], 'knowledge_intro_instance'): return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": "Knowledge intro instance not found"} ), ) # 检查介绍播放状态 if hasattr(nerfreals[sessionid], 'intro_play_state') and not nerfreals[sessionid].intro_play_state.get("is_playing", True): return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "Introduction playback is paused"} ), ) # 获取下一条介绍内容 next_content = nerfreals[sessionid].knowledge_intro_instance._get_next_content() if next_content: # 播放下一条介绍内容 nerfreals[sessionid].put_msg_txt(next_content['text']) # 更新播放状态 if hasattr(nerfreals[sessionid], 'intro_play_state'): nerfreals[sessionid].intro_play_state["last_played_index"] = next_content.get("play_index", 1) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "Next introduction content played successfully", "text": next_content['text'], "mode": "intro", "play_index": next_content.get("play_index", 1), "total_count": next_content.get("total_count", 8), "is_last": next_content.get("is_last", False)} ), ) else: return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg": "No more introduction content"} ), ) except Exception as e: logger.exception('exception in intro_play_completed:') return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg": str(e)} ), ) async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() async def post(url,data): try: async with aiohttp.ClientSession() as session: async with session.post(url,data=data) as response: return await response.text() except aiohttp.ClientError as e: logger.info(f'Error: {e}') async def run(push_url,sessionid): nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) answer = await post(push_url,pc.localDescription.sdp) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) ########################################## # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' if __name__ == '__main__': mp.set_start_method('spawn') parser = argparse.ArgumentParser() # audio FPS parser.add_argument('--fps', type=int, default=50, help="audio fps,must be 50") # sliding window left-middle-right length (unit: 20ms) parser.add_argument('-l', type=int, default=10) parser.add_argument('-m', type=int, default=8) parser.add_argument('-r', type=int, default=10) parser.add_argument('--W', type=int, default=450, help="GUI width") parser.add_argument('--H', type=int, default=450, help="GUI height") #musetalk opt parser.add_argument('--avatar_id', type=str, default='avator_1', help="define which avatar in data/avatars") #parser.add_argument('--bbox_shift', type=int, default=5) parser.add_argument('--batch_size', type=int, default=16, help="infer batch") parser.add_argument('--customvideo_config', type=str, default='', help="custom action json") parser.add_argument('--tts', type=str, default='edgetts', help="tts service type") #xtts gpt-sovits cosyvoice fishtts tencent doubao indextts2 azuretts qwen3tts parser.add_argument('--REF_FILE', type=str, default="zh-CN-YunxiaNeural",help="参考文件名或语音模型 ID,默认值为 edgetts 的语音模型 ID zh-CN-YunxiaNeural, 若--tts 指定为 azuretts, 可以使用 Azure 语音模型 ID, 如 zh-CN-XiaoxiaoMultilingualNeural") parser.add_argument('--REF_TEXT', type=str, default=None) parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 parser.add_argument('--QWEN3_TTS_MODEL_PATH', type=str, default='/home/test/Digital_Human/Qwen3-TTS-12Hz-1.7B-Base', help="Qwen3-TTS 模型本地路径") parser.add_argument('--QWEN3_TTS_LANGUAGE', type=str, default='Chinese', help="Qwen3-TTS 语言设置 (Chinese, English 等)") # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') parser.add_argument('--model', type=str, default='musetalk') #musetalk wav2lip ultralight parser.add_argument('--transport', type=str, default='rtcpush') #webrtc rtcpush virtualcam parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--max_session', type=int, default=1) #multi session count parser.add_argument('--listenport', type=int, default=7868, help="web listen port") opt = parser.parse_args() #app.config.from_object(opt) #print(app.config) opt.customopt = [] if opt.customvideo_config!='': with open(opt.customvideo_config,'r') as file: opt.customopt = json.load(file) # if opt.model == 'ernerf': # from nerfreal import NeRFReal,load_model,load_avatar # model = load_model(opt) # avatar = load_avatar(opt) if opt.model == 'musetalk': from musereal import MuseReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model() avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model) elif opt.model == 'wav2lip': from lipreal import LipReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model("./models/wav2lip256.pth") avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model,256) elif opt.model == 'ultralight': from lightreal import LightReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model(opt) avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,avatar,160) # if opt.transport=='rtmp': # thread_quit = Event() # nerfreals[0] = build_nerfreal(0) # rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) # rendthrd.start() if opt.transport=='virtualcam': thread_quit = Event() nerfreals[0] = build_nerfreal(0) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) rendthrd.start() ############################################################################# appasync = web.Application(client_max_size=1024**2*100) appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) appasync.router.add_post("/api/chat", api_chat) # 新增独立聊天API appasync.router.add_post("/humanaudio", humanaudio) appasync.router.add_post("/set_audiotype", set_audiotype) appasync.router.add_post("/record", record) appasync.router.add_post("/interrupt_talk", interrupt_talk) appasync.router.add_post("/is_speaking", is_speaking) appasync.router.add_post("/knowledge_intro", knowledge_intro) # 新增知识库介绍 API appasync.router.add_post("/resume_interrupted", resume_interrupted) # 新增恢复中断消息 API appasync.router.add_post("/handle_user_question", handle_user_question) # 新增用户提问 API appasync.router.add_post("/continue_after_qa", continue_after_qa) # 新增问答后继续播放 API appasync.router.add_post("/intro_play_completed", intro_play_completed) # 新增介绍播放完成回调 API appasync.router.add_post("/log", log) # 前端日志接口 appasync.router.add_static('/',path='web') # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ "*": aiohttp_cors.ResourceOptions( allow_credentials=True, expose_headers="*", allow_headers="*", ) }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) pagename='webrtcapi.html' if opt.transport=='rtmp': pagename='echoapi.html' elif opt.transport=='rtcpush': pagename='rtcpushapi.html' logger.info('start http server; http://:'+str(opt.listenport)+'/'+pagename) logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://:'+str(opt.listenport)+'/dashboard.html') def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '0.0.0.0', opt.listenport) loop.run_until_complete(site.start()) if opt.transport=='rtcpush': for k in range(opt.max_session): push_url = opt.push_url if k!=0: push_url = opt.push_url+str(k) loop.run_until_complete(run(push_url,k)) loop.run_forever() #Thread(target=run_server, args=(web.AppRunner(appasync),)).start() run_server(web.AppRunner(appasync)) #app.on_shutdown.append(on_shutdown) #app.router.add_post("/offer", offer) # print('start websocket server') # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) # server.serve_forever()