syncnet.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """
  2. This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
  3. """
  4. import torch
  5. from torch import nn
  6. from einops import rearrange
  7. from torch.nn import functional as F
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from diffusers.models.attention import Attention as CrossAttention, FeedForward
  11. from diffusers.utils.import_utils import is_xformers_available
  12. from einops import rearrange
  13. class SyncNet(nn.Module):
  14. def __init__(self, config):
  15. super().__init__()
  16. self.audio_encoder = DownEncoder2D(
  17. in_channels=config["audio_encoder"]["in_channels"],
  18. block_out_channels=config["audio_encoder"]["block_out_channels"],
  19. downsample_factors=config["audio_encoder"]["downsample_factors"],
  20. dropout=config["audio_encoder"]["dropout"],
  21. attn_blocks=config["audio_encoder"]["attn_blocks"],
  22. )
  23. self.visual_encoder = DownEncoder2D(
  24. in_channels=config["visual_encoder"]["in_channels"],
  25. block_out_channels=config["visual_encoder"]["block_out_channels"],
  26. downsample_factors=config["visual_encoder"]["downsample_factors"],
  27. dropout=config["visual_encoder"]["dropout"],
  28. attn_blocks=config["visual_encoder"]["attn_blocks"],
  29. )
  30. self.eval()
  31. def forward(self, image_sequences, audio_sequences):
  32. vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
  33. audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
  34. vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
  35. audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
  36. # Make them unit vectors
  37. vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
  38. audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
  39. return vision_embeds, audio_embeds
  40. def get_image_embed(self, image_sequences):
  41. vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
  42. vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
  43. # Make them unit vectors
  44. vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
  45. return vision_embeds
  46. def get_audio_embed(self, audio_sequences):
  47. audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
  48. audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
  49. audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
  50. return audio_embeds
  51. class ResnetBlock2D(nn.Module):
  52. def __init__(
  53. self,
  54. in_channels: int,
  55. out_channels: int,
  56. dropout: float = 0.0,
  57. norm_num_groups: int = 32,
  58. eps: float = 1e-6,
  59. act_fn: str = "silu",
  60. downsample_factor=2,
  61. ):
  62. super().__init__()
  63. self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
  64. self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  65. self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
  66. self.dropout = nn.Dropout(dropout)
  67. self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  68. if act_fn == "relu":
  69. self.act_fn = nn.ReLU()
  70. elif act_fn == "silu":
  71. self.act_fn = nn.SiLU()
  72. if in_channels != out_channels:
  73. self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  74. else:
  75. self.conv_shortcut = None
  76. if isinstance(downsample_factor, list):
  77. downsample_factor = tuple(downsample_factor)
  78. if downsample_factor == 1:
  79. self.downsample_conv = None
  80. else:
  81. self.downsample_conv = nn.Conv2d(
  82. out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
  83. )
  84. self.pad = (0, 1, 0, 1)
  85. if isinstance(downsample_factor, tuple):
  86. if downsample_factor[0] == 1:
  87. self.pad = (0, 1, 1, 1) # The padding order is from back to front
  88. elif downsample_factor[1] == 1:
  89. self.pad = (1, 1, 0, 1)
  90. def forward(self, input_tensor):
  91. hidden_states = input_tensor
  92. hidden_states = self.norm1(hidden_states)
  93. hidden_states = self.act_fn(hidden_states)
  94. hidden_states = self.conv1(hidden_states)
  95. hidden_states = self.norm2(hidden_states)
  96. hidden_states = self.act_fn(hidden_states)
  97. hidden_states = self.dropout(hidden_states)
  98. hidden_states = self.conv2(hidden_states)
  99. if self.conv_shortcut is not None:
  100. input_tensor = self.conv_shortcut(input_tensor)
  101. hidden_states += input_tensor
  102. if self.downsample_conv is not None:
  103. hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
  104. hidden_states = self.downsample_conv(hidden_states)
  105. return hidden_states
  106. class AttentionBlock2D(nn.Module):
  107. def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
  108. super().__init__()
  109. if not is_xformers_available():
  110. raise ModuleNotFoundError(
  111. "You have to install xformers to enable memory efficient attetion", name="xformers"
  112. )
  113. # inner_dim = dim_head * heads
  114. self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
  115. self.norm2 = nn.LayerNorm(query_dim)
  116. self.norm3 = nn.LayerNorm(query_dim)
  117. self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
  118. self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
  119. self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
  120. self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
  121. self.attn._use_memory_efficient_attention_xformers = True
  122. def forward(self, hidden_states):
  123. assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
  124. batch, channel, height, width = hidden_states.shape
  125. residual = hidden_states
  126. hidden_states = self.norm1(hidden_states)
  127. hidden_states = self.conv_in(hidden_states)
  128. hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
  129. norm_hidden_states = self.norm2(hidden_states)
  130. hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
  131. hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
  132. hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
  133. hidden_states = self.conv_out(hidden_states)
  134. hidden_states = hidden_states + residual
  135. return hidden_states
  136. class DownEncoder2D(nn.Module):
  137. def __init__(
  138. self,
  139. in_channels=4 * 16,
  140. block_out_channels=[64, 128, 256, 256],
  141. downsample_factors=[2, 2, 2, 2],
  142. layers_per_block=2,
  143. norm_num_groups=32,
  144. attn_blocks=[1, 1, 1, 1],
  145. dropout: float = 0.0,
  146. act_fn="silu",
  147. ):
  148. super().__init__()
  149. self.layers_per_block = layers_per_block
  150. # in
  151. self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
  152. # down
  153. self.down_blocks = nn.ModuleList([])
  154. output_channels = block_out_channels[0]
  155. for i, block_out_channel in enumerate(block_out_channels):
  156. input_channels = output_channels
  157. output_channels = block_out_channel
  158. # is_final_block = i == len(block_out_channels) - 1
  159. down_block = ResnetBlock2D(
  160. in_channels=input_channels,
  161. out_channels=output_channels,
  162. downsample_factor=downsample_factors[i],
  163. norm_num_groups=norm_num_groups,
  164. dropout=dropout,
  165. act_fn=act_fn,
  166. )
  167. self.down_blocks.append(down_block)
  168. if attn_blocks[i] == 1:
  169. attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
  170. self.down_blocks.append(attention_block)
  171. # out
  172. self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
  173. self.act_fn_out = nn.ReLU()
  174. def forward(self, hidden_states):
  175. hidden_states = self.conv_in(hidden_states)
  176. # down
  177. for down_block in self.down_blocks:
  178. hidden_states = down_block(hidden_states)
  179. # post-process
  180. hidden_states = self.norm_out(hidden_states)
  181. hidden_states = self.act_fn_out(hidden_states)
  182. return hidden_states