generator.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. Snowflake ID 生成器,适配三台 MySQL 服务器(IP 尾号 251/252/253)。
  3. """
  4. import os
  5. import socket
  6. import threading
  7. import time
  8. import warnings
  9. from typing import Dict, Optional
  10. # Snowflake 位分配:41-bit 时间戳 | 5-bit 机房 | 5-bit 节点 | 12-bit 自增序列
  11. EPOCH_MS = 1704067200000 # 2024-01-01 00:00:00 UTC
  12. WORKER_ID_BITS = 5
  13. DATACENTER_ID_BITS = 5
  14. SEQUENCE_BITS = 12
  15. MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
  16. MAX_DATACENTER_ID = (1 << DATACENTER_ID_BITS) - 1
  17. SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
  18. WORKER_SHIFT = SEQUENCE_BITS
  19. DATACENTER_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS
  20. TIMESTAMP_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS + DATACENTER_ID_BITS
  21. # 默认映射:将三台 MySQL 服务器 IP(或尾段)映射到 worker_id
  22. # DEFAULT_IP_WORKER_MAP: Dict[str, int] = {
  23. # "251": 1,
  24. # "252": 2,
  25. # "253": 3,
  26. # }
  27. # 单机
  28. DEFAULT_IP_WORKER_MAP: Dict[str, int] = {
  29. "168": 1,
  30. }
  31. def _now_ms() -> int:
  32. return int(time.time() * 1000)
  33. def _wait_next_ms(last_ts: int) -> int:
  34. ts = _now_ms()
  35. while ts <= last_ts:
  36. ts = _now_ms()
  37. return ts
  38. def _get_host_ip() -> str:
  39. """获取当前主机的主 IP 地址。"""
  40. s = None
  41. try:
  42. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  43. # 连接一个无需真实可达的地址,只为获取本机出站 IP
  44. s.connect(("10.255.255.255", 1))
  45. ip = s.getsockname()[0]
  46. except Exception:
  47. ip = socket.gethostbyname(socket.gethostname())
  48. finally:
  49. if s is not None:
  50. s.close()
  51. return ip
  52. def resolve_worker_id(
  53. ip_worker_map: Optional[Dict[str, int]] = None,
  54. override_ip: Optional[str] = None,
  55. ) -> int:
  56. """
  57. 根据主机 IP 推导 worker_id。
  58. - 优先读取环境变量 SNOWFLAKE_WORKER_IP 作为 IP。
  59. - 其次使用 override_ip。
  60. - 再次使用本机 IP。
  61. - 支持完整 IP 匹配和末段匹配(如 "192.168.0.251" 或 "251")。
  62. """
  63. ip_worker_map = ip_worker_map or DEFAULT_IP_WORKER_MAP
  64. chosen_ip = os.getenv("SNOWFLAKE_WORKER_IP") or override_ip or _get_host_ip()
  65. if chosen_ip in ip_worker_map:
  66. return ip_worker_map[chosen_ip]
  67. last_octet = chosen_ip.split(".")[-1]
  68. if last_octet in ip_worker_map:
  69. return ip_worker_map[last_octet]
  70. raise ValueError(
  71. f"未能为 IP {chosen_ip} 匹配到 worker_id,请检查映射或设置 SNOWFLAKE_WORKER_IP。"
  72. )
  73. class SnowflakeGenerator:
  74. """
  75. 线程安全的 Snowflake ID 生成器。
  76. 默认 datacenter_id=0;worker_id 将根据 IP 映射推导。
  77. """
  78. def __init__(
  79. self,
  80. datacenter_id: int = 0,
  81. worker_id: Optional[int] = None,
  82. epoch_ms: int = EPOCH_MS,
  83. ip_worker_map: Optional[Dict[str, int]] = None,
  84. ):
  85. self.datacenter_id = datacenter_id
  86. self.worker_id = worker_id or resolve_worker_id(ip_worker_map)
  87. self.epoch_ms = epoch_ms
  88. if not 0 <= self.worker_id <= MAX_WORKER_ID:
  89. raise ValueError(f"worker_id 必须在 0-{MAX_WORKER_ID} 之间")
  90. if not 0 <= self.datacenter_id <= MAX_DATACENTER_ID:
  91. raise ValueError(f"datacenter_id 必须在 0-{MAX_DATACENTER_ID} 之间")
  92. self.sequence = 0
  93. self.last_timestamp = -1
  94. self._lock = threading.Lock()
  95. def next_id(self) -> int:
  96. """生成下一个全局唯一 ID。"""
  97. with self._lock:
  98. timestamp = _now_ms()
  99. if timestamp < self.last_timestamp:
  100. # 时钟回拨保护
  101. raise RuntimeError("检测到系统时钟回拨,停止发号。")
  102. if timestamp == self.last_timestamp:
  103. self.sequence = (self.sequence + 1) & SEQUENCE_MASK
  104. if self.sequence == 0:
  105. timestamp = _wait_next_ms(self.last_timestamp)
  106. else:
  107. self.sequence = 0
  108. self.last_timestamp = timestamp
  109. return (
  110. ((timestamp - self.epoch_ms) << TIMESTAMP_SHIFT)
  111. | (self.datacenter_id << DATACENTER_SHIFT)
  112. | (self.worker_id << WORKER_SHIFT)
  113. | self.sequence
  114. )
  115. # 默认生成器:datacenter_id=0,worker_id 按 IP 映射自动推导
  116. def _build_default_generator() -> SnowflakeGenerator:
  117. try:
  118. return SnowflakeGenerator()
  119. except ValueError as exc:
  120. warnings.warn(
  121. f"Snowflake 默认生成器未找到匹配的 IP,退回 worker_id=0。详情:{exc}",
  122. RuntimeWarning,
  123. )
  124. return SnowflakeGenerator(worker_id=0)
  125. default_generator = _build_default_generator()
  126. def next_id() -> int:
  127. """便捷函数,返回默认生成器的下一个 ID。"""
  128. return default_generator.next_id()