Răsfoiți Sursa

自定义缓存注解

SpringSunYY 4 luni în urmă
părinte
comite
3e3ea60f2e

+ 10 - 0
ruoyi_framework/descriptor/__init__.py

@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+"""
+descriptor 包用于存放与 AOP/注解语义相关的工具。
+"""
+
+from .custom_cacheable import custom_cacheable
+from .custom_cache_evict import custom_cache_evict
+
+__all__ = ["custom_cacheable", "custom_cache_evict"]
+

+ 139 - 0
ruoyi_framework/descriptor/custom_cache_evict.py

@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+"""
+自定义缓存清理装饰器,对应 Java 版本的 `@CustomCacheEvict`。
+执行目标函数后,根据前缀/字段路径/参数组合构造通配符 Key,并批量删除 Redis 缓存。
+"""
+
+from __future__ import annotations
+
+import functools
+import inspect
+import logging
+from typing import Any, Callable, Iterable, Mapping, MutableMapping, Sequence
+
+from werkzeug.local import LocalProxy
+
+from .custom_cacheable import (
+    ARGS_HASH_PREFIX,
+    COMMON_SEPARATOR,
+    _get_value_by_field_path,
+    _hash_arguments,
+    _resolve_redis_client,
+)
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["custom_cache_evict"]
+
+
+def custom_cache_evict(
+    key_prefixes: Sequence[str],
+    key_fields: Sequence[str] | None = None,
+    use_query_params_as_key: bool = False,
+) -> Callable:
+    """
+    Redis 缓存清理装饰器。
+
+    Args:
+        key_prefixes: 缓存前缀数组,必填,对应待清理的一组 Key。
+        key_fields:   与前缀一一对应的字段路径,允许缺省;缺省时直接按前缀通配符清理。
+        use_query_params_as_key: 是否将函数参数序列化为 Key 的一部分(需与存储端保持一致)。
+    """
+
+    if not key_prefixes:
+        raise ValueError("key_prefixes 不能为空")
+
+    def decorator(func: Callable) -> Callable:
+        signature = inspect.signature(func)
+
+        @functools.wraps(func)
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
+            result = func(*args, **kwargs)
+            client = _resolve_redis_client()
+            if client is None:
+                return result
+
+            params = _bind_arguments(signature, *args, **kwargs)
+            args_hash = _hash_arguments(params) if use_query_params_as_key else None
+
+            for idx, prefix in enumerate(key_prefixes):
+                if not prefix:
+                    continue
+
+                pattern = prefix
+                field_value = _extract_field_value(params, key_fields, idx)
+
+                if field_value not in (None, ""):
+                    pattern = f"{pattern}{COMMON_SEPARATOR}{field_value}"
+
+                if use_query_params_as_key and args_hash:
+                    pattern = f"{pattern}{COMMON_SEPARATOR}{ARGS_HASH_PREFIX}:{args_hash}"
+
+                pattern = f"{pattern}*"
+                _delete_keys_by_pattern(client, pattern)
+
+            return result
+
+        return wrapper
+
+    return decorator
+
+
+def _bind_arguments(signature: inspect.Signature, *args: Any, **kwargs: Any) -> MutableMapping[str, Any]:
+    """
+    对函数参数做一次绑定,得到“参数名 -> 值”的映射,便于后续取字段。
+    """
+
+    bound_args = signature.bind_partial(*args, **kwargs)
+    bound_args.apply_defaults()
+    return bound_args.arguments
+
+
+def _extract_field_value(
+    params: Mapping[str, Any],
+    key_fields: Sequence[str] | None,
+    index: int,
+) -> Any:
+    """
+    根据 key_fields 配置提取对应的嵌套字段值,超过范围时返回 None。
+    """
+
+    if not key_fields or index >= len(key_fields):
+        return None
+    field_path = key_fields[index]
+    if not field_path:
+        return None
+    return _get_value_by_field_path(params, field_path)
+
+
+def _delete_keys_by_pattern(client: LocalProxy, pattern: str) -> None:
+    """
+    使用 scan_iter 增量拉取匹配 Key 并删除,避免阻塞 Redis。
+    """
+
+    try:
+        pipeline = client.pipeline(transaction=False)
+        batch: list[str] = []
+        for key in client.scan_iter(match=pattern, count=200):
+            batch.append(key)
+            if len(batch) >= 200:
+                _execute_delete_batch(pipeline, batch)
+                batch.clear()
+        if batch:
+            _execute_delete_batch(pipeline, batch)
+    except Exception as exc:  # noqa: BLE001
+        logger.warning("按模式删除缓存失败 %s: %s", pattern, exc)
+
+
+def _execute_delete_batch(pipeline: Any, batch: Iterable[str]) -> None:
+    """
+    批量删除 Key 并立即执行 pipeline。
+    """
+
+    try:
+        for key in batch:
+            pipeline.delete(key)
+        pipeline.execute()
+    except Exception as exc:  # noqa: BLE001
+        logger.warning("批量删除缓存失败: %s", exc)
+

+ 244 - 0
ruoyi_framework/descriptor/custom_cacheable.py

@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+"""
+自定义缓存装饰器,复刻 Java 版 `@CustomCacheable` 的核心能力:
+按照前缀、字段路径以及完整参数组合构造缓存 Key,并可选支持分页缓存。
+"""
+
+from __future__ import annotations
+
+import functools
+import hashlib
+import inspect
+import json
+import logging
+import pickle
+from typing import Any, Callable, Mapping, MutableMapping
+
+from werkzeug.local import LocalProxy
+
+from ruoyi_admin.ext import redis_cache
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_PAGE_SIZE = 30
+DEFAULT_PAGE_NUM = 1
+COMMON_SEPARATOR = ":"
+ARGS_HASH_PREFIX = "args"
+
+__all__ = ["custom_cacheable"]
+
+
+def custom_cacheable(
+    key_prefix: str,
+    key_field: str | None = None,
+    use_query_params_as_key: bool = False,
+    expire_time: int = 300,
+    paginate: bool = False,
+    page_number_field: str = "page_num",
+    page_size_field: str = "page_size",
+) -> Callable:
+    """
+    Redis 缓存装饰器,参数含义与用户给出的 Java 版注解保持一致,便于迁移。
+
+    示例:
+        @custom_cacheable(
+            key_prefix="recruit:list",
+            key_field="query.company_id",
+            paginate=True,
+            page_number_field="query.page_num",
+            page_size_field="query.page_size",
+        )
+        def list_recruit(query: RecruitQuery):
+            ...
+    """
+
+    def decorator(func: Callable) -> Callable:
+        signature = inspect.signature(func)
+
+        @functools.wraps(func)
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
+            client = _resolve_redis_client()
+            if client is None or expire_time <= 0:
+                return func(*args, **kwargs)
+
+            bound_args = signature.bind_partial(*args, **kwargs)
+            bound_args.apply_defaults()
+            params = bound_args.arguments  # OrderedDict:保留原始参数顺序
+
+            base_key_segments = [key_prefix] if key_prefix else []
+
+            if key_field:
+                field_value = _get_value_by_field_path(params, key_field)
+                if field_value not in (None, ""):
+                    base_key_segments.append(str(field_value))
+
+            if use_query_params_as_key:
+                args_hash = _hash_arguments(params)
+                base_key_segments.append(f"{ARGS_HASH_PREFIX}:{args_hash}")
+
+            if not base_key_segments:
+                # 如果开发者没有提供前缀,则退回到函数限定名,避免空 key。
+                base_key_segments.append(func.__qualname__)
+
+            cache_key = COMMON_SEPARATOR.join(base_key_segments)
+
+            if paginate:
+                page_number = _extract_int_value(params, page_number_field, DEFAULT_PAGE_NUM)
+                page_size = _extract_int_value(params, page_size_field, DEFAULT_PAGE_SIZE)
+                cache_key = (
+                    f"{cache_key}{COMMON_SEPARATOR}{page_number}{COMMON_SEPARATOR}{page_size}"
+                )
+            else:
+                page_number = page_size = None
+
+            cached = _safe_redis_get(client, cache_key)
+            if cached is not None:
+                try:
+                    return pickle.loads(cached)
+                except Exception as exc:  # noqa: BLE001
+                    logger.debug("反序列化缓存数据失败 %s: %s", cache_key, exc)
+
+            result = func(*args, **kwargs)
+
+            # 开启分页时仅缓存列表或元组,避免单个对象导致缓存结构不一致。
+            if paginate and not isinstance(result, (list, tuple)):
+                return result
+
+            try:
+                payload = pickle.dumps(result)
+            except Exception as exc:  # noqa: BLE001
+                logger.warning("序列化缓存数据失败 %s: %s", cache_key, exc)
+                return result
+
+            _safe_redis_setex(client, cache_key, int(expire_time), payload)
+            return result
+
+        return wrapper
+
+    return decorator
+
+
+def _resolve_redis_client() -> LocalProxy | None:
+    """
+    兼容 Flask LocalProxy 的获取逻辑,若无上下文则直接放弃缓存。
+    """
+
+    try:
+        return redis_cache
+    except RuntimeError:
+        logger.debug("当前无应用上下文,跳过缓存调用")
+        return None
+    except Exception as exc:  # noqa: BLE001
+        logger.warning("获取 redis 连接失败: %s", exc)
+        return None
+
+
+def _safe_redis_get(client: LocalProxy, cache_key: str) -> bytes | None:
+    """
+    捕获 Redis 异常,防止缓存故障影响主流程。
+    """
+
+    try:
+        return client.get(cache_key)
+    except Exception as exc:  # noqa: BLE001
+        logger.warning("读取缓存失败 %s: %s", cache_key, exc)
+        return None
+
+
+def _safe_redis_setex(client: LocalProxy, cache_key: str, expire: int, payload: bytes) -> None:
+    """
+    setex 包装,写入失败时仅记录日志。
+    """
+
+    try:
+        client.setex(cache_key, expire, payload)
+    except Exception as exc:  # noqa: BLE001
+        logger.warning("写入缓存失败 %s: %s", cache_key, exc)
+
+
+def _hash_arguments(params: Mapping[str, Any]) -> str:
+    """
+    将参数转为稳定 JSON,并计算 SHA1,避免直接存储长 JSON。
+    """
+
+    normalized = _normalize_for_hash(params)
+    serialized = json.dumps(normalized, sort_keys=True, ensure_ascii=True, default=str)
+    return hashlib.sha1(serialized.encode("utf-8")).hexdigest()
+
+
+def _normalize_for_hash(value: Any) -> Any:
+    """
+    递归展开常见类型,保证同样语义的参数能得到一致的哈希。
+    """
+
+    if isinstance(value, (str, int, float, bool)) or value is None:
+        return value
+    if isinstance(value, Mapping):
+        return {str(k): _normalize_for_hash(v) for k, v in value.items()}
+    if isinstance(value, (list, tuple, set)):
+        return [_normalize_for_hash(v) for v in value]
+    if hasattr(value, "__dict__"):
+        data = {
+            k: _normalize_for_hash(v)
+            for k, v in vars(value).items()
+            if not k.startswith("_")
+        }
+        if data:
+            return data
+    return repr(value)
+
+
+def _get_value_by_field_path(params: MutableMapping[str, Any], field_path: str) -> Any:
+    """
+    按“参数名.属性.子属性”路径提取嵌套值。
+    """
+
+    if not field_path:
+        return None
+    parts = field_path.split(".")
+    if not parts:
+        return None
+
+    target = params.get(parts[0])
+    for part in parts[1:]:
+        if target is None:
+            return None
+        target = _dig_value(target, part)
+    return target
+
+
+def _dig_value(value: Any, attribute: str) -> Any:
+    """
+    支持字典、列表(下标)、对象属性的通用取值方法。
+    """
+
+    if value is None:
+        return None
+    if isinstance(value, Mapping):
+        return value.get(attribute)
+    if isinstance(value, (list, tuple)):
+        if attribute.isdigit():
+            index = int(attribute)
+            if 0 <= index < len(value):
+                return value[index]
+        return None
+    return getattr(value, attribute, None)
+
+
+def _extract_int_value(
+    params: MutableMapping[str, Any], field_path: str | None, default_value: int
+) -> int:
+    """
+    读取分页参数,自动完成类型转换及异常兜底。
+    """
+
+    if not field_path:
+        return default_value
+    raw_value = _get_value_by_field_path(params, field_path)
+    if raw_value is None or isinstance(raw_value, bool):
+        return default_value
+    try:
+        return int(raw_value)
+    except (TypeError, ValueError):
+        return default_value
+

+ 32 - 6
ruoyi_generator/vm/py/mapper.py.vm

@@ -31,12 +31,38 @@ class {{ underscore(table.class_name) }}_mapper:
             # 构建查询条件
             stmt = select({{ underscore(table.class_name) }}_po)
 {% for column in table.columns %}
-{% if column.is_query and column.query_type == 'EQ' %}
-            if {{ underscore(table.business_name) }}.{{ underscore(column.java_field) }} is not None:
-                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ underscore(column.java_field) }} == {{ underscore(table.business_name) }}.{{ underscore(column.java_field) }})
-{% elif column.is_query and column.query_type == 'LIKE' %}
-            if {{ underscore(table.business_name) }}.{{ underscore(column.java_field) }}:
-                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ underscore(column.java_field) }}.like("%" + str({{ underscore(table.business_name) }}.{{ underscore(column.java_field) }}) + "%"))
+{% if column.is_query %}
+{%- set field_name = underscore(column.java_field) %}
+{%- if column.query_type == 'EQ' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} == {{ underscore(table.business_name) }}.{{ field_name }})
+{%- elif column.query_type == 'NE' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} != {{ underscore(table.business_name) }}.{{ field_name }})
+{%- elif column.query_type == 'GT' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} > {{ underscore(table.business_name) }}.{{ field_name }})
+{%- elif column.query_type == 'GTE' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} >= {{ underscore(table.business_name) }}.{{ field_name }})
+{%- elif column.query_type == 'LT' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} < {{ underscore(table.business_name) }}.{{ field_name }})
+{%- elif column.query_type == 'LTE' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} <= {{ underscore(table.business_name) }}.{{ field_name }})
+{%- elif column.query_type == 'LIKE' %}
+            if {{ underscore(table.business_name) }}.{{ field_name }}:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }}.like("%" + str({{ underscore(table.business_name) }}.{{ field_name }}) + "%"))
+{%- elif column.query_type == 'BETWEEN' %}
+            _params = getattr({{ underscore(table.business_name) }}, "params", {}) or {}
+            begin_val = _params.get("begin{{ capitalize_first(column.java_field) }}")
+            end_val = _params.get("end{{ capitalize_first(column.java_field) }}")
+            if begin_val is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} >= begin_val)
+            if end_val is not None:
+                stmt = stmt.where({{ underscore(table.class_name) }}_po.{{ field_name }} <= end_val)
+{%- endif %}
 {% endif %}
 {% endfor %}
             if "criterian_meta" in g and g.criterian_meta.page: