| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- """
- 语音唤醒检测器
- 基于 sherpa-onnx 实现唤醒词检测
- """
- import asyncio
- import time
- from typing import Callable, Optional
- import numpy as np
- import sherpa_onnx
- from .config import WakeWordConfig
- from .audio_capture import AudioListener
- class WakeWordDetector(AudioListener):
- """语音唤醒检测器"""
- def __init__(self, config: WakeWordConfig):
- """
- 初始化唤醒词检测器
- Args:
- config: 唤醒词配置
- """
- self.config = config
- # 状态标志
- self.is_running = False
- self.paused = False
- self.detection_task = None
- # 音频数据队列
- self._audio_queue = asyncio.Queue(maxsize=100)
- # 防重复触发机制
- self.last_detection_time = 0
- # 回调函数
- self.on_detected_callback: Optional[Callable] = None
- self.on_error_callback: Optional[Callable] = None
- # Sherpa-ONNX 组件
- self.keyword_spotter = None
- self.stream = None
- # 初始化模型
- self._init_kws_model()
- def _init_kws_model(self):
- """初始化Sherpa-ONNX KeywordSpotter模型"""
- try:
- # 验证配置
- self.config.validate()
- # 构建模型文件路径
- encoder_path = self.config.model_path / "encoder.onnx"
- decoder_path = self.config.model_path / "decoder.onnx"
- joiner_path = self.config.model_path / "joiner.onnx"
- tokens_path = self.config.model_path / "tokens.txt"
- keywords_path = self.config.model_path / "keywords.txt"
- print(f"加载Sherpa-ONNX KeywordSpotter模型: {self.config.model_path}")
- # 创建KeywordSpotter
- self.keyword_spotter = sherpa_onnx.KeywordSpotter(
- tokens=str(tokens_path),
- encoder=str(encoder_path),
- decoder=str(decoder_path),
- joiner=str(joiner_path),
- keywords_file=str(keywords_path),
- num_threads=self.config.num_threads,
- sample_rate=self.config.sample_rate,
- feature_dim=80,
- max_active_paths=self.config.max_active_paths,
- keywords_score=self.config.keywords_score,
- keywords_threshold=self.config.keywords_threshold,
- num_trailing_blanks=self.config.num_trailing_blanks,
- provider=self.config.provider,
- )
- print("Sherpa-ONNX KeywordSpotter模型加载成功")
- except Exception as e:
- raise RuntimeError(f"Sherpa-ONNX KeywordSpotter初始化失败: {e}")
- def on_detected(self, callback: Callable):
- """设置检测到唤醒词的回调函数"""
- self.on_detected_callback = callback
- def on_error(self, callback: Callable):
- """设置错误回调函数"""
- self.on_error_callback = callback
- def on_audio_data(self, audio_data: np.ndarray):
- """接收音频数据(AudioListener接口实现)"""
- if not self.is_running or self.paused:
- return
- try:
- self._audio_queue.put_nowait(audio_data.copy())
- except asyncio.QueueFull:
- try:
- self._audio_queue.get_nowait()
- self._audio_queue.put_nowait(audio_data.copy())
- except asyncio.QueueEmpty:
- self._audio_queue.put_nowait(audio_data.copy())
- except Exception as e:
- print(f"音频数据入队失败: {e}")
- async def start(self) -> bool:
- """启动检测器"""
- if not self.keyword_spotter:
- print("KeywordSpotter未初始化")
- return False
- try:
- self.is_running = True
- self.paused = False
- # 创建检测流
- self.stream = self.keyword_spotter.create_stream()
- # 启动检测任务
- self.detection_task = asyncio.create_task(self._detection_loop())
- print("唤醒词检测器启动成功")
- return True
- except Exception as e:
- print(f"启动唤醒词检测器失败: {e}")
- return False
- async def _detection_loop(self):
- """检测循环"""
- error_count = 0
- MAX_ERRORS = 5
- while self.is_running:
- try:
- if self.paused:
- await asyncio.sleep(0.1)
- continue
- # 处理音频数据
- await self._process_audio()
- # 减少延迟提高响应速度
- await asyncio.sleep(0.005)
- error_count = 0
- except asyncio.CancelledError:
- break
- except Exception as e:
- error_count += 1
- print(f"KWS检测循环错误({error_count}/{MAX_ERRORS}): {e}")
- # 调用错误回调
- if self.on_error_callback:
- try:
- if asyncio.iscoroutinefunction(self.on_error_callback):
- await self.on_error_callback(e)
- else:
- self.on_error_callback(e)
- except Exception as callback_error:
- print(f"执行错误回调时失败: {callback_error}")
- if error_count >= MAX_ERRORS:
- print("达到最大错误次数,停止KWS检测")
- break
- await asyncio.sleep(1)
- async def _process_audio(self):
- """处理音频数据"""
- try:
- if not self.stream:
- return
- try:
- audio_data = self._audio_queue.get_nowait()
- except asyncio.QueueEmpty:
- return
- if audio_data is None or len(audio_data) == 0:
- return
- # 转换音频格式为 float32
- if audio_data.dtype == np.int16:
- samples = audio_data.astype(np.float32) / 32768.0
- else:
- samples = audio_data.astype(np.float32)
- # 提供音频数据给KeywordSpotter
- self.stream.accept_waveform(sample_rate=self.config.sample_rate, waveform=samples)
- # 检查是否准备好解码
- if self.keyword_spotter.is_ready(self.stream):
- self.keyword_spotter.decode_stream(self.stream)
- result = self.keyword_spotter.get_result(self.stream)
- if result:
- await self._handle_detection_result(result)
- # 重置流状态
- self.keyword_spotter.reset_stream(self.stream)
- except Exception as e:
- print(f"KWS音频处理错误: {e}")
- raise
- async def _handle_detection_result(self, result):
- """处理检测结果"""
- # 防重复触发检查
- current_time = time.time()
- if current_time - self.last_detection_time < self.config.detection_cooldown:
- return
- self.last_detection_time = current_time
- # 触发回调
- if self.on_detected_callback:
- try:
- if asyncio.iscoroutinefunction(self.on_detected_callback):
- await self.on_detected_callback(result, result)
- else:
- self.on_detected_callback(result, result)
- except Exception as e:
- print(f"唤醒词回调执行失败: {e}")
- async def stop(self):
- """停止检测器"""
- self.is_running = False
- if self.detection_task:
- self.detection_task.cancel()
- try:
- await self.detection_task
- except asyncio.CancelledError:
- pass
- # 清空队列
- while not self._audio_queue.empty():
- try:
- self._audio_queue.get_nowait()
- except asyncio.QueueEmpty:
- break
- print("唤醒词检测器已停止")
- def pause(self):
- """暂停检测"""
- self.paused = True
- def resume(self):
- """恢复检测"""
- self.paused = False
|