schema_vo.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. from abc import ABC, abstractmethod
  4. from datetime import datetime
  5. from typing import Type, Dict, Optional, Set, Tuple, TypeVar
  6. from pydantic.alias_generators import to_camel
  7. from pydantic import AliasChoices, AliasGenerator, AliasPath, BaseModel, ConfigDict, Field, create_model
  8. from pydantic.fields import FieldInfo
  9. from pydantic_core import PydanticUndefined
  10. from pydantic import BaseModel
  11. from ruoyi_common.base.model import AllowedExtraModel, BaseEntity, BetOpt, ExtraModel, ExtraOpt, MultiFile, VoAccess, VoValidatorContext
  12. from ruoyi_common.utils.base import DateUtil, get_final_model
  13. T = TypeVar("T")
  14. strict_valid_config = ConfigDict(
  15. from_attributes = False,
  16. alias_generator = to_camel,
  17. frozen = True,
  18. extra = "forbid",
  19. strict = True,
  20. populate_by_name = False,
  21. json_encoders = {
  22. datetime: lambda v: v.strftime(DateUtil.YYYY_MM_DD_HH_MM_SS)
  23. },
  24. )
  25. query_valid_config = ConfigDict(
  26. from_attributes = False,
  27. alias_generator = to_camel,
  28. frozen = True,
  29. extra = "allow",
  30. strict = True,
  31. populate_by_name = False,
  32. )
  33. def VoField(
  34. body=True,
  35. query=False,
  36. sort=False,
  37. *args,
  38. **kwargs
  39. ):
  40. vo = VoAccess(
  41. body=body,
  42. query=query,
  43. sort=sort,
  44. )
  45. return Field(vo=vo,*args, **kwargs)
  46. class AbcFieldFilter(ABC):
  47. def filter(self, name:str, info:FieldInfo) -> bool:
  48. """
  49. 过滤字段
  50. Args:
  51. name(str): 字段名称
  52. info(FieldInfo): 字段元信息
  53. Returns:
  54. bool: 是否过滤
  55. """
  56. class BaseFieldFilter(AbcFieldFilter):
  57. def filter(self, name:str, info:FieldInfo) -> bool:
  58. """
  59. 过滤字段
  60. Args:
  61. name(str): 字段名称
  62. info(FieldInfo): 字段元信息
  63. Returns:
  64. bool: 是否过滤
  65. """
  66. return False
  67. class VoBodyFieldFilter(BaseFieldFilter):
  68. def __init__(self):
  69. self.action = "body"
  70. def filter(self, name:str, info:FieldInfo) -> bool:
  71. """
  72. 过滤字段
  73. Args:
  74. name(str): 字段名称
  75. info(FieldInfo): 字段元信息
  76. Returns:
  77. bool: 是否过滤
  78. """
  79. default = True
  80. if info.json_schema_extra:
  81. vo_access:VoAccess = info.json_schema_extra.get("vo",default)
  82. if vo_access:
  83. perm = getattr(vo_access,self.action,default)
  84. if perm:
  85. is_required = getattr(vo_access,"body_required",False)
  86. if is_required:
  87. self.change_to_required(info)
  88. else:
  89. self.change_to_optional(info)
  90. flag = perm
  91. else:
  92. flag = default
  93. else:
  94. flag = default
  95. return flag
  96. @classmethod
  97. def change_to_required(cls, info:FieldInfo):
  98. """
  99. 将注解转换为必选项
  100. Args:
  101. info(FieldInfo): 字段元信息
  102. """
  103. info.default = PydanticUndefined
  104. info.default_factory = None
  105. @classmethod
  106. def change_to_optional(cls, info:FieldInfo):
  107. """
  108. 将注解转换为可选项
  109. Args:
  110. info(FieldInfo): 字段元信息
  111. """
  112. if info.is_required:
  113. info.default = None
  114. class VoQueryFieldFilter(BaseFieldFilter):
  115. extra_opt_cls_list = [BetOpt]
  116. def __init__(self):
  117. self.action = "query"
  118. self.sort_fields:Dict[str,FieldInfo] = {}
  119. self.extra_fields:Dict[str,ExtraOpt] = {}
  120. def filter(self, name:str, info:FieldInfo) -> bool:
  121. """
  122. 根据权限信息,重置字段
  123. Args:
  124. name(str): 字段名称
  125. info(FieldInfo): 字段元信息
  126. Returns:
  127. Tuple[bool,bool]: 是否支持查询,是否支持排序
  128. """
  129. default = False
  130. if info.json_schema_extra:
  131. vo_access:VoAccess = info.json_schema_extra.get("vo",False)
  132. query_perm = getattr(vo_access,self.action,default)
  133. for extra_opt_cls in self.extra_opt_cls_list:
  134. if isinstance(query_perm, extra_opt_cls):
  135. query_perm.name = name
  136. query_perm.info = info
  137. self.extra_fields[name] = query_perm
  138. query_perm = False
  139. else:
  140. continue
  141. sort_perm = getattr(vo_access,"sort",default)
  142. if sort_perm:
  143. self.sort_fields[name] = info
  144. return query_perm
  145. else:
  146. return default
  147. class AbcSchemaFactory(ABC):
  148. @abstractmethod
  149. def validate_annotation(annotation:Type) -> Optional[Type]:
  150. """
  151. 检查注解是否有效
  152. Args:
  153. annotation(Type): 注解
  154. Returns:
  155. Optional[Type]: 合法的类型
  156. """
  157. pass
  158. class BaseSchemaFactory(AbcSchemaFactory):
  159. action = "base"
  160. model_config = query_valid_config
  161. def __init__(self):
  162. self.field_filter = None
  163. self.model_suffix = "Vo"
  164. def validate_annotation(self, annotation:Type) -> Optional[Type]:
  165. """
  166. 检查注解是否有效
  167. Args:
  168. annotation(Type): 注解
  169. Returns:
  170. Optional[Type]: 合法的类型
  171. """
  172. bo_model = get_final_model(annotation)
  173. if issubclass(bo_model, BaseEntity):
  174. model = self.rebuild_model(model_cls=bo_model)
  175. return model
  176. else:
  177. return None
  178. def rebuild_model(self, model_cls:Type[BaseEntity]) -> Type[BaseEntity]:
  179. """
  180. 从已有模型类,创建新的模型类
  181. Args:
  182. model_cls: 已有模型类
  183. Returns:
  184. Type[BaseModel]: 新的模型类
  185. """
  186. field_definitions = {}
  187. for name,info in model_cls.model_fields.items():
  188. flag = self.rebuild_field(name,info)
  189. if flag:
  190. field_definitions[name] = info.annotation,info
  191. vo_name = model_cls.__name__ + self.action.capitalize() + \
  192. self.model_suffix
  193. return create_model(
  194. vo_name,
  195. __base__= model_cls,
  196. __doc__ = model_cls.__doc__,
  197. __module__= model_cls.__module__,
  198. **field_definitions
  199. )
  200. def rebuild_field(self, name:str, info:FieldInfo) -> bool:
  201. """
  202. 重置字段
  203. Args:
  204. name(str): 字段名称
  205. info(FieldInfo): 字段元信息
  206. Returns:
  207. bool: 是否支持重置
  208. """
  209. return True
  210. class BodySchemaFactory(BaseSchemaFactory):
  211. action = "body"
  212. model_config = query_valid_config
  213. def __init__(self,context:VoValidatorContext):
  214. super().__init__()
  215. self.context = context
  216. self.field_filter = VoBodyFieldFilter()
  217. def validate_annotation(self, annotation: Type) -> Optional[Type[BaseModel]]:
  218. bo_model = get_final_model(annotation)
  219. if issubclass(bo_model, BaseModel):
  220. updated_model = self.rebuild_model(model_cls=bo_model)
  221. return updated_model
  222. else:
  223. if self.context.include_fields or self.context.exclude_fields:
  224. raise Exception(f"注解{annotation.__name__}不是模型,不支持include和exclude请求条件")
  225. return None
  226. def rebuild_field(self, name:str, info:FieldInfo) -> bool:
  227. """
  228. 重置字段
  229. Args:
  230. name(str): 字段名称
  231. info(FieldInfo): 字段元信息
  232. Returns:
  233. bool: 是否支持重置
  234. """
  235. if self.context.include_fields and \
  236. name not in self.context.include_fields:
  237. return False
  238. if self.context.exclude_fields and \
  239. name in self.context.exclude_fields:
  240. return False
  241. flag = self.field_filter.filter(name,info)
  242. return flag
  243. class QuerySchemaFactory(BaseSchemaFactory):
  244. action = "query"
  245. model_config = query_valid_config
  246. def __init__(
  247. self,
  248. context,
  249. extra_strict_forbid=True,
  250. extra_allowed_fields:Type[Dict[str,FieldInfo]]=None
  251. ):
  252. super().__init__()
  253. self.context = context
  254. self.field_filter = VoQueryFieldFilter()
  255. self.extra_strict_forbid = extra_strict_forbid
  256. self.extra_allowed_fields = extra_allowed_fields
  257. self.extra_model = None
  258. def validate_annotation(self, annotation: Type) -> Optional[Type[BaseEntity]]:
  259. bo_model = get_final_model(annotation)
  260. if issubclass(bo_model, BaseEntity):
  261. updated_model = self.rebuild_model(model_cls=bo_model)
  262. self.model_config = annotation.model_config.copy()
  263. self._validate_context()
  264. self.rebuild_extra_model()
  265. return updated_model
  266. else:
  267. if self.context.include_fields or self.context.exclude_fields:
  268. raise Exception(f"注解{annotation.__name__}不是模型,不支持include和exclude请求条件")
  269. return None
  270. def _validate_context(self):
  271. """
  272. 验证上下文信息
  273. """
  274. for name,info in self.field_filter.sort_fields.items():
  275. alias_set = self.get_validate_alias(name,info)
  276. self.context.include_sort_alias = self.context.include_sort_alias | alias_set
  277. def get_validate_alias(self, name:str,info:FieldInfo) -> Set[str]:
  278. """
  279. 获取验证别名
  280. Args:
  281. name (str): 字段名称
  282. info (FieldInfo): 字段元信息
  283. Raises:
  284. Exception: 模型不支持AliasPath
  285. Returns:
  286. Set[str]: 别名集合
  287. """
  288. alias_set = set()
  289. alias = self.get_alias_from_config(name)
  290. if alias:
  291. alias_set.add(alias)
  292. if info.validation_alias:
  293. if isinstance(info.validation_alias, str):
  294. alias_set.add(info.validation_alias)
  295. elif isinstance(info.validation_alias, AliasPath):
  296. raise Exception(f"模型字段{name}不支持AliasPath")
  297. elif isinstance(info.validation_alias, AliasChoices):
  298. alias_set = alias_set | \
  299. set(info.validation_alias.choices)
  300. if "populate_by_name" in self.model_config \
  301. and self.model_config["populate_by_name"]:
  302. alias_set.add(name)
  303. return alias_set
  304. def get_alias_from_config(self,name:str)-> Optional[str]:
  305. """
  306. 从配置中获取别名
  307. Args:
  308. name (str): 字段名称
  309. Returns:
  310. Optional[str]: 别名
  311. """
  312. if "generate_alias" in self.model_config:
  313. generate_alias = self.model_config["generate_alias"]
  314. if callable(generate_alias):
  315. alias = generate_alias(name)
  316. return alias
  317. elif isinstance(generate_alias, AliasGenerator):
  318. alias,v_alias,s_alias = generate_alias(name)
  319. return alias or v_alias
  320. def rebuild_extra_model(self) -> Optional[Type[ExtraModel]]:
  321. """
  322. 重新构建Extra查询模型
  323. Args:
  324. fields (Dict[str,ExtraOpt]): 字段元信息
  325. Returns:
  326. Optional[Type[ExtraModel]]: extra查询模型
  327. """
  328. field_defintions = {}
  329. for name,opt in self.field_filter.extra_fields.items():
  330. if isinstance(opt, BetOpt):
  331. min_fieldinfo,max_fieldinfo = self.rebuild_bet_opt(name,opt)
  332. field_defintions[opt.min] = min_fieldinfo.annotation,min_fieldinfo
  333. field_defintions[opt.max] = max_fieldinfo.annotation,max_fieldinfo
  334. elif isinstance(opt, ExtraOpt):
  335. field_defintions[name] = opt.info.annotation,opt.info
  336. else:
  337. continue
  338. if self.extra_allowed_fields:
  339. for name,info in self.extra_allowed_fields.items():
  340. if name not in field_defintions:
  341. field_defintions[name] = info.annotation,info
  342. if field_defintions:
  343. extra_model_cls = ExtraModel if self.extra_strict_forbid else AllowedExtraModel
  344. self.extra_model = create_model(
  345. model_name="ExtraModel",
  346. __base__=extra_model_cls,
  347. **field_defintions
  348. )
  349. def rebuild_bet_opt(self, name:str, opt:BetOpt) -> Tuple[FieldInfo,FieldInfo]:
  350. """
  351. 重新构建BetOpt
  352. Args:
  353. name (str): 字段名称
  354. opt (BetOpt): 字段元信息
  355. Returns:
  356. Tuple[FieldInfo,FieldInfo]: Between查询条件信息
  357. """
  358. min_opt = opt.replace(active="min")
  359. min_fieldinfo = FieldInfo.from_annotation(min_opt.info.annotation)
  360. min_fieldinfo.json_schema_extra={"vo_opt":min_opt}
  361. max_opt = opt.replace(active="max")
  362. max_fieldinfo = FieldInfo.from_annotation(max_opt.info.annotation)
  363. max_fieldinfo.json_schema_extra={"vo_opt":max_opt}
  364. return min_fieldinfo,max_fieldinfo
  365. def rebuild_field(self, name:str, info:FieldInfo) -> bool:
  366. """
  367. 重置字段
  368. Args:
  369. name(str): 字段名称
  370. info(FieldInfo): 字段元信息
  371. Returns:
  372. bool: 是否支持重置
  373. """
  374. if self.context.include_fields and \
  375. name not in self.context.include_fields:
  376. return False
  377. if self.context.exclude_fields and \
  378. name in self.context.exclude_fields:
  379. return False
  380. flag = self.field_filter.filter(name,info)
  381. return flag
  382. class FormSchemaFactory(AbcSchemaFactory):
  383. action = "form"
  384. model_config = strict_valid_config
  385. def validate_annotation(self, annotation: Type[BaseModel]) -> Optional[Type[BaseModel]]:
  386. pass
  387. class PathSchemaFactory(AbcSchemaFactory):
  388. action = "path"
  389. model_config = query_valid_config
  390. def validate_annotation(self, annotation: Type[BaseModel]) -> Optional[Type[BaseModel]]:
  391. pass
  392. class ArbitrarySchemaFactory(AbcSchemaFactory):
  393. model_config = query_valid_config
  394. def __init__(self):
  395. super().__init__()
  396. query_valid_config_copy = query_valid_config.copy()
  397. query_valid_config_copy.update({
  398. "arbitrary_types_allowed":True
  399. })
  400. self.model_config = query_valid_config_copy
  401. def validate_annotation(self, annotation: Type) -> Optional[Type[BaseEntity]]:
  402. bo_model = get_final_model(annotation)
  403. if issubclass(bo_model, BaseEntity):
  404. return bo_model
  405. else:
  406. return None