SpringSunYY пре 4 месеци
родитељ
комит
77b60bc20d
2 измењених фајлова са 64 додато и 35 уклоњено
  1. 34 27
      ruoyi_common/base/reqparser.py
  2. 30 8
      ruoyi_common/base/transformer.py

+ 34 - 27
ruoyi_common/base/reqparser.py

@@ -1,13 +1,13 @@
-
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from typing import ClassVar, Dict, Iterable, Set, Type
+
 from flask import g, request
 from pydantic import AliasChoices, AliasPath, BaseModel
 from werkzeug.datastructures import ImmutableMultiDict
-from werkzeug.exceptions import BadRequest,UnsupportedMediaType
+from werkzeug.exceptions import BadRequest, UnsupportedMediaType
 
-from ruoyi_common.base.model import BaseEntity, CriterianMeta, ExtraModel, \
+from ruoyi_common.base.model import CriterianMeta, ExtraModel, \
     BaseEntity, OrderModel, PageModel, VoValidatorContext
 from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
 
@@ -24,7 +24,7 @@ class AbsReqParser(ABC):
         """
 
     @abstractmethod
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         """
         适配模型
 
@@ -37,7 +37,7 @@ class AbsReqParser(ABC):
         """
 
     @abstractmethod
-    def prepare_factory(self, factory:BaseSchemaFactory):
+    def prepare_factory(self, factory: BaseSchemaFactory):
         """
         准备工厂
 
@@ -51,15 +51,16 @@ class AbsReqParser(ABC):
         准备数据
         """
 
+
 class BaseReqParser(AbsReqParser):
 
     def data(self) -> Dict:
         pass
 
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         pass
 
-    def prepare_factory(self, factory:BaseSchemaFactory):
+    def prepare_factory(self, factory: BaseSchemaFactory):
         pass
 
     def prepare(self):
@@ -68,7 +69,7 @@ class BaseReqParser(AbsReqParser):
 
 class QueryReqParser(BaseReqParser):
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         self.context = context
         self.extra_model = ExtraModel
 
@@ -81,28 +82,38 @@ class QueryReqParser(BaseReqParser):
         g.criterian_meta = self.criterian_meta
 
     def validate_request(self) -> Dict:
-        return request.args.to_dict()
+        data = request.args.to_dict()
+        # 兼容前端传参形式 params[xxx]=yyy
+        params_dict = {}
+        for key, val in list(data.items()):
+            if key.startswith("params[") and key.endswith("]"):
+                inner = key[len("params["):-1]
+                params_dict[inner] = val
+                data.pop(key, None)
+        if params_dict:
+            data["params"] = params_dict
+        return data
 
     def data(self) -> Dict:
         data = self.validate_request().copy()
         if self.context.is_page:
-            page = PageModel.model_validate(data,context=self.context)
+            page = PageModel.model_validate(data, context=self.context)
             if page.model_fields_set:
                 self.criterian_meta.page = page
             self._remove_model_aliases(data, PageModel)
         if self.context.is_sort:
-            sort = OrderModel.model_validate(data,context=self.context)
+            sort = OrderModel.model_validate(data, context=self.context)
             if sort.model_fields_set:
                 self.criterian_meta.sort = sort
             self._remove_model_aliases(data, OrderModel)
         if self.extra_model:
-            extra = self.extra_model.model_validate(data,context=self.context)
+            extra = self.extra_model.model_validate(data, context=self.context)
             if extra.model_fields_set:
                 self.criterian_meta.extra = extra
             self._remove_model_aliases(data, self.extra_model)
         return data
 
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         data = self.data()
         # 对于查询参数,只保留模型中定义的字段和别名,忽略额外字段
         # 收集模型中所有字段名和别名
@@ -177,10 +188,9 @@ class PathReqParser(BaseReqParser):
 
 @dataclass
 class BodyReqParser(BaseReqParser):
-
     minetype: ClassVar[str] = "application/json"
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         self.context = context
 
     def validate_request(self) -> Dict:
@@ -202,7 +212,7 @@ class BodyReqParser(BaseReqParser):
         data = self.validate_request().copy()
         return data
 
-    def cast_model(self, bo_model:BaseEntity) -> BaseModel:
+    def cast_model(self, bo_model: BaseEntity) -> BaseModel:
         data = self.data()
         bo = bo_model.model_validate(data, context=self.context)
         return bo
@@ -210,21 +220,20 @@ class BodyReqParser(BaseReqParser):
 
 @dataclass
 class FormUrlencodedQueryReqParser(QueryReqParser):
-
     minetype: ClassVar[str] = "application/x-www-form-urlencoded"
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         super().__init__(context)
 
     def validate_request(self) -> Dict:
         content_type = request.headers.get("Content-Type", "").lower()
         minetype = content_type.split(";")[0]
         if minetype == self.minetype:
-            form:ImmutableMultiDict = request.form
+            form: ImmutableMultiDict = request.form
             body = form.to_dict()
         else:
             raise UnsupportedMediaType(
-                description="除了{},content-type不支持{}".format(self.minetype,minetype)
+                description="除了{},content-type不支持{}".format(self.minetype, minetype)
             )
         return body
 
@@ -232,19 +241,18 @@ class FormUrlencodedQueryReqParser(QueryReqParser):
 @dataclass
 class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
 
-    def __init__(self, context:VoValidatorContext):
+    def __init__(self, context: VoValidatorContext):
         super().__init__(context)
 
 
 class FormReqParser(BaseReqParser):
-
     minetype: ClassVar[str] = "multipart/form-data"
 
     def __init__(
             self,
-            is_form:bool=True,
-            is_query:bool=False,
-            is_file:bool|None=None,
+            is_form: bool = True,
+            is_query: bool = False,
+            is_file: bool | None = None,
     ):
         self.is_form = is_form
         self.is_query = is_query
@@ -263,7 +271,7 @@ class FormReqParser(BaseReqParser):
                 new_data.update(request.files.to_dict(flat=False))
         else:
             raise UnsupportedMediaType(
-                description="除了{},content-type不支持{}".format(self.minetype,minetype)
+                description="除了{},content-type不支持{}".format(self.minetype, minetype)
             )
         return new_data
 
@@ -283,7 +291,6 @@ class UploadFileFormReqParser(FormReqParser):
 
 
 class StreamReqParser(BaseReqParser):
-
     minetype: ClassVar[str] = "application/octet-stream"
 
     def data(self, *args, **kwargs) -> Dict:

+ 30 - 8
ruoyi_common/base/transformer.py

@@ -1,6 +1,3 @@
-# -*- coding: utf-8 -*-
-# @Author  : YY
-
 from types import NoneType
 from typing import Callable, List, Optional
 from datetime import datetime
@@ -10,7 +7,7 @@ from pydantic import BeforeValidator, ValidationInfo
 from ruoyi_common.utils.base import DateUtil
 
 
-def ids_to_list(value:str) -> Optional[List[int]]:
+def ids_to_list(value: str) -> Optional[List[int]]:
     """
     验证ids转换为字符串列表
 
@@ -31,7 +28,13 @@ def to_datetime(format=None) -> Callable[[str | NoneType, ValidationInfo], datet
         format (str): 日期格式. Defaults to '%Y-%m-%d %H:%M:%S'.
     """
     if format is None:
-        formats: List[str] = [DateUtil.YYYY_MM_DD_HH_MM_SS, DateUtil.YYYY_MM_DD]
+        # 默认支持常见的年月日格式,以及仅到月份的格式,方便 Excel 导入
+        formats: List[str] = [
+            DateUtil.YYYY_MM_DD_HH_MM_SS,
+            DateUtil.YYYY_MM_DD,
+            "%Y.%m",
+            "%Y-%m",
+        ]
     elif isinstance(format, (list, tuple, set)):
         formats = list(format)
     else:
@@ -63,10 +66,11 @@ def to_datetime(format=None) -> Callable[[str | NoneType, ValidationInfo], datet
                     continue
             raise ValueError(f"time data '{value}' does not match formats: {formats}")
         raise ValueError(f"Invalid datetime format: {value}")
+
     return validate_datetime
 
 
-def str_to_int(value:str|NoneType, info:ValidationInfo) \
+def str_to_int(value: str | NoneType, info: ValidationInfo) \
         -> int:
     """
     验证str是否为整数,并转换为整数
@@ -90,11 +94,29 @@ def str_to_int(value:str|NoneType, info:ValidationInfo) \
     return value
 
 
-def int_to_str(value:int|NoneType)-> str:
+def str_to_float(value: str | NoneType, info: ValidationInfo) -> float | NoneType:
+    """
+    将字符串转换为浮点数;空值直接返回
+    """
+    if value is None or value == "":
+        return value
+    if isinstance(value, (int, float)):
+        return float(value)
+    if isinstance(value, str):
+        stripped = value.strip()
+        try:
+            return float(stripped)
+        except ValueError:
+            # 格式化不了就返回 None,避免抛出验证错误
+            return None
+    return value
+
+
+def int_to_str(value: int | NoneType) -> str:
     if isinstance(value, int):
         return str(value)
     else:
         return value
 
 
-ids_convertor = Annotated[List[int],BeforeValidator(ids_to_list)]
+ids_convertor = Annotated[List[int], BeforeValidator(ids_to_list)]