Jelajahi Sumber

优化params查询参数

SpringSunYY 4 bulan lalu
induk
melakukan
77b60bc20d
2 mengubah file dengan 64 tambahan dan 35 penghapusan
  1. 34 27
      ruoyi_common/base/reqparser.py
  2. 30 8
      ruoyi_common/base/transformer.py

+ 34 - 27
ruoyi_common/base/reqparser.py

@@ -1,13 +1,13 @@
-
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import ClassVar, Dict, Iterable, Set, Type
 from typing import ClassVar, Dict, Iterable, Set, Type
+
 from flask import g, request
 from flask import g, request
 from pydantic import AliasChoices, AliasPath, BaseModel
 from pydantic import AliasChoices, AliasPath, BaseModel
 from werkzeug.datastructures import ImmutableMultiDict
 from werkzeug.datastructures import ImmutableMultiDict
-from werkzeug.exceptions import BadRequest,UnsupportedMediaType
+from werkzeug.exceptions import BadRequest, UnsupportedMediaType
 
 
-from ruoyi_common.base.model import BaseEntity, CriterianMeta, ExtraModel, \
+from ruoyi_common.base.model import CriterianMeta, ExtraModel, \
     BaseEntity, OrderModel, PageModel, VoValidatorContext
     BaseEntity, OrderModel, PageModel, VoValidatorContext
 from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
 from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
 
 
@@ -24,7 +24,7 @@ class AbsReqParser(ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         """
         """
         适配模型
         适配模型
 
 
@@ -37,7 +37,7 @@ class AbsReqParser(ABC):
         """
         """
 
 
     @abstractmethod
     @abstractmethod
-    def prepare_factory(self, factory:BaseSchemaFactory):
+    def prepare_factory(self, factory: BaseSchemaFactory):
         """
         """
         准备工厂
         准备工厂
 
 
@@ -51,15 +51,16 @@ class AbsReqParser(ABC):
         准备数据
         准备数据
         """
         """
 
 
+
 class BaseReqParser(AbsReqParser):
 class BaseReqParser(AbsReqParser):
 
 
     def data(self) -> Dict:
     def data(self) -> Dict:
         pass
         pass
 
 
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         pass
         pass
 
 
-    def prepare_factory(self, factory:BaseSchemaFactory):
+    def prepare_factory(self, factory: BaseSchemaFactory):
         pass
         pass
 
 
     def prepare(self):
     def prepare(self):
@@ -68,7 +69,7 @@ class BaseReqParser(AbsReqParser):
 
 
 class QueryReqParser(BaseReqParser):
 class QueryReqParser(BaseReqParser):
 
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         self.context = context
         self.context = context
         self.extra_model = ExtraModel
         self.extra_model = ExtraModel
 
 
@@ -81,28 +82,38 @@ class QueryReqParser(BaseReqParser):
         g.criterian_meta = self.criterian_meta
         g.criterian_meta = self.criterian_meta
 
 
     def validate_request(self) -> Dict:
     def validate_request(self) -> Dict:
-        return request.args.to_dict()
+        data = request.args.to_dict()
+        # 兼容前端传参形式 params[xxx]=yyy
+        params_dict = {}
+        for key, val in list(data.items()):
+            if key.startswith("params[") and key.endswith("]"):
+                inner = key[len("params["):-1]
+                params_dict[inner] = val
+                data.pop(key, None)
+        if params_dict:
+            data["params"] = params_dict
+        return data
 
 
     def data(self) -> Dict:
     def data(self) -> Dict:
         data = self.validate_request().copy()
         data = self.validate_request().copy()
         if self.context.is_page:
         if self.context.is_page:
-            page = PageModel.model_validate(data,context=self.context)
+            page = PageModel.model_validate(data, context=self.context)
             if page.model_fields_set:
             if page.model_fields_set:
                 self.criterian_meta.page = page
                 self.criterian_meta.page = page
             self._remove_model_aliases(data, PageModel)
             self._remove_model_aliases(data, PageModel)
         if self.context.is_sort:
         if self.context.is_sort:
-            sort = OrderModel.model_validate(data,context=self.context)
+            sort = OrderModel.model_validate(data, context=self.context)
             if sort.model_fields_set:
             if sort.model_fields_set:
                 self.criterian_meta.sort = sort
                 self.criterian_meta.sort = sort
             self._remove_model_aliases(data, OrderModel)
             self._remove_model_aliases(data, OrderModel)
         if self.extra_model:
         if self.extra_model:
-            extra = self.extra_model.model_validate(data,context=self.context)
+            extra = self.extra_model.model_validate(data, context=self.context)
             if extra.model_fields_set:
             if extra.model_fields_set:
                 self.criterian_meta.extra = extra
                 self.criterian_meta.extra = extra
             self._remove_model_aliases(data, self.extra_model)
             self._remove_model_aliases(data, self.extra_model)
         return data
         return data
 
 
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         data = self.data()
         data = self.data()
         # 对于查询参数,只保留模型中定义的字段和别名,忽略额外字段
         # 对于查询参数,只保留模型中定义的字段和别名,忽略额外字段
         # 收集模型中所有字段名和别名
         # 收集模型中所有字段名和别名
@@ -177,10 +188,9 @@ class PathReqParser(BaseReqParser):
 
 
 @dataclass
 @dataclass
 class BodyReqParser(BaseReqParser):
 class BodyReqParser(BaseReqParser):
-
     minetype: ClassVar[str] = "application/json"
     minetype: ClassVar[str] = "application/json"
 
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         self.context = context
         self.context = context
 
 
     def validate_request(self) -> Dict:
     def validate_request(self) -> Dict:
@@ -202,7 +212,7 @@ class BodyReqParser(BaseReqParser):
         data = self.validate_request().copy()
         data = self.validate_request().copy()
         return data
         return data
 
 
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         data = self.data()
         data = self.data()
         bo = bo_model.model_validate(data, context=self.context)
         bo = bo_model.model_validate(data, context=self.context)
         return bo
         return bo
@@ -210,21 +220,20 @@ class BodyReqParser(BaseReqParser):
 
 
 @dataclass
 @dataclass
 class FormUrlencodedQueryReqParser(QueryReqParser):
 class FormUrlencodedQueryReqParser(QueryReqParser):
-
     minetype: ClassVar[str] = "application/x-www-form-urlencoded"
     minetype: ClassVar[str] = "application/x-www-form-urlencoded"
 
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         super().__init__(context)
         super().__init__(context)
 
 
     def validate_request(self) -> Dict:
     def validate_request(self) -> Dict:
         content_type = request.headers.get("Content-Type", "").lower()
         content_type = request.headers.get("Content-Type", "").lower()
         minetype = content_type.split(";")[0]
         minetype = content_type.split(";")[0]
         if minetype == self.minetype:
         if minetype == self.minetype:
-            form:ImmutableMultiDict = request.form
+            form: ImmutableMultiDict = request.form
             body = form.to_dict()
             body = form.to_dict()
         else:
         else:
             raise UnsupportedMediaType(
             raise UnsupportedMediaType(
-                description="除了{},content-type不支持{}".format(self.minetype,minetype)
+                description="除了{},content-type不支持{}".format(self.minetype, minetype)
             )
             )
         return body
         return body
 
 
@@ -232,19 +241,18 @@ class FormUrlencodedQueryReqParser(QueryReqParser):
 @dataclass
 @dataclass
 class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
 class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
 
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         super().__init__(context)
         super().__init__(context)
 
 
 
 
 class FormReqParser(BaseReqParser):
 class FormReqParser(BaseReqParser):
-
     minetype: ClassVar[str] = "multipart/form-data"
     minetype: ClassVar[str] = "multipart/form-data"
 
 
     def __init__(
     def __init__(
             self,
             self,
-            is_form:bool=True,
-            is_query:bool=False,
-            is_file:bool|None=None,
+            is_form: bool = True,
+            is_query: bool = False,
+            is_file: bool | None = None,
     ):
     ):
         self.is_form = is_form
         self.is_form = is_form
         self.is_query = is_query
         self.is_query = is_query
@@ -263,7 +271,7 @@ class FormReqParser(BaseReqParser):
                 new_data.update(request.files.to_dict(flat=False))
                 new_data.update(request.files.to_dict(flat=False))
         else:
         else:
             raise UnsupportedMediaType(
             raise UnsupportedMediaType(
-                description="除了{},content-type不支持{}".format(self.minetype,minetype)
+                description="除了{},content-type不支持{}".format(self.minetype, minetype)
             )
             )
         return new_data
         return new_data
 
 
@@ -283,7 +291,6 @@ class UploadFileFormReqParser(FormReqParser):
 
 
 
 
 class StreamReqParser(BaseReqParser):
 class StreamReqParser(BaseReqParser):
-
     minetype: ClassVar[str] = "application/octet-stream"
     minetype: ClassVar[str] = "application/octet-stream"
 
 
     def data(self, *args, **kwargs) -> Dict:
     def data(self, *args, **kwargs) -> Dict:

+ 30 - 8
ruoyi_common/base/transformer.py

@@ -1,6 +1,3 @@
-# -*- coding: utf-8 -*-
-# @Author  : YY
-
 from types import NoneType
 from types import NoneType
 from typing import Callable, List, Optional
 from typing import Callable, List, Optional
 from datetime import datetime
 from datetime import datetime
@@ -10,7 +7,7 @@ from pydantic import BeforeValidator, ValidationInfo
 from ruoyi_common.utils.base import DateUtil
 from ruoyi_common.utils.base import DateUtil
 
 
 
 
-def ids_to_list(value:str) -> Optional[List[int]]:
+def ids_to_list(value: str) -> Optional[List[int]]:
     """
     """
     验证ids转换为字符串列表
     验证ids转换为字符串列表
 
 
@@ -31,7 +28,13 @@ def to_datetime(format=None) -> Callable[[str | NoneType, ValidationInfo], datet
         format (str): 日期格式. Defaults to '%Y-%m-%d %H:%M:%S'.
         format (str): 日期格式. Defaults to '%Y-%m-%d %H:%M:%S'.
     """
     """
     if format is None:
     if format is None:
-        formats: List[str] = [DateUtil.YYYY_MM_DD_HH_MM_SS, DateUtil.YYYY_MM_DD]
+        # 默认支持常见的年月日格式,以及仅到月份的格式,方便 Excel 导入
+        formats: List[str] = [
+            DateUtil.YYYY_MM_DD_HH_MM_SS,
+            DateUtil.YYYY_MM_DD,
+            "%Y.%m",
+            "%Y-%m",
+        ]
     elif isinstance(format, (list, tuple, set)):
     elif isinstance(format, (list, tuple, set)):
         formats = list(format)
         formats = list(format)
     else:
     else:
@@ -63,10 +66,11 @@ def to_datetime(format=None) -> Callable[[str | NoneType, ValidationInfo], datet
                     continue
                     continue
             raise ValueError(f"time data '{value}' does not match formats: {formats}")
             raise ValueError(f"time data '{value}' does not match formats: {formats}")
         raise ValueError(f"Invalid datetime format: {value}")
         raise ValueError(f"Invalid datetime format: {value}")
+
     return validate_datetime
     return validate_datetime
 
 
 
 
-def str_to_int(value:str|NoneType, info:ValidationInfo) \
+def str_to_int(value: str | NoneType, info: ValidationInfo) \
         -> int:
         -> int:
     """
     """
     验证str是否为整数,并转换为整数
     验证str是否为整数,并转换为整数
@@ -90,11 +94,29 @@ def str_to_int(value:str|NoneType, info:ValidationInfo) \
     return value
     return value
 
 
 
 
-def int_to_str(value:int|NoneType)-> str:
+def str_to_float(value: str | NoneType, info: ValidationInfo) -> float | NoneType:
+    """
+    将字符串转换为浮点数;空值直接返回
+    """
+    if value is None or value == "":
+        return value
+    if isinstance(value, (int, float)):
+        return float(value)
+    if isinstance(value, str):
+        stripped = value.strip()
+        try:
+            return float(stripped)
+        except ValueError:
+            # 格式化不了就返回 None,避免抛出验证错误
+            return None
+    return value
+
+
+def int_to_str(value: int | NoneType) -> str:
     if isinstance(value, int):
     if isinstance(value, int):
         return str(value)
         return str(value)
     else:
     else:
         return value
         return value
 
 
 
 
-ids_convertor = Annotated[List[int],BeforeValidator(ids_to_list)]
+ids_convertor = Annotated[List[int], BeforeValidator(ids_to_list)]