| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- # -*- coding: utf-8 -*-
- """
- 自定义缓存装饰器,复刻 Java 版 `@CustomCacheable` 的核心能力:
- 按照前缀、字段路径以及完整参数组合构造缓存 Key,并可选支持分页缓存。
- """
- from __future__ import annotations
- import functools
- import hashlib
- import inspect
- import json
- import logging
- import pickle
- from typing import Any, Callable, Mapping, MutableMapping
- from werkzeug.local import LocalProxy
- from ruoyi_admin.ext import redis_cache
- logger = logging.getLogger(__name__)
- DEFAULT_PAGE_SIZE = 30
- DEFAULT_PAGE_NUM = 1
- COMMON_SEPARATOR = ":"
- ARGS_HASH_PREFIX = "args"
- __all__ = ["custom_cacheable"]
- def custom_cacheable(
- key_prefix: str,
- key_field: str | None = None,
- use_query_params_as_key: bool = False,
- expire_time: int = 300,
- paginate: bool = False,
- page_number_field: str = "page_num",
- page_size_field: str = "page_size",
- ) -> Callable:
- """
- Redis 缓存装饰器,参数含义与用户给出的 Java 版注解保持一致,便于迁移。
- 示例:
- @custom_cacheable(
- key_prefix="recruit:list",
- key_field="query.company_id",
- paginate=True,
- page_number_field="query.page_num",
- page_size_field="query.page_size",
- )
- def list_recruit(query: RecruitQuery):
- ...
- """
- def decorator(func: Callable) -> Callable:
- signature = inspect.signature(func)
- @functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- client = _resolve_redis_client()
- if client is None or expire_time <= 0:
- return func(*args, **kwargs)
- bound_args = signature.bind_partial(*args, **kwargs)
- bound_args.apply_defaults()
- params = bound_args.arguments # OrderedDict:保留原始参数顺序
- base_key_segments = [key_prefix] if key_prefix else []
- if key_field:
- field_value = _get_value_by_field_path(params, key_field)
- if field_value not in (None, ""):
- base_key_segments.append(str(field_value))
- if use_query_params_as_key:
- args_hash = _hash_arguments(params)
- base_key_segments.append(f"{ARGS_HASH_PREFIX}:{args_hash}")
- if not base_key_segments:
- # 如果开发者没有提供前缀,则退回到函数限定名,避免空 key。
- base_key_segments.append(func.__qualname__)
- cache_key = COMMON_SEPARATOR.join(base_key_segments)
- if paginate:
- page_number = _extract_int_value(params, page_number_field, DEFAULT_PAGE_NUM)
- page_size = _extract_int_value(params, page_size_field, DEFAULT_PAGE_SIZE)
- cache_key = (
- f"{cache_key}{COMMON_SEPARATOR}{page_number}{COMMON_SEPARATOR}{page_size}"
- )
- else:
- page_number = page_size = None
- cached = _safe_redis_get(client, cache_key)
- if cached is not None:
- try:
- return pickle.loads(cached)
- except Exception as exc: # noqa: BLE001
- logger.debug("反序列化缓存数据失败 %s: %s", cache_key, exc)
- result = func(*args, **kwargs)
- # 开启分页时仅缓存列表或元组,避免单个对象导致缓存结构不一致。
- if paginate and not isinstance(result, (list, tuple)):
- return result
- try:
- payload = pickle.dumps(result)
- except Exception as exc: # noqa: BLE001
- logger.warning("序列化缓存数据失败 %s: %s", cache_key, exc)
- return result
- _safe_redis_setex(client, cache_key, int(expire_time), payload)
- return result
- return wrapper
- return decorator
- def _resolve_redis_client() -> LocalProxy | None:
- """
- 兼容 Flask LocalProxy 的获取逻辑,若无上下文则直接放弃缓存。
- """
- try:
- return redis_cache
- except RuntimeError:
- logger.debug("当前无应用上下文,跳过缓存调用")
- return None
- except Exception as exc: # noqa: BLE001
- logger.warning("获取 redis 连接失败: %s", exc)
- return None
- def _safe_redis_get(client: LocalProxy, cache_key: str) -> bytes | None:
- """
- 捕获 Redis 异常,防止缓存故障影响主流程。
- """
- try:
- return client.get(cache_key)
- except Exception as exc: # noqa: BLE001
- logger.warning("读取缓存失败 %s: %s", cache_key, exc)
- return None
- def _safe_redis_setex(client: LocalProxy, cache_key: str, expire: int, payload: bytes) -> None:
- """
- setex 包装,写入失败时仅记录日志。
- """
- try:
- client.setex(cache_key, expire, payload)
- except Exception as exc: # noqa: BLE001
- logger.warning("写入缓存失败 %s: %s", cache_key, exc)
- def _hash_arguments(params: Mapping[str, Any]) -> str:
- """
- 将参数转为稳定 JSON,并计算 SHA1,避免直接存储长 JSON。
- """
- normalized = _normalize_for_hash(params)
- serialized = json.dumps(normalized, sort_keys=True, ensure_ascii=True, default=str)
- return hashlib.sha1(serialized.encode("utf-8")).hexdigest()
- def _normalize_for_hash(value: Any) -> Any:
- """
- 递归展开常见类型,保证同样语义的参数能得到一致的哈希。
- """
- if isinstance(value, (str, int, float, bool)) or value is None:
- return value
- if isinstance(value, Mapping):
- return {str(k): _normalize_for_hash(v) for k, v in value.items()}
- if isinstance(value, (list, tuple, set)):
- return [_normalize_for_hash(v) for v in value]
- if hasattr(value, "__dict__"):
- data = {
- k: _normalize_for_hash(v)
- for k, v in vars(value).items()
- if not k.startswith("_")
- }
- if data:
- return data
- return repr(value)
- def _get_value_by_field_path(params: MutableMapping[str, Any], field_path: str) -> Any:
- """
- 按“参数名.属性.子属性”路径提取嵌套值。
- """
- if not field_path:
- return None
- parts = field_path.split(".")
- if not parts:
- return None
- target = params.get(parts[0])
- for part in parts[1:]:
- if target is None:
- return None
- target = _dig_value(target, part)
- return target
- def _dig_value(value: Any, attribute: str) -> Any:
- """
- 支持字典、列表(下标)、对象属性的通用取值方法。
- """
- if value is None:
- return None
- if isinstance(value, Mapping):
- return value.get(attribute)
- if isinstance(value, (list, tuple)):
- if attribute.isdigit():
- index = int(attribute)
- if 0 <= index < len(value):
- return value[index]
- return None
- return getattr(value, attribute, None)
- def _extract_int_value(
- params: MutableMapping[str, Any], field_path: str | None, default_value: int
- ) -> int:
- """
- 读取分页参数,自动完成类型转换及异常兜底。
- """
- if not field_path:
- return default_value
- raw_value = _get_value_by_field_path(params, field_path)
- if raw_value is None or isinstance(raw_value, bool):
- return default_value
- try:
- return int(raw_value)
- except (TypeError, ValueError):
- return default_value
|