reqparser.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. from typing import ClassVar, Dict, Iterable, Set, Type
  4. from flask import g, request
  5. from pydantic import AliasChoices, AliasPath, BaseModel
  6. from werkzeug.datastructures import ImmutableMultiDict
  7. from werkzeug.exceptions import BadRequest,UnsupportedMediaType
  8. from ruoyi_common.base.model import BaseEntity, CriterianMeta, ExtraModel, \
  9. BaseEntity, OrderModel, PageModel, VoValidatorContext
  10. from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
  11. class AbsReqParser(ABC):
  12. @abstractmethod
  13. def data(self) -> Dict:
  14. """
  15. 获取请求参数
  16. Returns:
  17. Dict: 请求参数字典
  18. """
  19. @abstractmethod
  20. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  21. """
  22. 适配模型
  23. Args:
  24. bo_model (BaseEntity): Vo模型
  25. src_model (BaseModel): 源模型
  26. Returns:
  27. BaseModel: 适配后的模型
  28. """
  29. @abstractmethod
  30. def prepare_factory(self, factory:BaseSchemaFactory):
  31. """
  32. 准备工厂
  33. Args:
  34. factory (BaseSchemaFactory): 工厂
  35. """
  36. @abstractmethod
  37. def prepare(self):
  38. """
  39. 准备数据
  40. """
  41. class BaseReqParser(AbsReqParser):
  42. def data(self) -> Dict:
  43. pass
  44. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  45. pass
  46. def prepare_factory(self, factory:BaseSchemaFactory):
  47. pass
  48. def prepare(self):
  49. pass
  50. class QueryReqParser(BaseReqParser):
  51. def __init__(self, context:VoValidatorContext):
  52. self.context = context
  53. self.extra_model = ExtraModel
  54. def prepare_factory(self, factory: QuerySchemaFactory):
  55. if factory.extra_model:
  56. self.extra_model = factory.extra_model
  57. def prepare(self):
  58. self.criterian_meta = CriterianMeta()
  59. g.criterian_meta = self.criterian_meta
  60. def validate_request(self) -> Dict:
  61. return request.args.to_dict()
  62. def data(self) -> Dict:
  63. data = self.validate_request().copy()
  64. if self.context.is_page:
  65. page = PageModel.model_validate(data,context=self.context)
  66. if page.model_fields_set:
  67. self.criterian_meta.page = page
  68. self._remove_model_aliases(data, PageModel)
  69. if self.context.is_sort:
  70. sort = OrderModel.model_validate(data,context=self.context)
  71. if sort.model_fields_set:
  72. self.criterian_meta.sort = sort
  73. self._remove_model_aliases(data, OrderModel)
  74. if self.extra_model:
  75. extra = self.extra_model.model_validate(data,context=self.context)
  76. if extra.model_fields_set:
  77. self.criterian_meta.extra = extra
  78. self._remove_model_aliases(data, self.extra_model)
  79. return data
  80. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  81. data = self.data()
  82. bo = bo_model.model_validate(data)
  83. return bo
  84. def _remove_model_aliases(
  85. self,
  86. data: Dict[str, str],
  87. model_cls: Type[BaseModel]
  88. ) -> None:
  89. """
  90. 删除已经用于解析的模型字段别名,避免后续模型校验时报额外字段错误
  91. """
  92. if not data:
  93. return
  94. for alias in self._collect_aliases(model_cls):
  95. data.pop(alias, None)
  96. def _collect_aliases(self, model_cls: Type[BaseModel]) -> Set[str]:
  97. alias_set: Set[str] = set()
  98. populate_by_name = getattr(model_cls, "model_config", {}).get(
  99. "populate_by_name", False
  100. )
  101. for name, info in model_cls.model_fields.items():
  102. alias_set.update(self._field_aliases(name, info, populate_by_name))
  103. return alias_set
  104. @staticmethod
  105. def _field_aliases(
  106. name: str,
  107. info,
  108. populate_by_name: bool
  109. ) -> Iterable[str]:
  110. aliases: Set[str] = set()
  111. if getattr(info, "alias", None):
  112. aliases.add(info.alias)
  113. validation_alias = getattr(info, "validation_alias", None)
  114. if isinstance(validation_alias, str):
  115. aliases.add(validation_alias)
  116. elif isinstance(validation_alias, AliasChoices):
  117. aliases.update(validation_alias.choices)
  118. elif isinstance(validation_alias, AliasPath):
  119. pass
  120. if populate_by_name:
  121. aliases.add(name)
  122. return aliases
  123. @dataclass
  124. class PathReqParser(BaseReqParser):
  125. def data(self) -> Dict:
  126. return request.view_args.copy()
  127. @dataclass
  128. class BodyReqParser(BaseReqParser):
  129. minetype: ClassVar[str] = "application/json"
  130. def __init__(self, context:VoValidatorContext):
  131. self.context = context
  132. def validate_request(self) -> Dict:
  133. content_type = request.headers.get("Content-Type", "").lower()
  134. minetype = content_type.split(";")[0]
  135. if minetype == self.minetype:
  136. body: dict | list = request.get_json()
  137. if not body:
  138. raise BadRequest(
  139. description="在{}, body数据不能为空".format(content_type),
  140. )
  141. else:
  142. raise UnsupportedMediaType(
  143. description="content-type仅支持application/json"
  144. )
  145. return body
  146. def data(self) -> Dict:
  147. data = self.validate_request().copy()
  148. return data
  149. def cast_model(self, bo_model:BaseEntity) -> BaseModel:
  150. data = self.data()
  151. bo = bo_model.model_validate(data, context=self.context)
  152. return bo
  153. @dataclass
  154. class FormUrlencodedQueryReqParser(QueryReqParser):
  155. minetype: ClassVar[str] = "application/x-www-form-urlencoded"
  156. def __init__(self, context:VoValidatorContext):
  157. super().__init__(context)
  158. def validate_request(self) -> Dict:
  159. content_type = request.headers.get("Content-Type", "").lower()
  160. minetype = content_type.split(";")[0]
  161. if minetype == self.minetype:
  162. form:ImmutableMultiDict = request.form
  163. body = form.to_dict()
  164. else:
  165. raise UnsupportedMediaType(
  166. description="除了{},content-type不支持{}".format(self.minetype,minetype)
  167. )
  168. return body
  169. @dataclass
  170. class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
  171. def __init__(self, context:VoValidatorContext):
  172. super().__init__(context)
  173. class FormReqParser(BaseReqParser):
  174. minetype: ClassVar[str] = "multipart/form-data"
  175. def __init__(
  176. self,
  177. is_form:bool=True,
  178. is_query:bool=False,
  179. is_file:bool|None=None,
  180. ):
  181. self.is_form = is_form
  182. self.is_query = is_query
  183. self.is_file = is_file
  184. def validate_request(self) -> Dict:
  185. content_type = request.headers.get("Content-Type", "").lower()
  186. minetype = content_type.split(";")[0]
  187. new_data = {}
  188. if minetype == self.minetype:
  189. if self.is_form:
  190. new_data.update(request.form.to_dict())
  191. if self.is_query:
  192. new_data.update(request.args.to_dict())
  193. if self.is_file:
  194. new_data.update(request.files.to_dict(flat=False))
  195. else:
  196. raise UnsupportedMediaType(
  197. description="除了{},content-type不支持{}".format(self.minetype,minetype)
  198. )
  199. return new_data
  200. def data(self) -> Dict:
  201. data = self.validate_request()
  202. return data
  203. class UploadFileFormReqParser(FormReqParser):
  204. def validate_request(self) -> Dict:
  205. return super().validate_request()
  206. def data(self) -> Dict:
  207. data = self.validate_request()
  208. return data
  209. class StreamReqParser(BaseReqParser):
  210. minetype: ClassVar[str] = "application/octet-stream"
  211. def data(self, *args, **kwargs) -> Dict:
  212. pass