schema_vo.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. from abc import ABC, abstractmethod
  4. from datetime import datetime
  5. from typing import Type, Dict, Optional, Set, Tuple, TypeVar
  6. from pydantic.alias_generators import to_camel
  7. from pydantic import AliasChoices, AliasGenerator, AliasPath, BaseModel, ConfigDict, Field, create_model
  8. from pydantic.fields import FieldInfo
  9. from pydantic_core import PydanticUndefined
  10. from pydantic import BaseModel
  11. from ruoyi_common.base.model import AllowedExtraModel, BaseEntity, BetOpt, ExtraModel, ExtraOpt, MultiFile, VoAccess, \
  12. VoValidatorContext
  13. from ruoyi_common.utils.base import DateUtil, get_final_model
  14. T = TypeVar("T")
  15. strict_valid_config = ConfigDict(
  16. from_attributes=False,
  17. alias_generator=to_camel,
  18. frozen=True,
  19. extra="forbid",
  20. strict=True,
  21. populate_by_name=False,
  22. json_encoders={
  23. datetime: lambda v: v.strftime(DateUtil.YYYY_MM_DD_HH_MM_SS)
  24. },
  25. )
  26. query_valid_config = ConfigDict(
  27. from_attributes=False,
  28. alias_generator=to_camel,
  29. frozen=True,
  30. extra="allow",
  31. strict=True,
  32. populate_by_name=False,
  33. )
  34. def VoField(
  35. body=True,
  36. query=False,
  37. sort=False,
  38. *args,
  39. **kwargs
  40. ):
  41. vo = VoAccess(
  42. body=body,
  43. query=query,
  44. sort=sort,
  45. )
  46. return Field(vo=vo, *args, **kwargs)
  47. class AbcFieldFilter(ABC):
  48. def filter(self, name: str, info: FieldInfo) -> bool:
  49. """
  50. 过滤字段
  51. Args:
  52. name(str): 字段名称
  53. info(FieldInfo): 字段元信息
  54. Returns:
  55. bool: 是否过滤
  56. """
  57. class BaseFieldFilter(AbcFieldFilter):
  58. def filter(self, name: str, info: FieldInfo) -> bool:
  59. """
  60. 过滤字段
  61. Args:
  62. name(str): 字段名称
  63. info(FieldInfo): 字段元信息
  64. Returns:
  65. bool: 是否过滤
  66. """
  67. return False
  68. class VoBodyFieldFilter(BaseFieldFilter):
  69. def __init__(self):
  70. self.action = "body"
  71. def filter(self, name: str, info: FieldInfo) -> bool:
  72. """
  73. 过滤字段
  74. Args:
  75. name(str): 字段名称
  76. info(FieldInfo): 字段元信息
  77. Returns:
  78. bool: 是否过滤
  79. """
  80. default = True
  81. if info.json_schema_extra:
  82. vo_access: VoAccess = info.json_schema_extra.get("vo", default)
  83. if vo_access:
  84. perm = getattr(vo_access, self.action, default)
  85. if perm:
  86. is_required = getattr(vo_access, "body_required", False)
  87. if is_required:
  88. self.change_to_required(info)
  89. else:
  90. self.change_to_optional(info)
  91. flag = perm
  92. else:
  93. flag = default
  94. else:
  95. flag = default
  96. return flag
  97. @classmethod
  98. def change_to_required(cls, info: FieldInfo):
  99. """
  100. 将注解转换为必选项
  101. Args:
  102. info(FieldInfo): 字段元信息
  103. """
  104. info.default = PydanticUndefined
  105. info.default_factory = None
  106. @classmethod
  107. def change_to_optional(cls, info: FieldInfo):
  108. """
  109. 将注解转换为可选项
  110. Args:
  111. info(FieldInfo): 字段元信息
  112. """
  113. if info.is_required:
  114. info.default = None
  115. class VoQueryFieldFilter(BaseFieldFilter):
  116. extra_opt_cls_list = [BetOpt]
  117. def __init__(self):
  118. self.action = "query"
  119. self.sort_fields: Dict[str, FieldInfo] = {}
  120. self.extra_fields: Dict[str, ExtraOpt] = {}
  121. def filter(self, name: str, info: FieldInfo) -> bool:
  122. """
  123. 根据权限信息,重置字段
  124. Args:
  125. name(str): 字段名称
  126. info(FieldInfo): 字段元信息
  127. Returns:
  128. Tuple[bool,bool]: 是否支持查询,是否支持排序
  129. """
  130. default = False
  131. if info.json_schema_extra:
  132. vo_access: VoAccess = info.json_schema_extra.get("vo", False)
  133. query_perm = getattr(vo_access, self.action, default)
  134. for extra_opt_cls in self.extra_opt_cls_list:
  135. if isinstance(query_perm, extra_opt_cls):
  136. query_perm.name = name
  137. query_perm.info = info
  138. self.extra_fields[name] = query_perm
  139. query_perm = False
  140. else:
  141. continue
  142. sort_perm = getattr(vo_access, "sort", default)
  143. if sort_perm:
  144. self.sort_fields[name] = info
  145. return query_perm
  146. else:
  147. return default
  148. class AbcSchemaFactory(ABC):
  149. @abstractmethod
  150. def validate_annotation(annotation: Type) -> Optional[Type]:
  151. """
  152. 检查注解是否有效
  153. Args:
  154. annotation(Type): 注解
  155. Returns:
  156. Optional[Type]: 合法的类型
  157. """
  158. pass
  159. class BaseSchemaFactory(AbcSchemaFactory):
  160. action = "base"
  161. model_config = query_valid_config
  162. def __init__(self):
  163. self.field_filter = None
  164. self.model_suffix = "Vo"
  165. def validate_annotation(self, annotation: Type) -> Optional[Type]:
  166. """
  167. 检查注解是否有效
  168. Args:
  169. annotation(Type): 注解
  170. Returns:
  171. Optional[Type]: 合法的类型
  172. """
  173. bo_model = get_final_model(annotation)
  174. if issubclass(bo_model, BaseEntity):
  175. model = self.rebuild_model(model_cls=bo_model)
  176. return model
  177. else:
  178. return None
  179. def rebuild_model(self, model_cls: Type[BaseEntity]) -> Type[BaseEntity]:
  180. """
  181. 从已有模型类,创建新的模型类
  182. Args:
  183. model_cls: 已有模型类
  184. Returns:
  185. Type[BaseModel]: 新的模型类
  186. """
  187. field_definitions = {}
  188. for name, info in model_cls.model_fields.items():
  189. flag = self.rebuild_field(name, info)
  190. if flag:
  191. field_definitions[name] = info.annotation, info
  192. vo_name = model_cls.__name__ + self.action.capitalize() + \
  193. self.model_suffix
  194. # 如果是 QuerySchemaFactory,使用 query_valid_config 允许额外字段
  195. if hasattr(self, 'action') and self.action == 'query' and hasattr(self, 'model_config'):
  196. # 创建一个新的基类,使用 query_valid_config
  197. class QueryBaseModel(model_cls):
  198. model_config = self.model_config.copy()
  199. base_class = QueryBaseModel
  200. else:
  201. base_class = model_cls
  202. new_model = create_model(
  203. vo_name,
  204. __base__=base_class,
  205. __doc__=model_cls.__doc__,
  206. __module__=model_cls.__module__,
  207. **field_definitions
  208. )
  209. return new_model
  210. def rebuild_field(self, name: str, info: FieldInfo) -> bool:
  211. """
  212. 重置字段
  213. Args:
  214. name(str): 字段名称
  215. info(FieldInfo): 字段元信息
  216. Returns:
  217. bool: 是否支持重置
  218. """
  219. return True
  220. class BodySchemaFactory(BaseSchemaFactory):
  221. action = "body"
  222. model_config = query_valid_config
  223. def __init__(self, context: VoValidatorContext):
  224. super().__init__()
  225. self.context = context
  226. self.field_filter = VoBodyFieldFilter()
  227. def validate_annotation(self, annotation: Type) -> Optional[Type[BaseModel]]:
  228. bo_model = get_final_model(annotation)
  229. if issubclass(bo_model, BaseModel):
  230. updated_model = self.rebuild_model(model_cls=bo_model)
  231. return updated_model
  232. else:
  233. if self.context.include_fields or self.context.exclude_fields:
  234. raise Exception(f"注解{annotation.__name__}不是模型,不支持include和exclude请求条件")
  235. return None
  236. def rebuild_field(self, name: str, info: FieldInfo) -> bool:
  237. """
  238. 重置字段
  239. Args:
  240. name(str): 字段名称
  241. info(FieldInfo): 字段元信息
  242. Returns:
  243. bool: 是否支持重置
  244. """
  245. if self.context.include_fields and \
  246. name not in self.context.include_fields:
  247. return False
  248. if self.context.exclude_fields and \
  249. name in self.context.exclude_fields:
  250. return False
  251. flag = self.field_filter.filter(name, info)
  252. return flag
  253. class QuerySchemaFactory(BaseSchemaFactory):
  254. action = "query"
  255. model_config = query_valid_config
  256. def __init__(
  257. self,
  258. context,
  259. extra_strict_forbid=True,
  260. extra_allowed_fields: Type[Dict[str, FieldInfo]] = None
  261. ):
  262. super().__init__()
  263. self.context = context
  264. self.field_filter = VoQueryFieldFilter()
  265. self.extra_strict_forbid = extra_strict_forbid
  266. self.extra_allowed_fields = extra_allowed_fields
  267. self.extra_model = None
  268. def validate_annotation(self, annotation: Type) -> Optional[Type[BaseEntity]]:
  269. bo_model = get_final_model(annotation)
  270. if issubclass(bo_model, BaseEntity):
  271. updated_model = self.rebuild_model(model_cls=bo_model)
  272. # 使用 query_valid_config 而不是原模型的配置,允许额外字段
  273. self.model_config = self.model_config.copy()
  274. self._validate_context()
  275. self.rebuild_extra_model()
  276. return updated_model
  277. else:
  278. if self.context.include_fields or self.context.exclude_fields:
  279. raise Exception(f"注解{annotation.__name__}不是模型,不支持include和exclude请求条件")
  280. return None
  281. def _validate_context(self):
  282. """
  283. 验证上下文信息
  284. """
  285. for name, info in self.field_filter.sort_fields.items():
  286. alias_set = self.get_validate_alias(name, info)
  287. self.context.include_sort_alias = self.context.include_sort_alias | alias_set
  288. def get_validate_alias(self, name: str, info: FieldInfo) -> Set[str]:
  289. """
  290. 获取验证别名
  291. Args:
  292. name (str): 字段名称
  293. info (FieldInfo): 字段元信息
  294. Raises:
  295. Exception: 模型不支持AliasPath
  296. Returns:
  297. Set[str]: 别名集合
  298. """
  299. alias_set = set()
  300. alias = self.get_alias_from_config(name)
  301. if alias:
  302. alias_set.add(alias)
  303. if info.validation_alias:
  304. if isinstance(info.validation_alias, str):
  305. alias_set.add(info.validation_alias)
  306. elif isinstance(info.validation_alias, AliasPath):
  307. raise Exception(f"模型字段{name}不支持AliasPath")
  308. elif isinstance(info.validation_alias, AliasChoices):
  309. alias_set = alias_set | \
  310. set(info.validation_alias.choices)
  311. if "populate_by_name" in self.model_config \
  312. and self.model_config["populate_by_name"]:
  313. alias_set.add(name)
  314. return alias_set
  315. def get_alias_from_config(self, name: str) -> Optional[str]:
  316. """
  317. 从配置中获取别名
  318. Args:
  319. name (str): 字段名称
  320. Returns:
  321. Optional[str]: 别名
  322. """
  323. if "generate_alias" in self.model_config:
  324. generate_alias = self.model_config["generate_alias"]
  325. if callable(generate_alias):
  326. alias = generate_alias(name)
  327. return alias
  328. elif isinstance(generate_alias, AliasGenerator):
  329. alias, v_alias, s_alias = generate_alias(name)
  330. return alias or v_alias
  331. def rebuild_extra_model(self) -> Optional[Type[ExtraModel]]:
  332. """
  333. 重新构建Extra查询模型
  334. Args:
  335. fields (Dict[str,ExtraOpt]): 字段元信息
  336. Returns:
  337. Optional[Type[ExtraModel]]: extra查询模型
  338. """
  339. field_defintions = {}
  340. for name, opt in self.field_filter.extra_fields.items():
  341. if isinstance(opt, BetOpt):
  342. min_fieldinfo, max_fieldinfo = self.rebuild_bet_opt(name, opt)
  343. field_defintions[opt.min] = min_fieldinfo.annotation, min_fieldinfo
  344. field_defintions[opt.max] = max_fieldinfo.annotation, max_fieldinfo
  345. elif isinstance(opt, ExtraOpt):
  346. field_defintions[name] = opt.info.annotation, opt.info
  347. else:
  348. continue
  349. if self.extra_allowed_fields:
  350. for name, info in self.extra_allowed_fields.items():
  351. if name not in field_defintions:
  352. field_defintions[name] = info.annotation, info
  353. if field_defintions:
  354. extra_model_cls = ExtraModel if self.extra_strict_forbid else AllowedExtraModel
  355. self.extra_model = create_model(
  356. model_name="ExtraModel",
  357. __base__=extra_model_cls,
  358. **field_defintions
  359. )
  360. def rebuild_bet_opt(self, name: str, opt: BetOpt) -> Tuple[FieldInfo, FieldInfo]:
  361. """
  362. 重新构建BetOpt
  363. Args:
  364. name (str): 字段名称
  365. opt (BetOpt): 字段元信息
  366. Returns:
  367. Tuple[FieldInfo,FieldInfo]: Between查询条件信息
  368. """
  369. min_opt = opt.replace(active="min")
  370. min_fieldinfo = FieldInfo.from_annotation(min_opt.info.annotation)
  371. min_fieldinfo.json_schema_extra = {"vo_opt": min_opt}
  372. max_opt = opt.replace(active="max")
  373. max_fieldinfo = FieldInfo.from_annotation(max_opt.info.annotation)
  374. max_fieldinfo.json_schema_extra = {"vo_opt": max_opt}
  375. return min_fieldinfo, max_fieldinfo
  376. def rebuild_field(self, name: str, info: FieldInfo) -> bool:
  377. """
  378. 重置字段
  379. Args:
  380. name(str): 字段名称
  381. info(FieldInfo): 字段元信息
  382. Returns:
  383. bool: 是否支持重置
  384. """
  385. if self.context.include_fields and \
  386. name not in self.context.include_fields:
  387. return False
  388. if self.context.exclude_fields and \
  389. name in self.context.exclude_fields:
  390. return False
  391. flag = self.field_filter.filter(name, info)
  392. return flag
  393. class FormSchemaFactory(AbcSchemaFactory):
  394. action = "form"
  395. model_config = strict_valid_config
  396. def validate_annotation(self, annotation: Type[BaseModel]) -> Optional[Type[BaseModel]]:
  397. pass
  398. class PathSchemaFactory(AbcSchemaFactory):
  399. action = "path"
  400. model_config = query_valid_config
  401. def validate_annotation(self, annotation: Type[BaseModel]) -> Optional[Type[BaseModel]]:
  402. pass
  403. class ArbitrarySchemaFactory(AbcSchemaFactory):
  404. model_config = query_valid_config
  405. def __init__(self):
  406. super().__init__()
  407. query_valid_config_copy = query_valid_config.copy()
  408. query_valid_config_copy.update({
  409. "arbitrary_types_allowed": True
  410. })
  411. self.model_config = query_valid_config_copy
  412. def validate_annotation(self, annotation: Type) -> Optional[Type[BaseEntity]]:
  413. bo_model = get_final_model(annotation)
  414. if issubclass(bo_model, BaseEntity):
  415. return bo_model
  416. else:
  417. return None