detector.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. 语音唤醒检测器
  3. 基于 sherpa-onnx 实现唤醒词检测
  4. """
  5. import asyncio
  6. import time
  7. from typing import Callable, Optional
  8. import numpy as np
  9. import sherpa_onnx
  10. from .config import WakeWordConfig
  11. from .audio_capture import AudioListener
  12. class WakeWordDetector(AudioListener):
  13. """语音唤醒检测器"""
  14. def __init__(self, config: WakeWordConfig):
  15. """
  16. 初始化唤醒词检测器
  17. Args:
  18. config: 唤醒词配置
  19. """
  20. self.config = config
  21. # 状态标志
  22. self.is_running = False
  23. self.paused = False
  24. self.detection_task = None
  25. # 音频数据队列
  26. self._audio_queue = asyncio.Queue(maxsize=100)
  27. # 防重复触发机制
  28. self.last_detection_time = 0
  29. # 回调函数
  30. self.on_detected_callback: Optional[Callable] = None
  31. self.on_error_callback: Optional[Callable] = None
  32. # Sherpa-ONNX 组件
  33. self.keyword_spotter = None
  34. self.stream = None
  35. # 初始化模型
  36. self._init_kws_model()
  37. def _init_kws_model(self):
  38. """初始化Sherpa-ONNX KeywordSpotter模型"""
  39. try:
  40. # 验证配置
  41. self.config.validate()
  42. # 构建模型文件路径
  43. encoder_path = self.config.model_path / "encoder.onnx"
  44. decoder_path = self.config.model_path / "decoder.onnx"
  45. joiner_path = self.config.model_path / "joiner.onnx"
  46. tokens_path = self.config.model_path / "tokens.txt"
  47. keywords_path = self.config.model_path / "keywords.txt"
  48. print(f"加载Sherpa-ONNX KeywordSpotter模型: {self.config.model_path}")
  49. # 创建KeywordSpotter
  50. self.keyword_spotter = sherpa_onnx.KeywordSpotter(
  51. tokens=str(tokens_path),
  52. encoder=str(encoder_path),
  53. decoder=str(decoder_path),
  54. joiner=str(joiner_path),
  55. keywords_file=str(keywords_path),
  56. num_threads=self.config.num_threads,
  57. sample_rate=self.config.sample_rate,
  58. feature_dim=80,
  59. max_active_paths=self.config.max_active_paths,
  60. keywords_score=self.config.keywords_score,
  61. keywords_threshold=self.config.keywords_threshold,
  62. num_trailing_blanks=self.config.num_trailing_blanks,
  63. provider=self.config.provider,
  64. )
  65. print("Sherpa-ONNX KeywordSpotter模型加载成功")
  66. except Exception as e:
  67. raise RuntimeError(f"Sherpa-ONNX KeywordSpotter初始化失败: {e}")
  68. def on_detected(self, callback: Callable):
  69. """设置检测到唤醒词的回调函数"""
  70. self.on_detected_callback = callback
  71. def on_error(self, callback: Callable):
  72. """设置错误回调函数"""
  73. self.on_error_callback = callback
  74. def on_audio_data(self, audio_data: np.ndarray):
  75. """接收音频数据(AudioListener接口实现)"""
  76. if not self.is_running or self.paused:
  77. return
  78. try:
  79. self._audio_queue.put_nowait(audio_data.copy())
  80. except asyncio.QueueFull:
  81. try:
  82. self._audio_queue.get_nowait()
  83. self._audio_queue.put_nowait(audio_data.copy())
  84. except asyncio.QueueEmpty:
  85. self._audio_queue.put_nowait(audio_data.copy())
  86. except Exception as e:
  87. print(f"音频数据入队失败: {e}")
  88. async def start(self) -> bool:
  89. """启动检测器"""
  90. if not self.keyword_spotter:
  91. print("KeywordSpotter未初始化")
  92. return False
  93. try:
  94. self.is_running = True
  95. self.paused = False
  96. # 创建检测流
  97. self.stream = self.keyword_spotter.create_stream()
  98. # 启动检测任务
  99. self.detection_task = asyncio.create_task(self._detection_loop())
  100. print("唤醒词检测器启动成功")
  101. return True
  102. except Exception as e:
  103. print(f"启动唤醒词检测器失败: {e}")
  104. return False
  105. async def _detection_loop(self):
  106. """检测循环"""
  107. error_count = 0
  108. MAX_ERRORS = 5
  109. while self.is_running:
  110. try:
  111. if self.paused:
  112. await asyncio.sleep(0.1)
  113. continue
  114. # 处理音频数据
  115. await self._process_audio()
  116. # 减少延迟提高响应速度
  117. await asyncio.sleep(0.005)
  118. error_count = 0
  119. except asyncio.CancelledError:
  120. break
  121. except Exception as e:
  122. error_count += 1
  123. print(f"KWS检测循环错误({error_count}/{MAX_ERRORS}): {e}")
  124. # 调用错误回调
  125. if self.on_error_callback:
  126. try:
  127. if asyncio.iscoroutinefunction(self.on_error_callback):
  128. await self.on_error_callback(e)
  129. else:
  130. self.on_error_callback(e)
  131. except Exception as callback_error:
  132. print(f"执行错误回调时失败: {callback_error}")
  133. if error_count >= MAX_ERRORS:
  134. print("达到最大错误次数,停止KWS检测")
  135. break
  136. await asyncio.sleep(1)
  137. async def _process_audio(self):
  138. """处理音频数据"""
  139. try:
  140. if not self.stream:
  141. return
  142. try:
  143. audio_data = self._audio_queue.get_nowait()
  144. except asyncio.QueueEmpty:
  145. return
  146. if audio_data is None or len(audio_data) == 0:
  147. return
  148. # 转换音频格式为 float32
  149. if audio_data.dtype == np.int16:
  150. samples = audio_data.astype(np.float32) / 32768.0
  151. else:
  152. samples = audio_data.astype(np.float32)
  153. # 提供音频数据给KeywordSpotter
  154. self.stream.accept_waveform(sample_rate=self.config.sample_rate, waveform=samples)
  155. # 检查是否准备好解码
  156. if self.keyword_spotter.is_ready(self.stream):
  157. self.keyword_spotter.decode_stream(self.stream)
  158. result = self.keyword_spotter.get_result(self.stream)
  159. if result:
  160. await self._handle_detection_result(result)
  161. # 重置流状态
  162. self.keyword_spotter.reset_stream(self.stream)
  163. except Exception as e:
  164. print(f"KWS音频处理错误: {e}")
  165. raise
  166. async def _handle_detection_result(self, result):
  167. """处理检测结果"""
  168. # 防重复触发检查
  169. current_time = time.time()
  170. if current_time - self.last_detection_time < self.config.detection_cooldown:
  171. return
  172. self.last_detection_time = current_time
  173. # 触发回调
  174. if self.on_detected_callback:
  175. try:
  176. if asyncio.iscoroutinefunction(self.on_detected_callback):
  177. await self.on_detected_callback(result, result)
  178. else:
  179. self.on_detected_callback(result, result)
  180. except Exception as e:
  181. print(f"唤醒词回调执行失败: {e}")
  182. async def stop(self):
  183. """停止检测器"""
  184. self.is_running = False
  185. if self.detection_task:
  186. self.detection_task.cancel()
  187. try:
  188. await self.detection_task
  189. except asyncio.CancelledError:
  190. pass
  191. # 清空队列
  192. while not self._audio_queue.empty():
  193. try:
  194. self._audio_queue.get_nowait()
  195. except asyncio.QueueEmpty:
  196. break
  197. print("唤醒词检测器已停止")
  198. def pause(self):
  199. """暂停检测"""
  200. self.paused = True
  201. def resume(self):
  202. """恢复检测"""
  203. self.paused = False