audio2feature.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import os
  2. from .whisper import load_model
  3. import soundfile as sf
  4. import numpy as np
  5. import time
  6. import sys
  7. from transformers import AutoFeatureExtractor
  8. from transformers import WhisperModel
  9. import torch
  10. sys.path.append("..")
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
  13. class Audio2Feature():
  14. def __init__(self,
  15. whisper_model_type="tiny",
  16. model_path="./models/whisper"):
  17. # self.whisper_model_type = whisper_model_type
  18. # self.model = load_model(model_path) #
  19. self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
  20. self.whisper = WhisperModel.from_pretrained(model_path)
  21. self.whisper = self.whisper.to(device=device, dtype=weight_dtype).eval()
  22. self.whisper.requires_grad_(False)
  23. def get_sliced_feature(self,
  24. feature_array,
  25. vid_idx,
  26. audio_feat_length=[2,2],
  27. fps=25):
  28. """
  29. Get sliced features based on a given index
  30. :param feature_array:
  31. :param start_idx: the start index of the feature
  32. :param audio_feat_length:
  33. :return:
  34. """
  35. length = len(feature_array)
  36. selected_feature = []
  37. selected_idx = []
  38. center_idx = int(vid_idx*50/fps)
  39. left_idx = center_idx; #-audio_feat_length[0]*2
  40. right_idx = center_idx + (audio_feat_length[0]+audio_feat_length[1]+1)*2
  41. for idx in range(left_idx,right_idx):
  42. idx = max(0, idx)
  43. idx = min(length-1, idx)
  44. x = feature_array[idx]
  45. selected_feature.append(x)
  46. selected_idx.append(idx)
  47. selected_feature = np.concatenate(selected_feature, axis=0)
  48. selected_feature = selected_feature.reshape(-1, 384)# 50*384
  49. return selected_feature,selected_idx
  50. def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
  51. """
  52. Get sliced features based on a given index
  53. :param feature_array:
  54. :param start_idx: the start index of the feature
  55. :param audio_feat_length:
  56. :return:
  57. """
  58. length = len(feature_array)
  59. selected_feature = []
  60. selected_idx = []
  61. for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
  62. left_idx = int((vid_idx+dt)*50/fps)
  63. if left_idx<1 or left_idx>length-1:
  64. print('test-----,left_idx=',left_idx)
  65. left_idx = max(0, left_idx)
  66. left_idx = min(length-1, left_idx)
  67. x = feature_array[left_idx]
  68. x = x[np.newaxis,:,:]
  69. x = np.repeat(x, 2, axis=0)
  70. selected_feature.append(x)
  71. selected_idx.append(left_idx)
  72. selected_idx.append(left_idx)
  73. else:
  74. x = feature_array[left_idx-1:left_idx+1]
  75. selected_feature.append(x)
  76. selected_idx.append(left_idx-1)
  77. selected_idx.append(left_idx)
  78. selected_feature = np.concatenate(selected_feature, axis=0)
  79. selected_feature = selected_feature.reshape(-1, 384)# 50*384
  80. return selected_feature,selected_idx
  81. def feature2chunks(self,feature_array,fps,batch_size,audio_feat_length = [2,2],start=0):
  82. whisper_chunks = []
  83. whisper_idx_multiplier = 50./fps
  84. i = 0
  85. #print(f"video in {fps} FPS, audio idx in 50FPS")
  86. for _ in range(batch_size):
  87. # start_idx = int(i * whisper_idx_multiplier)
  88. # if start_idx>=len(feature_array):
  89. # break
  90. selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i+start,audio_feat_length=audio_feat_length,fps=fps)
  91. #print(f"i:{i},selected_idx {selected_idx}")
  92. whisper_chunks.append(selected_feature)
  93. i += 1
  94. return whisper_chunks
  95. def audio2feat(self, wav_data): #, weight_dtype=None
  96. input_feature = self.feature_extractor(
  97. wav_data,
  98. return_tensors="pt",
  99. sampling_rate=16000
  100. ).input_features
  101. input_feature = input_feature.to(device).to(weight_dtype)
  102. whisper_feature = self.whisper.encoder(input_feature, output_hidden_states=True).hidden_states
  103. #print(f"input_feature shape:{input_feature.shape}, whisper_feature shape:{whisper_feature[0].shape}, whisper_feature len:{len(whisper_feature)}")
  104. whisper_feature = torch.stack(whisper_feature, dim=2)
  105. #print(f"stacked whisper_feature shape:{whisper_feature.shape}")
  106. return whisper_feature.squeeze(0).cpu().numpy()
  107. # def audio2feat(self,audio_path):
  108. # # get the sample rate of the audio
  109. # result = self.model.transcribe(audio_path)
  110. # embed_list = []
  111. # for emb in result['segments']:
  112. # encoder_embeddings = emb['encoder_embeddings']
  113. # encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
  114. # encoder_embeddings = encoder_embeddings.squeeze(0)
  115. # start_idx = int(emb['start'])
  116. # end_idx = int(emb['end'])
  117. # emb_end_idx = int((end_idx - start_idx)/2)
  118. # embed_list.append(encoder_embeddings[:emb_end_idx])
  119. # concatenated_array = np.concatenate(embed_list, axis=0)
  120. # return concatenated_array
  121. if __name__ == "__main__":
  122. audio_processor = Audio2Feature(model_path="../../models/whisper/whisper_tiny.pt")
  123. audio_path = "./test.mp3"
  124. array = audio_processor.audio2feat(audio_path)
  125. print(array.shape)
  126. fps = 25
  127. whisper_idx_multiplier = 50./fps
  128. i = 0
  129. print(f"video in {fps} FPS, audio idx in 50FPS")
  130. while 1:
  131. start_idx = int(i * whisper_idx_multiplier)
  132. selected_feature,selected_idx = audio_processor.get_sliced_feature(feature_array= array,vid_idx = i,audio_feat_length=[2,2],fps=fps)
  133. print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
  134. i += 1
  135. if start_idx>len(array):
  136. break