""" 语音唤醒检测器 基于 sherpa-onnx 实现唤醒词检测 """ import asyncio import time from typing import Callable, Optional import numpy as np import sherpa_onnx from .config import WakeWordConfig from .audio_capture import AudioListener class WakeWordDetector(AudioListener): """语音唤醒检测器""" def __init__(self, config: WakeWordConfig): """ 初始化唤醒词检测器 Args: config: 唤醒词配置 """ self.config = config # 状态标志 self.is_running = False self.paused = False self.detection_task = None # 音频数据队列 self._audio_queue = asyncio.Queue(maxsize=100) # 防重复触发机制 self.last_detection_time = 0 # 回调函数 self.on_detected_callback: Optional[Callable] = None self.on_error_callback: Optional[Callable] = None # Sherpa-ONNX 组件 self.keyword_spotter = None self.stream = None # 初始化模型 self._init_kws_model() def _init_kws_model(self): """初始化Sherpa-ONNX KeywordSpotter模型""" try: # 验证配置 self.config.validate() # 构建模型文件路径 encoder_path = self.config.model_path / "encoder.onnx" decoder_path = self.config.model_path / "decoder.onnx" joiner_path = self.config.model_path / "joiner.onnx" tokens_path = self.config.model_path / "tokens.txt" keywords_path = self.config.model_path / "keywords.txt" print(f"加载Sherpa-ONNX KeywordSpotter模型: {self.config.model_path}") # 创建KeywordSpotter self.keyword_spotter = sherpa_onnx.KeywordSpotter( tokens=str(tokens_path), encoder=str(encoder_path), decoder=str(decoder_path), joiner=str(joiner_path), keywords_file=str(keywords_path), num_threads=self.config.num_threads, sample_rate=self.config.sample_rate, feature_dim=80, max_active_paths=self.config.max_active_paths, keywords_score=self.config.keywords_score, keywords_threshold=self.config.keywords_threshold, num_trailing_blanks=self.config.num_trailing_blanks, provider=self.config.provider, ) print("Sherpa-ONNX KeywordSpotter模型加载成功") except Exception as e: raise RuntimeError(f"Sherpa-ONNX KeywordSpotter初始化失败: {e}") def on_detected(self, callback: Callable): """设置检测到唤醒词的回调函数""" self.on_detected_callback = callback def on_error(self, callback: Callable): """设置错误回调函数""" self.on_error_callback = callback def on_audio_data(self, audio_data: np.ndarray): """接收音频数据(AudioListener接口实现)""" if not self.is_running or self.paused: return try: self._audio_queue.put_nowait(audio_data.copy()) except asyncio.QueueFull: try: self._audio_queue.get_nowait() self._audio_queue.put_nowait(audio_data.copy()) except asyncio.QueueEmpty: self._audio_queue.put_nowait(audio_data.copy()) except Exception as e: print(f"音频数据入队失败: {e}") async def start(self) -> bool: """启动检测器""" if not self.keyword_spotter: print("KeywordSpotter未初始化") return False try: self.is_running = True self.paused = False # 创建检测流 self.stream = self.keyword_spotter.create_stream() # 启动检测任务 self.detection_task = asyncio.create_task(self._detection_loop()) print("唤醒词检测器启动成功") return True except Exception as e: print(f"启动唤醒词检测器失败: {e}") return False async def _detection_loop(self): """检测循环""" error_count = 0 MAX_ERRORS = 5 while self.is_running: try: if self.paused: await asyncio.sleep(0.1) continue # 处理音频数据 await self._process_audio() # 减少延迟提高响应速度 await asyncio.sleep(0.005) error_count = 0 except asyncio.CancelledError: break except Exception as e: error_count += 1 print(f"KWS检测循环错误({error_count}/{MAX_ERRORS}): {e}") # 调用错误回调 if self.on_error_callback: try: if asyncio.iscoroutinefunction(self.on_error_callback): await self.on_error_callback(e) else: self.on_error_callback(e) except Exception as callback_error: print(f"执行错误回调时失败: {callback_error}") if error_count >= MAX_ERRORS: print("达到最大错误次数,停止KWS检测") break await asyncio.sleep(1) async def _process_audio(self): """处理音频数据""" try: if not self.stream: return try: audio_data = self._audio_queue.get_nowait() except asyncio.QueueEmpty: return if audio_data is None or len(audio_data) == 0: return # 转换音频格式为 float32 if audio_data.dtype == np.int16: samples = audio_data.astype(np.float32) / 32768.0 else: samples = audio_data.astype(np.float32) # 提供音频数据给KeywordSpotter self.stream.accept_waveform(sample_rate=self.config.sample_rate, waveform=samples) # 检查是否准备好解码 if self.keyword_spotter.is_ready(self.stream): self.keyword_spotter.decode_stream(self.stream) result = self.keyword_spotter.get_result(self.stream) if result: await self._handle_detection_result(result) # 重置流状态 self.keyword_spotter.reset_stream(self.stream) except Exception as e: print(f"KWS音频处理错误: {e}") raise async def _handle_detection_result(self, result): """处理检测结果""" # 防重复触发检查 current_time = time.time() if current_time - self.last_detection_time < self.config.detection_cooldown: return self.last_detection_time = current_time # 触发回调 if self.on_detected_callback: try: if asyncio.iscoroutinefunction(self.on_detected_callback): await self.on_detected_callback(result, result) else: self.on_detected_callback(result, result) except Exception as e: print(f"唤醒词回调执行失败: {e}") async def stop(self): """停止检测器""" self.is_running = False if self.detection_task: self.detection_task.cancel() try: await self.detection_task except asyncio.CancelledError: pass # 清空队列 while not self._audio_queue.empty(): try: self._audio_queue.get_nowait() except asyncio.QueueEmpty: break print("唤醒词检测器已停止") def pause(self): """暂停检测""" self.paused = True def resume(self): """恢复检测""" self.paused = False