audio2feature.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from transformers import Wav2Vec2Processor, HubertModel
  2. import torch
  3. import numpy as np
  4. class Audio2Feature():
  5. def __init__(self):
  6. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  7. self.processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
  8. self.model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft").to(self.device)
  9. @torch.no_grad()
  10. def get_hubert_from_16k_speech(self, speech):
  11. if speech.ndim == 2:
  12. speech = speech[:, 0] # [T, 2] ==> [T,]
  13. input_values_all = self.processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
  14. input_values_all = input_values_all.to(self.device)
  15. kernel = 400
  16. stride = 320
  17. clip_length = stride * 1000
  18. num_iter = input_values_all.shape[1] // clip_length
  19. expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
  20. res_lst = []
  21. for i in range(num_iter):
  22. if i == 0:
  23. start_idx = 0
  24. end_idx = clip_length - stride + kernel
  25. else:
  26. start_idx = clip_length * i
  27. end_idx = start_idx + (clip_length - stride + kernel)
  28. input_values = input_values_all[:, start_idx: end_idx]
  29. hidden_states = self.model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
  30. res_lst.append(hidden_states[0])
  31. if num_iter > 0:
  32. input_values = input_values_all[:, clip_length * num_iter:]
  33. else:
  34. input_values = input_values_all
  35. if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
  36. hidden_states = self.model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
  37. res_lst.append(hidden_states[0])
  38. ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
  39. assert abs(ret.shape[0] - expected_T) <= 1
  40. if ret.shape[0] < expected_T:
  41. ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
  42. else:
  43. ret = ret[:expected_T]
  44. return ret
  45. def get_sliced_feature(self,
  46. feature_array,
  47. vid_idx,
  48. audio_feat_length=[8,8],
  49. fps=25):
  50. """
  51. Get sliced features based on a given index
  52. :param feature_array:
  53. :param start_idx: the start index of the feature
  54. :param audio_feat_length:
  55. :return:
  56. """
  57. length = len(feature_array)
  58. selected_feature = []
  59. selected_idx = []
  60. center_idx = int(vid_idx*50/fps)
  61. left_idx = center_idx-audio_feat_length[0]*2
  62. right_idx = center_idx + (audio_feat_length[1])*2
  63. for idx in range(left_idx,right_idx):
  64. idx = max(0, idx)
  65. idx = min(length-1, idx)
  66. x = feature_array[idx]
  67. selected_feature.append(x)
  68. selected_idx.append(idx)
  69. selected_feature = np.concatenate(selected_feature, axis=0)
  70. selected_feature = selected_feature.reshape(-1, 1024)
  71. return selected_feature,selected_idx
  72. def feature2chunks(self,feature_array,fps,batch_size,audio_feat_length = [8,8],start=0):
  73. whisper_chunks = []
  74. whisper_idx_multiplier = 50./fps
  75. i = 0
  76. #print(f"video in {fps} FPS, audio idx in 50FPS")
  77. for _ in range(batch_size):
  78. # start_idx = int(i * whisper_idx_multiplier)
  79. # if start_idx>=len(feature_array):
  80. # break
  81. selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i+start,audio_feat_length=audio_feat_length,fps=fps)
  82. #print(f"i:{i},selected_idx {selected_idx}")
  83. whisper_chunks.append(selected_feature)
  84. i += 1
  85. return whisper_chunks