reqparser.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. from typing import ClassVar, Dict, Iterable, Set, Type
  4. from flask import g, request
  5. from pydantic import AliasChoices, AliasPath, BaseModel
  6. from werkzeug.datastructures import ImmutableMultiDict
  7. from werkzeug.exceptions import BadRequest,UnsupportedMediaType
  8. from ruoyi_common.base.model import BaseEntity, CriterianMeta, ExtraModel, \
  9. BaseEntity, OrderModel, PageModel, VoValidatorContext
  10. from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
  11. class AbsReqParser(ABC):
  12. @abstractmethod
  13. def data(self) -> Dict:
  14. """
  15. 获取请求参数
  16. Returns:
  17. Dict: 请求参数字典
  18. """
  19. @abstractmethod
  20. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  21. """
  22. 适配模型
  23. Args:
  24. bo_model (BaseEntity): Vo模型
  25. src_model (BaseModel): 源模型
  26. Returns:
  27. BaseModel: 适配后的模型
  28. """
  29. @abstractmethod
  30. def prepare_factory(self, factory:BaseSchemaFactory):
  31. """
  32. 准备工厂
  33. Args:
  34. factory (BaseSchemaFactory): 工厂
  35. """
  36. @abstractmethod
  37. def prepare(self):
  38. """
  39. 准备数据
  40. """
  41. class BaseReqParser(AbsReqParser):
  42. def data(self) -> Dict:
  43. pass
  44. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  45. pass
  46. def prepare_factory(self, factory:BaseSchemaFactory):
  47. pass
  48. def prepare(self):
  49. pass
  50. class QueryReqParser(BaseReqParser):
  51. def __init__(self, context:VoValidatorContext):
  52. self.context = context
  53. self.extra_model = ExtraModel
  54. def prepare_factory(self, factory: QuerySchemaFactory):
  55. if factory.extra_model:
  56. self.extra_model = factory.extra_model
  57. def prepare(self):
  58. self.criterian_meta = CriterianMeta()
  59. g.criterian_meta = self.criterian_meta
  60. def validate_request(self) -> Dict:
  61. return request.args.to_dict()
  62. def data(self) -> Dict:
  63. data = self.validate_request().copy()
  64. if self.context.is_page:
  65. page = PageModel.model_validate(data,context=self.context)
  66. if page.model_fields_set:
  67. self.criterian_meta.page = page
  68. self._remove_model_aliases(data, PageModel)
  69. if self.context.is_sort:
  70. sort = OrderModel.model_validate(data,context=self.context)
  71. if sort.model_fields_set:
  72. self.criterian_meta.sort = sort
  73. self._remove_model_aliases(data, OrderModel)
  74. if self.extra_model:
  75. extra = self.extra_model.model_validate(data,context=self.context)
  76. if extra.model_fields_set:
  77. self.criterian_meta.extra = extra
  78. self._remove_model_aliases(data, self.extra_model)
  79. return data
  80. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  81. data = self.data()
  82. # 对于查询参数,只保留模型中定义的字段和别名,忽略额外字段
  83. # 收集模型中所有字段名和别名
  84. model_fields = set()
  85. for name, info in bo_model.model_fields.items():
  86. model_fields.add(name)
  87. # 添加查询别名(camelCase)
  88. if hasattr(bo_model, 'model_config') and bo_model.model_config:
  89. alias_gen = bo_model.model_config.get('alias_generator')
  90. if callable(alias_gen):
  91. model_fields.add(alias_gen(name))
  92. # 添加 validation_alias
  93. if info.validation_alias:
  94. if isinstance(info.validation_alias, str):
  95. model_fields.add(info.validation_alias)
  96. elif hasattr(info.validation_alias, 'choices'):
  97. model_fields.update(info.validation_alias.choices)
  98. # 过滤掉未定义的字段
  99. filtered_data = {k: v for k, v in data.items() if k in model_fields}
  100. bo = bo_model.model_validate(filtered_data)
  101. return bo
  102. def _remove_model_aliases(
  103. self,
  104. data: Dict[str, str],
  105. model_cls: Type[BaseModel]
  106. ) -> None:
  107. """
  108. 删除已经用于解析的模型字段别名,避免后续模型校验时报额外字段错误
  109. """
  110. if not data:
  111. return
  112. for alias in self._collect_aliases(model_cls):
  113. data.pop(alias, None)
  114. def _collect_aliases(self, model_cls: Type[BaseModel]) -> Set[str]:
  115. alias_set: Set[str] = set()
  116. populate_by_name = getattr(model_cls, "model_config", {}).get(
  117. "populate_by_name", False
  118. )
  119. for name, info in model_cls.model_fields.items():
  120. alias_set.update(self._field_aliases(name, info, populate_by_name))
  121. return alias_set
  122. @staticmethod
  123. def _field_aliases(
  124. name: str,
  125. info,
  126. populate_by_name: bool
  127. ) -> Iterable[str]:
  128. aliases: Set[str] = set()
  129. if getattr(info, "alias", None):
  130. aliases.add(info.alias)
  131. validation_alias = getattr(info, "validation_alias", None)
  132. if isinstance(validation_alias, str):
  133. aliases.add(validation_alias)
  134. elif isinstance(validation_alias, AliasChoices):
  135. aliases.update(validation_alias.choices)
  136. elif isinstance(validation_alias, AliasPath):
  137. pass
  138. if populate_by_name:
  139. aliases.add(name)
  140. return aliases
  141. @dataclass
  142. class PathReqParser(BaseReqParser):
  143. def data(self) -> Dict:
  144. return request.view_args.copy()
  145. @dataclass
  146. class BodyReqParser(BaseReqParser):
  147. minetype: ClassVar[str] = "application/json"
  148. def __init__(self, context:VoValidatorContext):
  149. self.context = context
  150. def validate_request(self) -> Dict:
  151. content_type = request.headers.get("Content-Type", "").lower()
  152. minetype = content_type.split(";")[0]
  153. if minetype == self.minetype:
  154. body: dict | list = request.get_json()
  155. if not body:
  156. raise BadRequest(
  157. description="在{}, body数据不能为空".format(content_type),
  158. )
  159. else:
  160. raise UnsupportedMediaType(
  161. description="content-type仅支持application/json"
  162. )
  163. return body
  164. def data(self) -> Dict:
  165. data = self.validate_request().copy()
  166. return data
  167. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  168. data = self.data()
  169. bo = bo_model.model_validate(data, context=self.context)
  170. return bo
  171. @dataclass
  172. class FormUrlencodedQueryReqParser(QueryReqParser):
  173. minetype: ClassVar[str] = "application/x-www-form-urlencoded"
  174. def __init__(self, context:VoValidatorContext):
  175. super().__init__(context)
  176. def validate_request(self) -> Dict:
  177. content_type = request.headers.get("Content-Type", "").lower()
  178. minetype = content_type.split(";")[0]
  179. if minetype == self.minetype:
  180. form:ImmutableMultiDict = request.form
  181. body = form.to_dict()
  182. else:
  183. raise UnsupportedMediaType(
  184. description="除了{},content-type不支持{}".format(self.minetype,minetype)
  185. )
  186. return body
  187. @dataclass
  188. class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
  189. def __init__(self, context:VoValidatorContext):
  190. super().__init__(context)
  191. class FormReqParser(BaseReqParser):
  192. minetype: ClassVar[str] = "multipart/form-data"
  193. def __init__(
  194. self,
  195. is_form:bool=True,
  196. is_query:bool=False,
  197. is_file:bool|None=None,
  198. ):
  199. self.is_form = is_form
  200. self.is_query = is_query
  201. self.is_file = is_file
  202. def validate_request(self) -> Dict:
  203. content_type = request.headers.get("Content-Type", "").lower()
  204. minetype = content_type.split(";")[0]
  205. new_data = {}
  206. if minetype == self.minetype:
  207. if self.is_form:
  208. new_data.update(request.form.to_dict())
  209. if self.is_query:
  210. new_data.update(request.args.to_dict())
  211. if self.is_file:
  212. new_data.update(request.files.to_dict(flat=False))
  213. else:
  214. raise UnsupportedMediaType(
  215. description="除了{},content-type不支持{}".format(self.minetype,minetype)
  216. )
  217. return new_data
  218. def data(self) -> Dict:
  219. data = self.validate_request()
  220. return data
  221. class UploadFileFormReqParser(FormReqParser):
  222. def validate_request(self) -> Dict:
  223. return super().validate_request()
  224. def data(self) -> Dict:
  225. data = self.validate_request()
  226. return data
  227. class StreamReqParser(BaseReqParser):
  228. minetype: ClassVar[str] = "application/octet-stream"
  229. def data(self, *args, **kwargs) -> Dict:
  230. pass