vae.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from diffusers import AutoencoderKL
  2. import torch
  3. import torchvision.transforms as transforms
  4. import torch.nn.functional as F
  5. import cv2
  6. import numpy as np
  7. from PIL import Image
  8. import os
  9. class VAE():
  10. """
  11. VAE (Variational Autoencoder) class for image processing.
  12. """
  13. def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
  14. """
  15. Initialize the VAE instance.
  16. :param model_path: Path to the trained model.
  17. :param resized_img: The size to which images are resized.
  18. :param use_float16: Whether to use float16 precision.
  19. """
  20. self.model_path = model_path
  21. self.vae = AutoencoderKL.from_pretrained(self.model_path)
  22. self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
  23. self.vae.to(self.device)
  24. if use_float16:
  25. self.vae = self.vae.half()
  26. self._use_float16 = True
  27. else:
  28. self._use_float16 = False
  29. self.scaling_factor = self.vae.config.scaling_factor
  30. self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  31. self._resized_img = resized_img
  32. self._mask_tensor = self.get_mask_tensor()
  33. def get_mask_tensor(self):
  34. """
  35. Creates a mask tensor for image processing.
  36. :return: A mask tensor.
  37. """
  38. mask_tensor = torch.zeros((self._resized_img,self._resized_img))
  39. mask_tensor[:self._resized_img//2,:] = 1
  40. mask_tensor[mask_tensor< 0.5] = 0
  41. mask_tensor[mask_tensor>= 0.5] = 1
  42. return mask_tensor
  43. def preprocess_img(self,img_name,half_mask=False):
  44. """
  45. Preprocess an image for the VAE.
  46. :param img_name: The image file path or a list of image file paths.
  47. :param half_mask: Whether to apply a half mask to the image.
  48. :return: A preprocessed image tensor.
  49. """
  50. window = []
  51. if isinstance(img_name, str):
  52. window_fnames = [img_name]
  53. for fname in window_fnames:
  54. img = cv2.imread(fname)
  55. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  56. img = cv2.resize(img, (self._resized_img, self._resized_img),
  57. interpolation=cv2.INTER_LANCZOS4)
  58. window.append(img)
  59. else:
  60. img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
  61. window.append(img)
  62. x = np.asarray(window) / 255.
  63. x = np.transpose(x, (3, 0, 1, 2))
  64. x = torch.squeeze(torch.FloatTensor(x))
  65. if half_mask:
  66. x = x * (self._mask_tensor>0.5)
  67. x = self.transform(x)
  68. x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
  69. x = x.to(self.vae.device)
  70. return x
  71. def encode_latents(self,image):
  72. """
  73. Encode an image into latent variables.
  74. :param image: The image tensor to encode.
  75. :return: The encoded latent variables.
  76. """
  77. with torch.no_grad():
  78. init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
  79. init_latents = self.scaling_factor * init_latent_dist.sample()
  80. return init_latents
  81. def decode_latents(self, latents):
  82. """
  83. Decode latent variables back into an image.
  84. :param latents: The latent variables to decode.
  85. :return: A NumPy array representing the decoded image.
  86. """
  87. latents = (1/ self.scaling_factor) * latents
  88. image = self.vae.decode(latents.to(self.vae.dtype)).sample
  89. image = (image / 2 + 0.5).clamp(0, 1)
  90. image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
  91. image = (image * 255).round().astype("uint8")
  92. image = image[...,::-1] # RGB to BGR
  93. return image
  94. def get_latents_for_unet(self,img):
  95. """
  96. Prepare latent variables for a U-Net model.
  97. :param img: The image to process.
  98. :return: A concatenated tensor of latents for U-Net input.
  99. """
  100. ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
  101. masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
  102. ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
  103. ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
  104. latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
  105. return latent_model_input
  106. if __name__ == "__main__":
  107. vae_mode_path = "./models/sd-vae-ft-mse/"
  108. vae = VAE(model_path = vae_mode_path,use_float16=False)
  109. img_path = "./results/sun001_crop/00000.png"
  110. crop_imgs_path = "./results/sun001_crop/"
  111. latents_out_path = "./results/latents/"
  112. if not os.path.exists(latents_out_path):
  113. os.mkdir(latents_out_path)
  114. files = os.listdir(crop_imgs_path)
  115. files.sort()
  116. files = [file for file in files if file.split(".")[-1] == "png"]
  117. for file in files:
  118. index = file.split(".")[0]
  119. img_path = crop_imgs_path + file
  120. latents = vae.get_latents_for_unet(img_path)
  121. print(img_path,"latents",latents.size())
  122. #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
  123. #reload_tensor = torch.load('tensor.pt')
  124. #print(reload_tensor.size())