audio_processor.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import math
  2. import os
  3. import librosa
  4. import numpy as np
  5. import torch
  6. from einops import rearrange
  7. from transformers import AutoFeatureExtractor
  8. class AudioProcessor:
  9. def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
  10. self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
  11. def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
  12. if not os.path.exists(wav_path):
  13. return None
  14. librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
  15. assert sampling_rate == 16000
  16. # Split audio into 30s segments
  17. segment_length = 30 * sampling_rate
  18. segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
  19. features = []
  20. for segment in segments:
  21. audio_feature = self.feature_extractor(
  22. segment,
  23. return_tensors="pt",
  24. sampling_rate=sampling_rate
  25. ).input_features
  26. if weight_dtype is not None:
  27. audio_feature = audio_feature.to(dtype=weight_dtype)
  28. features.append(audio_feature)
  29. return features, len(librosa_output)
  30. def get_whisper_chunk(
  31. self,
  32. whisper_input_features,
  33. device,
  34. weight_dtype,
  35. whisper,
  36. librosa_length,
  37. fps=25,
  38. audio_padding_length_left=2,
  39. audio_padding_length_right=2,
  40. ):
  41. audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
  42. whisper_feature = []
  43. # Process multiple 30s mel input features
  44. for input_feature in whisper_input_features:
  45. input_feature = input_feature.to(device).to(weight_dtype)
  46. audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
  47. audio_feats = torch.stack(audio_feats, dim=2)
  48. whisper_feature.append(audio_feats)
  49. whisper_feature = torch.cat(whisper_feature, dim=1)
  50. # Trim the last segment to remove padding
  51. sr = 16000
  52. audio_fps = 50
  53. fps = int(fps)
  54. whisper_idx_multiplier = audio_fps / fps
  55. num_frames = math.floor((librosa_length / sr) * fps)
  56. actual_length = math.floor((librosa_length / sr) * audio_fps)
  57. whisper_feature = whisper_feature[:,:actual_length,...]
  58. # Calculate padding amount
  59. padding_nums = math.ceil(whisper_idx_multiplier)
  60. # Add padding at start and end
  61. whisper_feature = torch.cat([
  62. torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
  63. whisper_feature,
  64. # Add extra padding to prevent out of bounds
  65. torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
  66. ], 1)
  67. audio_prompts = []
  68. for frame_index in range(num_frames):
  69. try:
  70. audio_index = math.floor(frame_index * whisper_idx_multiplier)
  71. audio_clip = whisper_feature[:, audio_index: audio_index + audio_feature_length_per_frame]
  72. assert audio_clip.shape[1] == audio_feature_length_per_frame
  73. audio_prompts.append(audio_clip)
  74. except Exception as e:
  75. print(f"Error occurred: {e}")
  76. print(f"whisper_feature.shape: {whisper_feature.shape}")
  77. print(f"audio_clip.shape: {audio_clip.shape}")
  78. print(f"num frames: {num_frames}, fps: {fps}, whisper_idx_multiplier: {whisper_idx_multiplier}")
  79. print(f"frame_index: {frame_index}, audio_index: {audio_index}-{audio_index + audio_feature_length_per_frame}")
  80. exit()
  81. audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
  82. audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
  83. return audio_prompts
  84. if __name__ == "__main__":
  85. audio_processor = AudioProcessor()
  86. wav_path = "./2.wav"
  87. audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
  88. print("Audio Feature shape:", audio_feature.shape)
  89. print("librosa_feature_length:", librosa_feature_length)