# -*- coding: utf-8 -*- """ 自定义缓存装饰器,复刻 Java 版 `@CustomCacheable` 的核心能力: 按照前缀、字段路径以及完整参数组合构造缓存 Key,并可选支持分页缓存。 """ from __future__ import annotations import functools import hashlib import inspect import json import logging import pickle from typing import Any, Callable, Mapping, MutableMapping from werkzeug.local import LocalProxy from ruoyi_admin.ext import redis_cache logger = logging.getLogger(__name__) DEFAULT_PAGE_SIZE = 30 DEFAULT_PAGE_NUM = 1 COMMON_SEPARATOR = ":" ARGS_HASH_PREFIX = "args" __all__ = ["custom_cacheable"] def custom_cacheable( key_prefix: str, key_field: str | None = None, use_query_params_as_key: bool = False, expire_time: int = 300, paginate: bool = False, page_number_field: str = "page_num", page_size_field: str = "page_size", ) -> Callable: """ Redis 缓存装饰器,参数含义与用户给出的 Java 版注解保持一致,便于迁移。 示例: @custom_cacheable( key_prefix="recruit:list", key_field="query.company_id", paginate=True, page_number_field="query.page_num", page_size_field="query.page_size", ) def list_recruit(query: RecruitQuery): ... """ def decorator(func: Callable) -> Callable: signature = inspect.signature(func) @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: client = _resolve_redis_client() if client is None or expire_time <= 0: return func(*args, **kwargs) bound_args = signature.bind_partial(*args, **kwargs) bound_args.apply_defaults() params = bound_args.arguments # OrderedDict:保留原始参数顺序 base_key_segments = [key_prefix] if key_prefix else [] if key_field: field_value = _get_value_by_field_path(params, key_field) if field_value not in (None, ""): base_key_segments.append(str(field_value)) if use_query_params_as_key: args_hash = _hash_arguments(params) base_key_segments.append(f"{ARGS_HASH_PREFIX}:{args_hash}") if not base_key_segments: # 如果开发者没有提供前缀,则退回到函数限定名,避免空 key。 base_key_segments.append(func.__qualname__) cache_key = COMMON_SEPARATOR.join(base_key_segments) if paginate: page_number = _extract_int_value(params, page_number_field, DEFAULT_PAGE_NUM) page_size = _extract_int_value(params, page_size_field, DEFAULT_PAGE_SIZE) cache_key = ( f"{cache_key}{COMMON_SEPARATOR}{page_number}{COMMON_SEPARATOR}{page_size}" ) else: page_number = page_size = None cached = _safe_redis_get(client, cache_key) if cached is not None: try: return pickle.loads(cached) except Exception as exc: # noqa: BLE001 logger.debug("反序列化缓存数据失败 %s: %s", cache_key, exc) result = func(*args, **kwargs) # 开启分页时仅缓存列表或元组,避免单个对象导致缓存结构不一致。 if paginate and not isinstance(result, (list, tuple)): return result try: payload = pickle.dumps(result) except Exception as exc: # noqa: BLE001 logger.warning("序列化缓存数据失败 %s: %s", cache_key, exc) return result _safe_redis_setex(client, cache_key, int(expire_time), payload) return result return wrapper return decorator def _resolve_redis_client() -> LocalProxy | None: """ 兼容 Flask LocalProxy 的获取逻辑,若无上下文则直接放弃缓存。 """ try: return redis_cache except RuntimeError: logger.debug("当前无应用上下文,跳过缓存调用") return None except Exception as exc: # noqa: BLE001 logger.warning("获取 redis 连接失败: %s", exc) return None def _safe_redis_get(client: LocalProxy, cache_key: str) -> bytes | None: """ 捕获 Redis 异常,防止缓存故障影响主流程。 """ try: return client.get(cache_key) except Exception as exc: # noqa: BLE001 logger.warning("读取缓存失败 %s: %s", cache_key, exc) return None def _safe_redis_setex(client: LocalProxy, cache_key: str, expire: int, payload: bytes) -> None: """ setex 包装,写入失败时仅记录日志。 """ try: client.setex(cache_key, expire, payload) except Exception as exc: # noqa: BLE001 logger.warning("写入缓存失败 %s: %s", cache_key, exc) def _hash_arguments(params: Mapping[str, Any]) -> str: """ 将参数转为稳定 JSON,并计算 SHA1,避免直接存储长 JSON。 """ normalized = _normalize_for_hash(params) serialized = json.dumps(normalized, sort_keys=True, ensure_ascii=True, default=str) return hashlib.sha1(serialized.encode("utf-8")).hexdigest() def _normalize_for_hash(value: Any) -> Any: """ 递归展开常见类型,保证同样语义的参数能得到一致的哈希。 """ if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, Mapping): return {str(k): _normalize_for_hash(v) for k, v in value.items()} if isinstance(value, (list, tuple, set)): return [_normalize_for_hash(v) for v in value] if hasattr(value, "__dict__"): data = { k: _normalize_for_hash(v) for k, v in vars(value).items() if not k.startswith("_") } if data: return data return repr(value) def _get_value_by_field_path(params: MutableMapping[str, Any], field_path: str) -> Any: """ 按“参数名.属性.子属性”路径提取嵌套值。 """ if not field_path: return None parts = field_path.split(".") if not parts: return None target = params.get(parts[0]) for part in parts[1:]: if target is None: return None target = _dig_value(target, part) return target def _dig_value(value: Any, attribute: str) -> Any: """ 支持字典、列表(下标)、对象属性的通用取值方法。 """ if value is None: return None if isinstance(value, Mapping): return value.get(attribute) if isinstance(value, (list, tuple)): if attribute.isdigit(): index = int(attribute) if 0 <= index < len(value): return value[index] return None return getattr(value, attribute, None) def _extract_int_value( params: MutableMapping[str, Any], field_path: str | None, default_value: int ) -> int: """ 读取分页参数,自动完成类型转换及异常兜底。 """ if not field_path: return default_value raw_value = _get_value_by_field_path(params, field_path) if raw_value is None or isinstance(raw_value, bool): return default_value try: return int(raw_value) except (TypeError, ValueError): return default_value