custom_cacheable.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # -*- coding: utf-8 -*-
  2. """
  3. 自定义缓存装饰器,复刻 Java 版 `@CustomCacheable` 的核心能力:
  4. 按照前缀、字段路径以及完整参数组合构造缓存 Key,并可选支持分页缓存。
  5. """
  6. from __future__ import annotations
  7. import functools
  8. import hashlib
  9. import inspect
  10. import json
  11. import logging
  12. import pickle
  13. from typing import Any, Callable, Mapping, MutableMapping
  14. from werkzeug.local import LocalProxy
  15. from ruoyi_admin.ext import redis_cache
  16. logger = logging.getLogger(__name__)
  17. DEFAULT_PAGE_SIZE = 30
  18. DEFAULT_PAGE_NUM = 1
  19. COMMON_SEPARATOR = ":"
  20. ARGS_HASH_PREFIX = "args"
  21. __all__ = ["custom_cacheable"]
  22. def custom_cacheable(
  23. key_prefix: str,
  24. key_field: str | None = None,
  25. use_query_params_as_key: bool = False,
  26. expire_time: int = 300,
  27. paginate: bool = False,
  28. page_number_field: str = "page_num",
  29. page_size_field: str = "page_size",
  30. ) -> Callable:
  31. """
  32. Redis 缓存装饰器,参数含义与用户给出的 Java 版注解保持一致,便于迁移。
  33. 示例:
  34. @custom_cacheable(
  35. key_prefix="recruit:list",
  36. key_field="query.company_id",
  37. paginate=True,
  38. page_number_field="query.page_num",
  39. page_size_field="query.page_size",
  40. )
  41. def list_recruit(query: RecruitQuery):
  42. ...
  43. """
  44. def decorator(func: Callable) -> Callable:
  45. signature = inspect.signature(func)
  46. @functools.wraps(func)
  47. def wrapper(*args: Any, **kwargs: Any) -> Any:
  48. client = _resolve_redis_client()
  49. if client is None or expire_time <= 0:
  50. return func(*args, **kwargs)
  51. bound_args = signature.bind_partial(*args, **kwargs)
  52. bound_args.apply_defaults()
  53. params = bound_args.arguments # OrderedDict:保留原始参数顺序
  54. base_key_segments = [key_prefix] if key_prefix else []
  55. if key_field:
  56. field_value = _get_value_by_field_path(params, key_field)
  57. if field_value not in (None, ""):
  58. base_key_segments.append(str(field_value))
  59. if use_query_params_as_key:
  60. args_hash = _hash_arguments(params)
  61. base_key_segments.append(f"{ARGS_HASH_PREFIX}:{args_hash}")
  62. if not base_key_segments:
  63. # 如果开发者没有提供前缀,则退回到函数限定名,避免空 key。
  64. base_key_segments.append(func.__qualname__)
  65. cache_key = COMMON_SEPARATOR.join(base_key_segments)
  66. if paginate:
  67. page_number = _extract_int_value(params, page_number_field, DEFAULT_PAGE_NUM)
  68. page_size = _extract_int_value(params, page_size_field, DEFAULT_PAGE_SIZE)
  69. cache_key = (
  70. f"{cache_key}{COMMON_SEPARATOR}{page_number}{COMMON_SEPARATOR}{page_size}"
  71. )
  72. else:
  73. page_number = page_size = None
  74. cached = _safe_redis_get(client, cache_key)
  75. if cached is not None:
  76. try:
  77. return pickle.loads(cached)
  78. except Exception as exc: # noqa: BLE001
  79. logger.debug("反序列化缓存数据失败 %s: %s", cache_key, exc)
  80. result = func(*args, **kwargs)
  81. # 开启分页时仅缓存列表或元组,避免单个对象导致缓存结构不一致。
  82. if paginate and not isinstance(result, (list, tuple)):
  83. return result
  84. try:
  85. payload = pickle.dumps(result)
  86. except Exception as exc: # noqa: BLE001
  87. logger.warning("序列化缓存数据失败 %s: %s", cache_key, exc)
  88. return result
  89. _safe_redis_setex(client, cache_key, int(expire_time), payload)
  90. return result
  91. return wrapper
  92. return decorator
  93. def _resolve_redis_client() -> LocalProxy | None:
  94. """
  95. 兼容 Flask LocalProxy 的获取逻辑,若无上下文则直接放弃缓存。
  96. """
  97. try:
  98. return redis_cache
  99. except RuntimeError:
  100. logger.debug("当前无应用上下文,跳过缓存调用")
  101. return None
  102. except Exception as exc: # noqa: BLE001
  103. logger.warning("获取 redis 连接失败: %s", exc)
  104. return None
  105. def _safe_redis_get(client: LocalProxy, cache_key: str) -> bytes | None:
  106. """
  107. 捕获 Redis 异常,防止缓存故障影响主流程。
  108. """
  109. try:
  110. return client.get(cache_key)
  111. except Exception as exc: # noqa: BLE001
  112. logger.warning("读取缓存失败 %s: %s", cache_key, exc)
  113. return None
  114. def _safe_redis_setex(client: LocalProxy, cache_key: str, expire: int, payload: bytes) -> None:
  115. """
  116. setex 包装,写入失败时仅记录日志。
  117. """
  118. try:
  119. client.setex(cache_key, expire, payload)
  120. except Exception as exc: # noqa: BLE001
  121. logger.warning("写入缓存失败 %s: %s", cache_key, exc)
  122. def _hash_arguments(params: Mapping[str, Any]) -> str:
  123. """
  124. 将参数转为稳定 JSON,并计算 SHA1,避免直接存储长 JSON。
  125. """
  126. normalized = _normalize_for_hash(params)
  127. serialized = json.dumps(normalized, sort_keys=True, ensure_ascii=True, default=str)
  128. return hashlib.sha1(serialized.encode("utf-8")).hexdigest()
  129. def _normalize_for_hash(value: Any) -> Any:
  130. """
  131. 递归展开常见类型,保证同样语义的参数能得到一致的哈希。
  132. """
  133. if isinstance(value, (str, int, float, bool)) or value is None:
  134. return value
  135. if isinstance(value, Mapping):
  136. return {str(k): _normalize_for_hash(v) for k, v in value.items()}
  137. if isinstance(value, (list, tuple, set)):
  138. return [_normalize_for_hash(v) for v in value]
  139. if hasattr(value, "__dict__"):
  140. data = {
  141. k: _normalize_for_hash(v)
  142. for k, v in vars(value).items()
  143. if not k.startswith("_")
  144. }
  145. if data:
  146. return data
  147. return repr(value)
  148. def _get_value_by_field_path(params: MutableMapping[str, Any], field_path: str) -> Any:
  149. """
  150. 按“参数名.属性.子属性”路径提取嵌套值。
  151. """
  152. if not field_path:
  153. return None
  154. parts = field_path.split(".")
  155. if not parts:
  156. return None
  157. target = params.get(parts[0])
  158. for part in parts[1:]:
  159. if target is None:
  160. return None
  161. target = _dig_value(target, part)
  162. return target
  163. def _dig_value(value: Any, attribute: str) -> Any:
  164. """
  165. 支持字典、列表(下标)、对象属性的通用取值方法。
  166. """
  167. if value is None:
  168. return None
  169. if isinstance(value, Mapping):
  170. return value.get(attribute)
  171. if isinstance(value, (list, tuple)):
  172. if attribute.isdigit():
  173. index = int(attribute)
  174. if 0 <= index < len(value):
  175. return value[index]
  176. return None
  177. return getattr(value, attribute, None)
  178. def _extract_int_value(
  179. params: MutableMapping[str, Any], field_path: str | None, default_value: int
  180. ) -> int:
  181. """
  182. 读取分页参数,自动完成类型转换及异常兜底。
  183. """
  184. if not field_path:
  185. return default_value
  186. raw_value = _get_value_by_field_path(params, field_path)
  187. if raw_value is None or isinstance(raw_value, bool):
  188. return default_value
  189. try:
  190. return int(raw_value)
  191. except (TypeError, ValueError):
  192. return default_value