reqparser.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. from typing import ClassVar, Dict, Iterable, Set, Type
  4. from flask import g, request
  5. from pydantic import AliasChoices, AliasPath, BaseModel
  6. from werkzeug.datastructures import ImmutableMultiDict
  7. from werkzeug.exceptions import BadRequest, UnsupportedMediaType
  8. from ruoyi_common.base.model import CriterianMeta, ExtraModel, \
  9. BaseEntity, OrderModel, PageModel, VoValidatorContext
  10. from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
  11. class AbsReqParser(ABC):
  12. @abstractmethod
  13. def data(self) -> Dict:
  14. """
  15. 获取请求参数
  16. Returns:
  17. Dict: 请求参数字典
  18. """
  19. @abstractmethod
  20. def cast_model(self, bo_model: BaseEntity) -> BaseModel:
  21. """
  22. 适配模型
  23. Args:
  24. bo_model (BaseEntity): Vo模型
  25. src_model (BaseModel): 源模型
  26. Returns:
  27. BaseModel: 适配后的模型
  28. """
  29. @abstractmethod
  30. def prepare_factory(self, factory: BaseSchemaFactory):
  31. """
  32. 准备工厂
  33. Args:
  34. factory (BaseSchemaFactory): 工厂
  35. """
  36. @abstractmethod
  37. def prepare(self):
  38. """
  39. 准备数据
  40. """
  41. class BaseReqParser(AbsReqParser):
  42. def data(self) -> Dict:
  43. pass
  44. def cast_model(self, bo_model: BaseEntity) -> BaseModel:
  45. pass
  46. def prepare_factory(self, factory: BaseSchemaFactory):
  47. pass
  48. def prepare(self):
  49. pass
  50. class QueryReqParser(BaseReqParser):
  51. def __init__(self, context: VoValidatorContext):
  52. self.context = context
  53. self.extra_model = ExtraModel
  54. def prepare_factory(self, factory: QuerySchemaFactory):
  55. if factory.extra_model:
  56. self.extra_model = factory.extra_model
  57. def prepare(self):
  58. self.criterian_meta = CriterianMeta()
  59. g.criterian_meta = self.criterian_meta
  60. def validate_request(self) -> Dict:
  61. data = request.args.to_dict()
  62. # 兼容前端传参形式 params[xxx]=yyy
  63. params_dict = {}
  64. for key, val in list(data.items()):
  65. if key.startswith("params[") and key.endswith("]"):
  66. inner = key[len("params["):-1]
  67. params_dict[inner] = val
  68. data.pop(key, None)
  69. if params_dict:
  70. data["params"] = params_dict
  71. return data
  72. def data(self) -> Dict:
  73. data = self.validate_request().copy()
  74. if self.context.is_page:
  75. page = PageModel.model_validate(data, context=self.context)
  76. if page.model_fields_set:
  77. self.criterian_meta.page = page
  78. self._remove_model_aliases(data, PageModel)
  79. if self.context.is_sort:
  80. sort = OrderModel.model_validate(data, context=self.context)
  81. if sort.model_fields_set:
  82. self.criterian_meta.sort = sort
  83. self._remove_model_aliases(data, OrderModel)
  84. if self.extra_model:
  85. extra = self.extra_model.model_validate(data, context=self.context)
  86. if extra.model_fields_set:
  87. self.criterian_meta.extra = extra
  88. self._remove_model_aliases(data, self.extra_model)
  89. return data
  90. def cast_model(self, bo_model: BaseEntity) -> BaseModel:
  91. data = self.data()
  92. # 对于查询参数,只保留模型中定义的字段和别名,忽略额外字段
  93. # 收集模型中所有字段名和别名
  94. model_fields = set()
  95. for name, info in bo_model.model_fields.items():
  96. model_fields.add(name)
  97. # 添加查询别名(camelCase)
  98. if hasattr(bo_model, 'model_config') and bo_model.model_config:
  99. alias_gen = bo_model.model_config.get('alias_generator')
  100. if callable(alias_gen):
  101. model_fields.add(alias_gen(name))
  102. # 添加 validation_alias
  103. if info.validation_alias:
  104. if isinstance(info.validation_alias, str):
  105. model_fields.add(info.validation_alias)
  106. elif hasattr(info.validation_alias, 'choices'):
  107. model_fields.update(info.validation_alias.choices)
  108. # 过滤掉未定义的字段
  109. filtered_data = {k: v for k, v in data.items() if k in model_fields}
  110. bo = bo_model.model_validate(filtered_data)
  111. return bo
  112. def _remove_model_aliases(
  113. self,
  114. data: Dict[str, str],
  115. model_cls: Type[BaseModel]
  116. ) -> None:
  117. """
  118. 删除已经用于解析的模型字段别名,避免后续模型校验时报额外字段错误
  119. """
  120. if not data:
  121. return
  122. for alias in self._collect_aliases(model_cls):
  123. data.pop(alias, None)
  124. def _collect_aliases(self, model_cls: Type[BaseModel]) -> Set[str]:
  125. alias_set: Set[str] = set()
  126. populate_by_name = getattr(model_cls, "model_config", {}).get(
  127. "populate_by_name", False
  128. )
  129. for name, info in model_cls.model_fields.items():
  130. alias_set.update(self._field_aliases(name, info, populate_by_name))
  131. return alias_set
  132. @staticmethod
  133. def _field_aliases(
  134. name: str,
  135. info,
  136. populate_by_name: bool
  137. ) -> Iterable[str]:
  138. aliases: Set[str] = set()
  139. if getattr(info, "alias", None):
  140. aliases.add(info.alias)
  141. validation_alias = getattr(info, "validation_alias", None)
  142. if isinstance(validation_alias, str):
  143. aliases.add(validation_alias)
  144. elif isinstance(validation_alias, AliasChoices):
  145. aliases.update(validation_alias.choices)
  146. elif isinstance(validation_alias, AliasPath):
  147. pass
  148. if populate_by_name:
  149. aliases.add(name)
  150. return aliases
  151. @dataclass
  152. class PathReqParser(BaseReqParser):
  153. def data(self) -> Dict:
  154. return request.view_args.copy()
  155. @dataclass
  156. class BodyReqParser(BaseReqParser):
  157. minetype: ClassVar[str] = "application/json"
  158. def __init__(self, context: VoValidatorContext):
  159. self.context = context
  160. def validate_request(self) -> Dict:
  161. content_type = request.headers.get("Content-Type", "").lower()
  162. minetype = content_type.split(";")[0]
  163. if minetype == self.minetype:
  164. body: dict | list = request.get_json()
  165. if not body:
  166. raise BadRequest(
  167. description="在{}, body数据不能为空".format(content_type),
  168. )
  169. else:
  170. raise UnsupportedMediaType(
  171. description="content-type仅支持application/json"
  172. )
  173. return body
  174. def data(self) -> Dict:
  175. data = self.validate_request().copy()
  176. return data
  177. def cast_model(self, bo_model: BaseEntity) -> BaseModel:
  178. data = self.data()
  179. bo = bo_model.model_validate(data, context=self.context)
  180. return bo
  181. @dataclass
  182. class FormUrlencodedQueryReqParser(QueryReqParser):
  183. minetype: ClassVar[str] = "application/x-www-form-urlencoded"
  184. def __init__(self, context: VoValidatorContext):
  185. super().__init__(context)
  186. def validate_request(self) -> Dict:
  187. content_type = request.headers.get("Content-Type", "").lower()
  188. minetype = content_type.split(";")[0]
  189. if minetype == self.minetype:
  190. form: ImmutableMultiDict = request.form
  191. body = form.to_dict()
  192. else:
  193. raise UnsupportedMediaType(
  194. description="除了{},content-type不支持{}".format(self.minetype, minetype)
  195. )
  196. return body
  197. @dataclass
  198. class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
  199. def __init__(self, context: VoValidatorContext):
  200. super().__init__(context)
  201. class FormReqParser(BaseReqParser):
  202. minetype: ClassVar[str] = "multipart/form-data"
  203. def __init__(
  204. self,
  205. is_form: bool = True,
  206. is_query: bool = False,
  207. is_file: bool | None = None,
  208. ):
  209. self.is_form = is_form
  210. self.is_query = is_query
  211. self.is_file = is_file
  212. def validate_request(self) -> Dict:
  213. content_type = request.headers.get("Content-Type", "").lower()
  214. minetype = content_type.split(";")[0]
  215. new_data = {}
  216. if minetype == self.minetype:
  217. if self.is_form:
  218. new_data.update(request.form.to_dict())
  219. if self.is_query:
  220. new_data.update(request.args.to_dict())
  221. if self.is_file:
  222. new_data.update(request.files.to_dict(flat=False))
  223. else:
  224. raise UnsupportedMediaType(
  225. description="除了{},content-type不支持{}".format(self.minetype, minetype)
  226. )
  227. return new_data
  228. def data(self) -> Dict:
  229. data = self.validate_request()
  230. return data
  231. class UploadFileFormReqParser(FormReqParser):
  232. def validate_request(self) -> Dict:
  233. return super().validate_request()
  234. def data(self) -> Dict:
  235. data = self.validate_request()
  236. return data
  237. class StreamReqParser(BaseReqParser):
  238. minetype: ClassVar[str] = "application/octet-stream"
  239. def data(self, *args, **kwargs) -> Dict:
  240. pass