validator.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import inspect
  2. from abc import ABC, abstractmethod
  3. from functools import wraps
  4. from dataclasses import dataclass, field
  5. from typing import Annotated, Any, Callable, Dict, Tuple, Type, ClassVar, \
  6. Optional, Set
  7. from werkzeug.exceptions import BadRequest, InternalServerError
  8. from flask import has_request_context, request
  9. from pydantic import BaseModel, ValidationError, validate_call
  10. from pydantic.fields import FieldInfo
  11. from ruoyi_common.base.model import MultiFile
  12. from ruoyi_common.base.reqparser import BaseReqParser, BodyReqParser, \
  13. DownloadFileQueryReqParser, UploadFileFormReqParser, PathReqParser, \
  14. QueryReqParser, VoValidatorContext
  15. from ruoyi_common.base.schema_vo import ArbitrarySchemaFactory, \
  16. BaseSchemaFactory, BodySchemaFactory, PathSchemaFactory, QuerySchemaFactory
  17. class AbcValidatorFunction(ABC):
  18. @abstractmethod
  19. def validate_unbound_parameters(self):
  20. raise NotImplementedError()
  21. @abstractmethod
  22. def validate_function(self):
  23. raise NotImplementedError()
  24. @abstractmethod
  25. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  26. raise NotImplementedError()
  27. class ValidatorScopeFunction(AbcValidatorFunction):
  28. def __init__(self,func:Callable):
  29. self.func = func
  30. self.sig = inspect.signature(self.func)
  31. self._unbound_fields:Dict[str,Annotated] = {}
  32. self._unbound_model: \
  33. Optional[Tuple[str,Type[BaseModel],Type[BaseModel]]] = None
  34. self.args = ()
  35. self.kwargs = {}
  36. self.validate_unbound_parameters()
  37. self.validate_function()
  38. @property
  39. def unbound_model(self):
  40. return self._unbound_model
  41. def _validate_kind(self,kind):
  42. if kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
  43. raise Exception("参数必须是位置参数")
  44. def validate_unbound_parameters(self):
  45. index = 0
  46. for key in self.sig.parameters:
  47. param = self.sig.parameters[key]
  48. self._validate_kind(param.kind)
  49. if isinstance(param.annotation, BaseModel):
  50. self._unbound_model = (key,param.annotation)
  51. if index > 0:
  52. raise Exception(
  53. f"{self.func.__name__} 类型参数有且仅有第一个"
  54. )
  55. else:
  56. self._unbound_fields[key] = \
  57. FieldInfo.from_annotation(param.annotation)
  58. index += 1
  59. def validate_function(self):
  60. self.func = validate_call(self.func)
  61. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  62. self.args = args if args else ()
  63. self.kwargs = kwargs if kwargs else {}
  64. return self.func(*self.args, **self.kwargs)
  65. class ValidatorViewFunction(AbcValidatorFunction):
  66. def __init__(self,func:Callable):
  67. self.func = func
  68. self.sig = inspect.signature(self.func)
  69. self._unbound_fields:Dict[str,Annotated] = {}
  70. self._unbound_model: \
  71. Optional[Tuple[str,Type[BaseModel],Type[BaseModel]]] = None
  72. self._schema_factory = None
  73. self._data_parser = None
  74. self.args = ()
  75. self.kwargs = {}
  76. @property
  77. def unbound_model(self):
  78. return self._unbound_model
  79. def _validate_kind(self,kind):
  80. if kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
  81. raise Exception("参数必须是位置参数")
  82. def validate_unbound_parameters(self):
  83. index = 0
  84. for key in self.sig.parameters:
  85. param = self.sig.parameters[key]
  86. self._validate_kind(param.kind)
  87. if self._schema_factory:
  88. annotation = self._schema_factory. \
  89. validate_annotation(param.annotation)
  90. if annotation:
  91. self._unbound_model = (key,annotation)
  92. if index > 0:
  93. raise Exception(
  94. f"{self.func.__name__} 类型参数有且仅有第一个"
  95. )
  96. else:
  97. self._unbound_fields[key] = \
  98. FieldInfo.from_annotation(param.annotation)
  99. else:
  100. if isinstance(param.annotation, BaseModel):
  101. self._unbound_model = (key,param.annotation)
  102. if index > 0:
  103. raise Exception(
  104. f"{self.func.__name__} 类型参数有且仅有第一个"
  105. )
  106. else:
  107. self._unbound_fields[key] = \
  108. FieldInfo.from_annotation(param.annotation)
  109. index += 1
  110. def validate_function(self):
  111. if self._schema_factory:
  112. if self._unbound_model:
  113. self.func = validate_call(
  114. self.func,
  115. config=self._schema_factory.model_config)
  116. else:
  117. self.func = validate_call(
  118. self.func,
  119. config=self._schema_factory.model_config
  120. )
  121. else:
  122. self.func = validate_call(self.func)
  123. def unbound_data(
  124. self,
  125. data_parser: BaseReqParser
  126. ):
  127. self._data_parser = data_parser
  128. if self._schema_factory and self._data_parser:
  129. self._data_parser.prepare_factory(self._schema_factory)
  130. def unbound_schema(
  131. self,
  132. schema_factory:Optional[BaseSchemaFactory],
  133. ):
  134. self._schema_factory = schema_factory
  135. self.validate_unbound_parameters()
  136. self.validate_function()
  137. def bound_data(self, args:Tuple=(), kwargs:Dict={}):
  138. if self._unbound_model:
  139. key, bo_model = self._unbound_model
  140. obj = self._data_parser.cast_model(bo_model)
  141. kwargs[key] = obj
  142. else:
  143. # 特殊处理文件上传参数(MultiFile),直接从 request.files 构造
  144. has_multi_file_param = any(
  145. param.annotation is MultiFile
  146. for param in self.sig.parameters.values()
  147. )
  148. if has_multi_file_param:
  149. files = MultiFile.from_obj(request.files)
  150. kwargs.clear()
  151. for name, param in self.sig.parameters.items():
  152. if param.annotation is MultiFile:
  153. kwargs[name] = files
  154. return
  155. data = self._data_parser.data()
  156. kwargs.clear()
  157. kwargs.update(data)
  158. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  159. self.args = args if args else ()
  160. self.kwargs = kwargs if kwargs else {}
  161. if not has_request_context:
  162. raise Exception("请在flask请求上下文中调用")
  163. try:
  164. if self._data_parser:
  165. self._data_parser.prepare()
  166. self.bound_data(self.args, self.kwargs)
  167. except ValidationError as e:
  168. return BadRequest(description=str(e))
  169. except TypeError as e:
  170. return BadRequest(description=str(e))
  171. except Exception as e:
  172. return InternalServerError(description=str(e))
  173. else:
  174. return self.func(*self.args, **self.kwargs)
  175. @dataclass
  176. class BaseValidator:
  177. data_parser:ClassVar = None
  178. schema_factory:ClassVar = None
  179. vo_context:VoValidatorContext = field(init=False)
  180. def __call__(self, func):
  181. view_function = ValidatorViewFunction(func)
  182. view_function.unbound_schema(self.schema_factory)
  183. view_function.unbound_data(self.data_parser)
  184. @wraps(func)
  185. def wrapper(*args, **kwargs):
  186. return view_function(*args, **kwargs)
  187. return wrapper
  188. @dataclass
  189. class PathValidator(BaseValidator):
  190. def __post_init__(self):
  191. self.schema_factory = PathSchemaFactory()
  192. self.data_parser = PathReqParser()
  193. @dataclass
  194. class QueryValidator(BaseValidator):
  195. is_page: bool = False
  196. include:Optional[Set[str]] = field(default=None)
  197. exclude:Optional[Set[str]] = field(default=None)
  198. extra_fields:Optional[Dict[str, FieldInfo]] = field(default=None)
  199. def __post_init__(self):
  200. vo_context = VoValidatorContext(
  201. exclude_data_alias=True,
  202. is_page=self.is_page,
  203. is_sort=self.is_page,
  204. include_fields=self.include,
  205. exclude_fields=self.exclude,
  206. )
  207. self.schema_factory = QuerySchemaFactory(
  208. vo_context,
  209. extra_strict_forbid=True,
  210. extra_allowed_fields=self.extra_fields
  211. )
  212. self.data_parser = QueryReqParser(vo_context)
  213. @dataclass
  214. class BodyValidator(BaseValidator):
  215. include:Optional[Set[str]] = field(default=None)
  216. exclude:Optional[Set[str]] = field(default=None)
  217. def __post_init__(self):
  218. vo_context = VoValidatorContext(
  219. include_fields=self.include,
  220. exclude_fields=self.exclude
  221. )
  222. self.schema_factory = BodySchemaFactory(vo_context)
  223. self.data_parser = BodyReqParser(vo_context)
  224. @dataclass
  225. class FileDownloadValidator(BaseValidator):
  226. def __post_init__(self):
  227. vo_context = VoValidatorContext(
  228. exclude_data_alias=True,
  229. is_page=True,
  230. )
  231. self.schema_factory = QuerySchemaFactory(vo_context)
  232. self.data_parser = DownloadFileQueryReqParser(vo_context)
  233. @dataclass
  234. class FileUploadValidator(BaseValidator):
  235. def __post_init__(self):
  236. self.schema_factory = ArbitrarySchemaFactory()
  237. self.data_parser = UploadFileFormReqParser(
  238. is_form=False, is_query=True, is_file=True
  239. )
  240. @dataclass
  241. class FileValidator(BaseValidator):
  242. include:Optional[Set[str]] = field(default=None)
  243. def __post_init__(self):
  244. self.schema_factory = ArbitrarySchemaFactory()
  245. self.data_parser = UploadFileFormReqParser()