| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- from diffusers import AutoencoderKL
- import torch
- import torchvision.transforms as transforms
- import torch.nn.functional as F
- import cv2
- import numpy as np
- from PIL import Image
- import os
- class VAE():
- """
- VAE (Variational Autoencoder) class for image processing.
- """
- def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
- """
- Initialize the VAE instance.
- :param model_path: Path to the trained model.
- :param resized_img: The size to which images are resized.
- :param use_float16: Whether to use float16 precision.
- """
- self.model_path = model_path
- self.vae = AutoencoderKL.from_pretrained(self.model_path)
- self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
- self.vae.to(self.device)
- if use_float16:
- self.vae = self.vae.half()
- self._use_float16 = True
- else:
- self._use_float16 = False
- self.scaling_factor = self.vae.config.scaling_factor
- self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
- self._resized_img = resized_img
- self._mask_tensor = self.get_mask_tensor()
-
- def get_mask_tensor(self):
- """
- Creates a mask tensor for image processing.
- :return: A mask tensor.
- """
- mask_tensor = torch.zeros((self._resized_img,self._resized_img))
- mask_tensor[:self._resized_img//2,:] = 1
- mask_tensor[mask_tensor< 0.5] = 0
- mask_tensor[mask_tensor>= 0.5] = 1
- return mask_tensor
-
- def preprocess_img(self,img_name,half_mask=False):
- """
- Preprocess an image for the VAE.
- :param img_name: The image file path or a list of image file paths.
- :param half_mask: Whether to apply a half mask to the image.
- :return: A preprocessed image tensor.
- """
- window = []
- if isinstance(img_name, str):
- window_fnames = [img_name]
- for fname in window_fnames:
- img = cv2.imread(fname)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = cv2.resize(img, (self._resized_img, self._resized_img),
- interpolation=cv2.INTER_LANCZOS4)
- window.append(img)
- else:
- img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
- window.append(img)
-
- x = np.asarray(window) / 255.
- x = np.transpose(x, (3, 0, 1, 2))
- x = torch.squeeze(torch.FloatTensor(x))
- if half_mask:
- x = x * (self._mask_tensor>0.5)
- x = self.transform(x)
-
- x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
- x = x.to(self.vae.device)
- return x
- def encode_latents(self,image):
- """
- Encode an image into latent variables.
- :param image: The image tensor to encode.
- :return: The encoded latent variables.
- """
- with torch.no_grad():
- init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
- init_latents = self.scaling_factor * init_latent_dist.sample()
- return init_latents
-
- def decode_latents(self, latents):
- """
- Decode latent variables back into an image.
- :param latents: The latent variables to decode.
- :return: A NumPy array representing the decoded image.
- """
- latents = (1/ self.scaling_factor) * latents
- image = self.vae.decode(latents.to(self.vae.dtype)).sample
- image = (image / 2 + 0.5).clamp(0, 1)
- image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
- image = (image * 255).round().astype("uint8")
- image = image[...,::-1] # RGB to BGR
- return image
-
- def get_latents_for_unet(self,img):
- """
- Prepare latent variables for a U-Net model.
- :param img: The image to process.
- :return: A concatenated tensor of latents for U-Net input.
- """
-
- ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
- masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
- ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
- ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
- latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
- return latent_model_input
- if __name__ == "__main__":
- vae_mode_path = "./models/sd-vae-ft-mse/"
- vae = VAE(model_path = vae_mode_path,use_float16=False)
- img_path = "./results/sun001_crop/00000.png"
-
- crop_imgs_path = "./results/sun001_crop/"
- latents_out_path = "./results/latents/"
- if not os.path.exists(latents_out_path):
- os.mkdir(latents_out_path)
- files = os.listdir(crop_imgs_path)
- files.sort()
- files = [file for file in files if file.split(".")[-1] == "png"]
- for file in files:
- index = file.split(".")[0]
- img_path = crop_imgs_path + file
- latents = vae.get_latents_for_unet(img_path)
- print(img_path,"latents",latents.size())
- #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
- #reload_tensor = torch.load('tensor.pt')
- #print(reload_tensor.size())
-
-
|