utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. from __future__ import print_function
  2. import os
  3. import sys
  4. import time
  5. import torch
  6. import math
  7. import numpy as np
  8. import cv2
  9. def _gaussian(
  10. size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
  11. height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
  12. mean_vert=0.5):
  13. # handle some defaults
  14. if width is None:
  15. width = size
  16. if height is None:
  17. height = size
  18. if sigma_horz is None:
  19. sigma_horz = sigma
  20. if sigma_vert is None:
  21. sigma_vert = sigma
  22. center_x = mean_horz * width + 0.5
  23. center_y = mean_vert * height + 0.5
  24. gauss = np.empty((height, width), dtype=np.float32)
  25. # generate kernel
  26. for i in range(height):
  27. for j in range(width):
  28. gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
  29. sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
  30. if normalize:
  31. gauss = gauss / np.sum(gauss)
  32. return gauss
  33. def draw_gaussian(image, point, sigma):
  34. # Check if the gaussian is inside
  35. ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
  36. br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
  37. if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
  38. return image
  39. size = 6 * sigma + 1
  40. g = _gaussian(size)
  41. g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
  42. g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
  43. img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
  44. img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
  45. assert (g_x[0] > 0 and g_y[1] > 0)
  46. image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
  47. ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
  48. image[image > 1] = 1
  49. return image
  50. def transform(point, center, scale, resolution, invert=False):
  51. """Generate and affine transformation matrix.
  52. Given a set of points, a center, a scale and a targer resolution, the
  53. function generates and affine transformation matrix. If invert is ``True``
  54. it will produce the inverse transformation.
  55. Arguments:
  56. point {torch.tensor} -- the input 2D point
  57. center {torch.tensor or numpy.array} -- the center around which to perform the transformations
  58. scale {float} -- the scale of the face/object
  59. resolution {float} -- the output resolution
  60. Keyword Arguments:
  61. invert {bool} -- define wherever the function should produce the direct or the
  62. inverse transformation matrix (default: {False})
  63. """
  64. _pt = torch.ones(3)
  65. _pt[0] = point[0]
  66. _pt[1] = point[1]
  67. h = 200.0 * scale
  68. t = torch.eye(3)
  69. t[0, 0] = resolution / h
  70. t[1, 1] = resolution / h
  71. t[0, 2] = resolution * (-center[0] / h + 0.5)
  72. t[1, 2] = resolution * (-center[1] / h + 0.5)
  73. if invert:
  74. t = torch.inverse(t)
  75. new_point = (torch.matmul(t, _pt))[0:2]
  76. return new_point.int()
  77. def crop(image, center, scale, resolution=256.0):
  78. """Center crops an image or set of heatmaps
  79. Arguments:
  80. image {numpy.array} -- an rgb image
  81. center {numpy.array} -- the center of the object, usually the same as of the bounding box
  82. scale {float} -- scale of the face
  83. Keyword Arguments:
  84. resolution {float} -- the size of the output cropped image (default: {256.0})
  85. Returns:
  86. [type] -- [description]
  87. """ # Crop around the center point
  88. """ Crops the image around the center. Input is expected to be an np.ndarray """
  89. ul = transform([1, 1], center, scale, resolution, True)
  90. br = transform([resolution, resolution], center, scale, resolution, True)
  91. # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
  92. if image.ndim > 2:
  93. newDim = np.array([br[1] - ul[1], br[0] - ul[0],
  94. image.shape[2]], dtype=np.int32)
  95. newImg = np.zeros(newDim, dtype=np.uint8)
  96. else:
  97. newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
  98. newImg = np.zeros(newDim, dtype=np.uint8)
  99. ht = image.shape[0]
  100. wd = image.shape[1]
  101. newX = np.array(
  102. [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
  103. newY = np.array(
  104. [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
  105. oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
  106. oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
  107. newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
  108. ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
  109. newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
  110. interpolation=cv2.INTER_LINEAR)
  111. return newImg
  112. def get_preds_fromhm(hm, center=None, scale=None):
  113. """Obtain (x,y) coordinates given a set of N heatmaps. If the center
  114. and the scale is provided the function will return the points also in
  115. the original coordinate frame.
  116. Arguments:
  117. hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
  118. Keyword Arguments:
  119. center {torch.tensor} -- the center of the bounding box (default: {None})
  120. scale {float} -- face scale (default: {None})
  121. """
  122. max, idx = torch.max(
  123. hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
  124. idx += 1
  125. preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
  126. preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
  127. preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
  128. for i in range(preds.size(0)):
  129. for j in range(preds.size(1)):
  130. hm_ = hm[i, j, :]
  131. pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
  132. if pX > 0 and pX < 63 and pY > 0 and pY < 63:
  133. diff = torch.FloatTensor(
  134. [hm_[pY, pX + 1] - hm_[pY, pX - 1],
  135. hm_[pY + 1, pX] - hm_[pY - 1, pX]])
  136. preds[i, j].add_(diff.sign_().mul_(.25))
  137. preds.add_(-.5)
  138. preds_orig = torch.zeros(preds.size())
  139. if center is not None and scale is not None:
  140. for i in range(hm.size(0)):
  141. for j in range(hm.size(1)):
  142. preds_orig[i, j] = transform(
  143. preds[i, j], center, scale, hm.size(2), True)
  144. return preds, preds_orig
  145. def get_preds_fromhm_batch(hm, centers=None, scales=None):
  146. """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
  147. and the scales is provided the function will return the points also in
  148. the original coordinate frame.
  149. Arguments:
  150. hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
  151. Keyword Arguments:
  152. centers {torch.tensor} -- the centers of the bounding box (default: {None})
  153. scales {float} -- face scales (default: {None})
  154. """
  155. max, idx = torch.max(
  156. hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
  157. idx += 1
  158. preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
  159. preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
  160. preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
  161. for i in range(preds.size(0)):
  162. for j in range(preds.size(1)):
  163. hm_ = hm[i, j, :]
  164. pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
  165. if pX > 0 and pX < 63 and pY > 0 and pY < 63:
  166. diff = torch.FloatTensor(
  167. [hm_[pY, pX + 1] - hm_[pY, pX - 1],
  168. hm_[pY + 1, pX] - hm_[pY - 1, pX]])
  169. preds[i, j].add_(diff.sign_().mul_(.25))
  170. preds.add_(-.5)
  171. preds_orig = torch.zeros(preds.size())
  172. if centers is not None and scales is not None:
  173. for i in range(hm.size(0)):
  174. for j in range(hm.size(1)):
  175. preds_orig[i, j] = transform(
  176. preds[i, j], centers[i], scales[i], hm.size(2), True)
  177. return preds, preds_orig
  178. def shuffle_lr(parts, pairs=None):
  179. """Shuffle the points left-right according to the axis of symmetry
  180. of the object.
  181. Arguments:
  182. parts {torch.tensor} -- a 3D or 4D object containing the
  183. heatmaps.
  184. Keyword Arguments:
  185. pairs {list of integers} -- [order of the flipped points] (default: {None})
  186. """
  187. if pairs is None:
  188. pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
  189. 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
  190. 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
  191. 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
  192. 62, 61, 60, 67, 66, 65]
  193. if parts.ndimension() == 3:
  194. parts = parts[pairs, ...]
  195. else:
  196. parts = parts[:, pairs, ...]
  197. return parts
  198. def flip(tensor, is_label=False):
  199. """Flip an image or a set of heatmaps left-right
  200. Arguments:
  201. tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
  202. Keyword Arguments:
  203. is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
  204. """
  205. if not torch.is_tensor(tensor):
  206. tensor = torch.from_numpy(tensor)
  207. if is_label:
  208. tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
  209. else:
  210. tensor = tensor.flip(tensor.ndimension() - 1)
  211. return tensor
  212. # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
  213. def appdata_dir(appname=None, roaming=False):
  214. """ appdata_dir(appname=None, roaming=False)
  215. Get the path to the application directory, where applications are allowed
  216. to write user specific files (e.g. configurations). For non-user specific
  217. data, consider using common_appdata_dir().
  218. If appname is given, a subdir is appended (and created if necessary).
  219. If roaming is True, will prefer a roaming directory (Windows Vista/7).
  220. """
  221. # Define default user directory
  222. userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
  223. if userDir is None:
  224. userDir = os.path.expanduser('~')
  225. if not os.path.isdir(userDir): # pragma: no cover
  226. userDir = '/var/tmp' # issue #54
  227. # Get system app data dir
  228. path = None
  229. if sys.platform.startswith('win'):
  230. path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
  231. path = (path2 or path1) if roaming else (path1 or path2)
  232. elif sys.platform.startswith('darwin'):
  233. path = os.path.join(userDir, 'Library', 'Application Support')
  234. # On Linux and as fallback
  235. if not (path and os.path.isdir(path)):
  236. path = userDir
  237. # Maybe we should store things local to the executable (in case of a
  238. # portable distro or a frozen application that wants to be portable)
  239. prefix = sys.prefix
  240. if getattr(sys, 'frozen', None):
  241. prefix = os.path.abspath(os.path.dirname(sys.executable))
  242. for reldir in ('settings', '../settings'):
  243. localpath = os.path.abspath(os.path.join(prefix, reldir))
  244. if os.path.isdir(localpath): # pragma: no cover
  245. try:
  246. open(os.path.join(localpath, 'test.write'), 'wb').close()
  247. os.remove(os.path.join(localpath, 'test.write'))
  248. except IOError:
  249. pass # We cannot write in this directory
  250. else:
  251. path = localpath
  252. break
  253. # Get path specific for this app
  254. if appname:
  255. if path == userDir:
  256. appname = '.' + appname.lstrip('.') # Make it a hidden directory
  257. path = os.path.join(path, appname)
  258. if not os.path.isdir(path): # pragma: no cover
  259. os.mkdir(path)
  260. # Done
  261. return path