validator.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import inspect
  2. from abc import ABC, abstractmethod
  3. from functools import wraps
  4. from dataclasses import dataclass, field
  5. from typing import Annotated, Any, Callable, Dict, Tuple, Type, ClassVar, \
  6. Optional, Set
  7. from werkzeug.exceptions import BadRequest, InternalServerError
  8. from flask import has_request_context
  9. from pydantic import BaseModel, ValidationError, validate_call
  10. from pydantic.fields import FieldInfo
  11. from ruoyi_common.base.reqparser import BaseReqParser, BodyReqParser, \
  12. DownloadFileQueryReqParser, UploadFileFormReqParser, PathReqParser, \
  13. QueryReqParser, VoValidatorContext
  14. from ruoyi_common.base.schema_vo import ArbitrarySchemaFactory, \
  15. BaseSchemaFactory, BodySchemaFactory, PathSchemaFactory, QuerySchemaFactory
  16. class AbcValidatorFunction(ABC):
  17. @abstractmethod
  18. def validate_unbound_parameters(self):
  19. raise NotImplementedError()
  20. @abstractmethod
  21. def validate_function(self):
  22. raise NotImplementedError()
  23. @abstractmethod
  24. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  25. raise NotImplementedError()
  26. class ValidatorScopeFunction(AbcValidatorFunction):
  27. def __init__(self,func:Callable):
  28. self.func = func
  29. self.sig = inspect.signature(self.func)
  30. self._unbound_fields:Dict[str,Annotated] = {}
  31. self._unbound_model: \
  32. Optional[Tuple[str,Type[BaseModel],Type[BaseModel]]] = None
  33. self.args = ()
  34. self.kwargs = {}
  35. self.validate_unbound_parameters()
  36. self.validate_function()
  37. @property
  38. def unbound_model(self):
  39. return self._unbound_model
  40. def _validate_kind(self,kind):
  41. if kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
  42. raise Exception("参数必须是位置参数")
  43. def validate_unbound_parameters(self):
  44. index = 0
  45. for key in self.sig.parameters:
  46. param = self.sig.parameters[key]
  47. self._validate_kind(param.kind)
  48. if isinstance(param.annotation, BaseModel):
  49. self._unbound_model = (key,param.annotation)
  50. if index > 0:
  51. raise Exception(
  52. f"{self.func.__name__} 类型参数有且仅有第一个"
  53. )
  54. else:
  55. self._unbound_fields[key] = \
  56. FieldInfo.from_annotation(param.annotation)
  57. index += 1
  58. def validate_function(self):
  59. self.func = validate_call(self.func)
  60. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  61. self.args = args if args else ()
  62. self.kwargs = kwargs if kwargs else {}
  63. return self.func(*self.args, **self.kwargs)
  64. class ValidatorViewFunction(AbcValidatorFunction):
  65. def __init__(self,func:Callable):
  66. self.func = func
  67. self.sig = inspect.signature(self.func)
  68. self._unbound_fields:Dict[str,Annotated] = {}
  69. self._unbound_model: \
  70. Optional[Tuple[str,Type[BaseModel],Type[BaseModel]]] = None
  71. self._schema_factory = None
  72. self._data_parser = None
  73. self.args = ()
  74. self.kwargs = {}
  75. @property
  76. def unbound_model(self):
  77. return self._unbound_model
  78. def _validate_kind(self,kind):
  79. if kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
  80. raise Exception("参数必须是位置参数")
  81. def validate_unbound_parameters(self):
  82. index = 0
  83. for key in self.sig.parameters:
  84. param = self.sig.parameters[key]
  85. self._validate_kind(param.kind)
  86. if self._schema_factory:
  87. annotation = self._schema_factory. \
  88. validate_annotation(param.annotation)
  89. if annotation:
  90. self._unbound_model = (key,annotation)
  91. if index > 0:
  92. raise Exception(
  93. f"{self.func.__name__} 类型参数有且仅有第一个"
  94. )
  95. else:
  96. self._unbound_fields[key] = \
  97. FieldInfo.from_annotation(param.annotation)
  98. else:
  99. if isinstance(param.annotation, BaseModel):
  100. self._unbound_model = (key,param.annotation)
  101. if index > 0:
  102. raise Exception(
  103. f"{self.func.__name__} 类型参数有且仅有第一个"
  104. )
  105. else:
  106. self._unbound_fields[key] = \
  107. FieldInfo.from_annotation(param.annotation)
  108. index += 1
  109. def validate_function(self):
  110. if self._schema_factory:
  111. if self._unbound_model:
  112. self.func = validate_call(
  113. self.func,
  114. config=self._schema_factory.model_config)
  115. else:
  116. self.func = validate_call(
  117. self.func,
  118. config=self._schema_factory.model_config
  119. )
  120. else:
  121. self.func = validate_call(self.func)
  122. def unbound_data(
  123. self,
  124. data_parser: BaseReqParser
  125. ):
  126. self._data_parser = data_parser
  127. if self._schema_factory and self._data_parser:
  128. self._data_parser.prepare_factory(self._schema_factory)
  129. def unbound_schema(
  130. self,
  131. schema_factory:Optional[BaseSchemaFactory],
  132. ):
  133. self._schema_factory = schema_factory
  134. self.validate_unbound_parameters()
  135. self.validate_function()
  136. def bound_data(self, args:Tuple=(), kwargs:Dict={}):
  137. if self._unbound_model:
  138. key, bo_model = self._unbound_model
  139. obj = self._data_parser.cast_model(bo_model)
  140. kwargs[key] = obj
  141. else:
  142. data = self._data_parser.data()
  143. kwargs.clear()
  144. kwargs.update(data)
  145. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  146. self.args = args if args else ()
  147. self.kwargs = kwargs if kwargs else {}
  148. if not has_request_context:
  149. raise Exception("请在flask请求上下文中调用")
  150. try:
  151. if self._data_parser:
  152. self._data_parser.prepare()
  153. self.bound_data(self.args, self.kwargs)
  154. except ValidationError as e:
  155. return BadRequest(description=str(e))
  156. except TypeError as e:
  157. return BadRequest(description=str(e))
  158. except Exception as e:
  159. return InternalServerError(description=str(e))
  160. else:
  161. return self.func(*self.args, **self.kwargs)
  162. @dataclass
  163. class BaseValidator:
  164. data_parser:ClassVar = None
  165. schema_factory:ClassVar = None
  166. vo_context:VoValidatorContext = field(init=False)
  167. def __call__(self, func):
  168. view_function = ValidatorViewFunction(func)
  169. view_function.unbound_schema(self.schema_factory)
  170. view_function.unbound_data(self.data_parser)
  171. @wraps(func)
  172. def wrapper(*args, **kwargs):
  173. return view_function(*args, **kwargs)
  174. return wrapper
  175. @dataclass
  176. class PathValidator(BaseValidator):
  177. def __post_init__(self):
  178. self.schema_factory = PathSchemaFactory()
  179. self.data_parser = PathReqParser()
  180. @dataclass
  181. class QueryValidator(BaseValidator):
  182. is_page: bool = False
  183. include:Optional[Set[str]] = field(default=None)
  184. exclude:Optional[Set[str]] = field(default=None)
  185. extra_fields:Optional[Dict[str, FieldInfo]] = field(default=None)
  186. def __post_init__(self):
  187. vo_context = VoValidatorContext(
  188. exclude_data_alias=True,
  189. is_page=self.is_page,
  190. is_sort=self.is_page,
  191. include_fields=self.include,
  192. exclude_fields=self.exclude,
  193. )
  194. self.schema_factory = QuerySchemaFactory(
  195. vo_context,
  196. extra_strict_forbid=True,
  197. extra_allowed_fields=self.extra_fields
  198. )
  199. self.data_parser = QueryReqParser(vo_context)
  200. @dataclass
  201. class BodyValidator(BaseValidator):
  202. include:Optional[Set[str]] = field(default=None)
  203. exclude:Optional[Set[str]] = field(default=None)
  204. def __post_init__(self):
  205. vo_context = VoValidatorContext(
  206. include_fields=self.include,
  207. exclude_fields=self.exclude
  208. )
  209. self.schema_factory = BodySchemaFactory(vo_context)
  210. self.data_parser = BodyReqParser(vo_context)
  211. @dataclass
  212. class FileDownloadValidator(BaseValidator):
  213. def __post_init__(self):
  214. vo_context = VoValidatorContext(
  215. exclude_data_alias=True,
  216. is_page=True,
  217. )
  218. self.schema_factory = QuerySchemaFactory(vo_context)
  219. self.data_parser = DownloadFileQueryReqParser(vo_context)
  220. @dataclass
  221. class FileUploadValidator(BaseValidator):
  222. def __post_init__(self):
  223. self.schema_factory = ArbitrarySchemaFactory()
  224. self.data_parser = UploadFileFormReqParser(
  225. is_form=False, is_query=True, is_file=True
  226. )
  227. @dataclass
  228. class FileValidator(BaseValidator):
  229. include:Optional[Set[str]] = field(default=None)
  230. def __post_init__(self):
  231. self.schema_factory = ArbitrarySchemaFactory()
  232. self.data_parser = UploadFileFormReqParser()