wav2lip_v2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. import pdb
  5. from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
  6. class Wav2Lip(nn.Module):
  7. def __init__(self):
  8. super(Wav2Lip, self).__init__()
  9. self.face_encoder_blocks = nn.ModuleList([
  10. nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)),
  11. nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
  12. Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
  13. Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
  14. nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
  15. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
  16. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
  17. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
  18. nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
  19. Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
  20. Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
  21. nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
  22. Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
  23. Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
  24. nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
  25. Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
  26. nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
  27. Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
  28. nn.Sequential(Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
  29. Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
  30. self.audio_encoder = nn.Sequential(
  31. Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
  32. Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
  33. Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
  34. Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
  35. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
  36. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
  37. Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
  38. Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
  39. Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
  40. Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
  41. Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
  42. Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
  43. Conv2d(512, 512, kernel_size=1, stride=1, padding=0), )
  44. self.face_decoder_blocks = nn.ModuleList([
  45. nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ),
  46. nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0),
  47. Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
  48. nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
  49. Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
  50. nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
  51. Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
  52. Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
  53. nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
  54. Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
  55. Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), ),
  56. nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
  57. Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
  58. Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ),
  59. nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
  60. Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
  61. Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ),
  62. nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
  63. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
  64. Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ])
  65. self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
  66. nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
  67. nn.Sigmoid())
  68. def audio_forward(self, audio_sequences, a_alpha=1.):
  69. audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
  70. if a_alpha != 1.:
  71. audio_embedding *= a_alpha
  72. return audio_embedding
  73. def inference(self, audio_embedding, face_sequences):
  74. feats = []
  75. x = face_sequences
  76. for f in self.face_encoder_blocks:
  77. x = f(x)
  78. feats.append(x)
  79. x = audio_embedding
  80. for f in self.face_decoder_blocks:
  81. x = f(x)
  82. try:
  83. x = torch.cat((x, feats[-1]), dim=1)
  84. except Exception as e:
  85. print(x.size())
  86. print(feats[-1].size())
  87. raise e
  88. feats.pop()
  89. x = self.output_block(x)
  90. outputs = x
  91. return outputs
  92. def forward(self, audio_sequences, face_sequences, a_alpha=1.):
  93. # audio_sequences = (B, T, 1, 80, 16)
  94. B = audio_sequences.size(0)
  95. input_dim_size = len(face_sequences.size())
  96. if input_dim_size > 4:
  97. audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)#[bz, 5, 1, 80, 16]->[bz*5, 1, 80, 16]
  98. face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)#[bz, 6, 5, 256, 256]->[bz*5, 6, 256, 256]
  99. audio_embedding = self.audio_encoder(audio_sequences) # [bz*5, 1, 80, 16]->[bz*5, 512, 1, 1]
  100. if a_alpha != 1.:
  101. audio_embedding *= a_alpha #放大音频强度
  102. feats = []
  103. x = face_sequences
  104. for f in self.face_encoder_blocks:
  105. x = f(x)
  106. feats.append(x)
  107. x = audio_embedding
  108. for f in self.face_decoder_blocks:
  109. x = f(x)
  110. try:
  111. x = torch.cat((x, feats[-1]), dim=1)
  112. except Exception as e:
  113. print(x.size())
  114. print(feats[-1].size())
  115. raise e
  116. feats.pop()
  117. x = self.output_block(x) #[bz*5, 80, 256, 256]->[bz*5, 3, 256, 256]
  118. if input_dim_size > 4: #[bz*5, 3, 256, 256]->[B, 3, 5, 256, 256]
  119. x = torch.split(x, B, dim=0)
  120. outputs = torch.stack(x, dim=2)
  121. else:
  122. outputs = x
  123. return outputs
  124. class Wav2Lip_disc_qual(nn.Module):
  125. def __init__(self):
  126. super(Wav2Lip_disc_qual, self).__init__()
  127. self.face_encoder_blocks = nn.ModuleList([
  128. nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)),
  129. nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2),
  130. nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
  131. nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
  132. nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
  133. nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
  134. nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
  135. nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
  136. nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
  137. nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
  138. nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
  139. nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
  140. nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
  141. nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
  142. nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
  143. self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
  144. self.label_noise = .0
  145. def get_lower_half(self, face_sequences): #取得输入图片的下半部分。
  146. return face_sequences[:, :, face_sequences.size(2) // 2:]
  147. def to_2d(self, face_sequences): #将输入的图片序列连接起来,形成一个二维的tensor。
  148. B = face_sequences.size(0)
  149. face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
  150. return face_sequences
  151. def perceptual_forward(self, false_face_sequences): #前传生成图像
  152. false_face_sequences = self.to_2d(false_face_sequences) #[bz, 3, 5, 256, 256]->[bz*5, 3, 256, 256]
  153. false_face_sequences = self.get_lower_half(false_face_sequences)#[bz*5, 3, 256, 256]->[bz*5, 3, 128, 256]
  154. false_feats = false_face_sequences
  155. for f in self.face_encoder_blocks: #[bz*5, 3, 128, 256]->[bz*5, 512, 1, 1]
  156. false_feats = f(false_feats)
  157. return self.binary_pred(false_feats).view(len(false_feats), -1) #[bz*5, 512, 1, 1]->[bz*5, 1, 1]
  158. def forward(self, face_sequences): #前传真值图像
  159. face_sequences = self.to_2d(face_sequences)
  160. face_sequences = self.get_lower_half(face_sequences)
  161. x = face_sequences
  162. for f in self.face_encoder_blocks:
  163. x = f(x)
  164. return self.binary_pred(x).view(len(x), -1)