hparams.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from glob import glob
  2. import os
  3. def get_image_list(data_root, split):
  4. filelist = []
  5. with open('filelists/{}.txt'.format(split)) as f:
  6. for line in f:
  7. line = line.strip()
  8. if ' ' in line: line = line.split()[0]
  9. filelist.append(os.path.join(data_root, line))
  10. return filelist
  11. class HParams:
  12. def __init__(self, **kwargs):
  13. self.data = {}
  14. for key, value in kwargs.items():
  15. self.data[key] = value
  16. def __getattr__(self, key):
  17. if key not in self.data:
  18. raise AttributeError("'HParams' object has no attribute %s" % key)
  19. return self.data[key]
  20. def set_hparam(self, key, value):
  21. self.data[key] = value
  22. # Default hyperparameters
  23. hparams = HParams(
  24. num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
  25. # network
  26. rescale=True, # Whether to rescale audio prior to preprocessing
  27. rescaling_max=0.9, # Rescaling value
  28. # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
  29. # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
  30. # Does not work if n_ffit is not multiple of hop_size!!
  31. use_lws=False,
  32. n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
  33. hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
  34. win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
  35. sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
  36. frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
  37. # Mel and Linear spectrograms normalization/scaling and clipping
  38. signal_normalization=True,
  39. # Whether to normalize mel spectrograms to some predefined range (following below parameters)
  40. allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
  41. symmetric_mels=True,
  42. # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
  43. # faster and cleaner convergence)
  44. max_abs_value=4.,
  45. # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
  46. # be too big to avoid gradient explosion,
  47. # not too small for fast convergence)
  48. # Contribution by @begeekmyfriend
  49. # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
  50. # levels. Also allows for better G&L phase reconstruction)
  51. preemphasize=True, # whether to apply filter
  52. preemphasis=0.97, # filter coefficient.
  53. # Limits
  54. min_level_db=-100,
  55. ref_level_db=20,
  56. fmin=55,
  57. # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
  58. # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
  59. fmax=7600, # To be increased/reduced depending on data.
  60. ###################### Our training parameters #################################
  61. img_size=96,
  62. fps=25,
  63. batch_size=16,
  64. initial_learning_rate=1e-4,
  65. nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
  66. num_workers=16,
  67. checkpoint_interval=3000,
  68. eval_interval=3000,
  69. save_optimizer_state=True,
  70. syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
  71. syncnet_batch_size=64,
  72. syncnet_lr=1e-4,
  73. syncnet_eval_interval=10000,
  74. syncnet_checkpoint_interval=10000,
  75. disc_wt=0.07,
  76. disc_initial_learning_rate=1e-4,
  77. )
  78. def hparams_debug_string():
  79. values = hparams.values()
  80. hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
  81. return "Hyperparameters:\n" + "\n".join(hp)