unet.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. import json
  5. from diffusers import UNet2DConditionModel
  6. import sys
  7. import time
  8. import numpy as np
  9. import os
  10. class PositionalEncoding(nn.Module):
  11. def __init__(self, d_model=384, max_len=5000):
  12. super(PositionalEncoding, self).__init__()
  13. pe = torch.zeros(max_len, d_model)
  14. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  15. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  16. pe[:, 0::2] = torch.sin(position * div_term)
  17. pe[:, 1::2] = torch.cos(position * div_term)
  18. pe = pe.unsqueeze(0)
  19. self.register_buffer('pe', pe)
  20. def forward(self, x):
  21. b, seq_len, d_model = x.size()
  22. pe = self.pe[:, :seq_len, :]
  23. x = x + pe.to(x.device)
  24. return x
  25. class UNet():
  26. def __init__(self,
  27. unet_config,
  28. model_path,
  29. use_float16=False,
  30. device=None
  31. ):
  32. with open(unet_config, 'r') as f:
  33. unet_config = json.load(f)
  34. self.model = UNet2DConditionModel(**unet_config)
  35. self.pe = PositionalEncoding(d_model=384)
  36. if device != None:
  37. self.device = device
  38. else:
  39. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  40. weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
  41. self.model.load_state_dict(weights)
  42. if use_float16:
  43. self.model = self.model.half()
  44. self.model.to(self.device)
  45. if __name__ == "__main__":
  46. unet = UNet()