| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- # -*- coding: utf-8 -*-
- # @Author : YY
- from abc import ABC, abstractmethod
- from datetime import datetime
- from typing import Type, Dict, Optional, Set, Tuple, TypeVar
- from pydantic.alias_generators import to_camel
- from pydantic import AliasChoices, AliasGenerator, AliasPath, BaseModel, ConfigDict, Field, create_model
- from pydantic.fields import FieldInfo
- from pydantic_core import PydanticUndefined
- from pydantic import BaseModel
- from ruoyi_common.base.model import AllowedExtraModel, BaseEntity, BetOpt, ExtraModel, ExtraOpt, MultiFile, VoAccess, \
- VoValidatorContext
- from ruoyi_common.utils.base import DateUtil, get_final_model
- T = TypeVar("T")
- strict_valid_config = ConfigDict(
- from_attributes=False,
- alias_generator=to_camel,
- frozen=True,
- extra="forbid",
- strict=True,
- populate_by_name=False,
- json_encoders={
- datetime: lambda v: v.strftime(DateUtil.YYYY_MM_DD_HH_MM_SS)
- },
- )
- query_valid_config = ConfigDict(
- from_attributes=False,
- alias_generator=to_camel,
- frozen=True,
- extra="allow",
- strict=True,
- populate_by_name=False,
- )
- def VoField(
- body=True,
- query=False,
- sort=False,
- *args,
- **kwargs
- ):
- vo = VoAccess(
- body=body,
- query=query,
- sort=sort,
- )
- return Field(vo=vo, *args, **kwargs)
- class AbcFieldFilter(ABC):
- def filter(self, name: str, info: FieldInfo) -> bool:
- """
- 过滤字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- bool: 是否过滤
- """
- class BaseFieldFilter(AbcFieldFilter):
- def filter(self, name: str, info: FieldInfo) -> bool:
- """
- 过滤字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- bool: 是否过滤
- """
- return False
- class VoBodyFieldFilter(BaseFieldFilter):
- def __init__(self):
- self.action = "body"
- def filter(self, name: str, info: FieldInfo) -> bool:
- """
- 过滤字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- bool: 是否过滤
- """
- default = True
- if info.json_schema_extra:
- vo_access: VoAccess = info.json_schema_extra.get("vo", default)
- if vo_access:
- perm = getattr(vo_access, self.action, default)
- if perm:
- is_required = getattr(vo_access, "body_required", False)
- if is_required:
- self.change_to_required(info)
- else:
- self.change_to_optional(info)
- flag = perm
- else:
- flag = default
- else:
- flag = default
- return flag
- @classmethod
- def change_to_required(cls, info: FieldInfo):
- """
- 将注解转换为必选项
- Args:
- info(FieldInfo): 字段元信息
- """
- info.default = PydanticUndefined
- info.default_factory = None
- @classmethod
- def change_to_optional(cls, info: FieldInfo):
- """
- 将注解转换为可选项
- Args:
- info(FieldInfo): 字段元信息
- """
- if info.is_required:
- info.default = None
- class VoQueryFieldFilter(BaseFieldFilter):
- extra_opt_cls_list = [BetOpt]
- def __init__(self):
- self.action = "query"
- self.sort_fields: Dict[str, FieldInfo] = {}
- self.extra_fields: Dict[str, ExtraOpt] = {}
- def filter(self, name: str, info: FieldInfo) -> bool:
- """
- 根据权限信息,重置字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- Tuple[bool,bool]: 是否支持查询,是否支持排序
- """
- default = False
- if info.json_schema_extra:
- vo_access: VoAccess = info.json_schema_extra.get("vo", False)
- query_perm = getattr(vo_access, self.action, default)
- for extra_opt_cls in self.extra_opt_cls_list:
- if isinstance(query_perm, extra_opt_cls):
- query_perm.name = name
- query_perm.info = info
- self.extra_fields[name] = query_perm
- query_perm = False
- else:
- continue
- sort_perm = getattr(vo_access, "sort", default)
- if sort_perm:
- self.sort_fields[name] = info
- return query_perm
- else:
- return default
- class AbcSchemaFactory(ABC):
- @abstractmethod
- def validate_annotation(annotation: Type) -> Optional[Type]:
- """
- 检查注解是否有效
- Args:
- annotation(Type): 注解
- Returns:
- Optional[Type]: 合法的类型
- """
- pass
- class BaseSchemaFactory(AbcSchemaFactory):
- action = "base"
- model_config = query_valid_config
- def __init__(self):
- self.field_filter = None
- self.model_suffix = "Vo"
- def validate_annotation(self, annotation: Type) -> Optional[Type]:
- """
- 检查注解是否有效
- Args:
- annotation(Type): 注解
- Returns:
- Optional[Type]: 合法的类型
- """
- bo_model = get_final_model(annotation)
- if issubclass(bo_model, BaseEntity):
- model = self.rebuild_model(model_cls=bo_model)
- return model
- else:
- return None
- def rebuild_model(self, model_cls: Type[BaseEntity]) -> Type[BaseEntity]:
- """
- 从已有模型类,创建新的模型类
- Args:
- model_cls: 已有模型类
- Returns:
- Type[BaseModel]: 新的模型类
- """
- field_definitions = {}
- for name, info in model_cls.model_fields.items():
- flag = self.rebuild_field(name, info)
- if flag:
- field_definitions[name] = info.annotation, info
- vo_name = model_cls.__name__ + self.action.capitalize() + \
- self.model_suffix
- # 如果是 QuerySchemaFactory,使用 query_valid_config 允许额外字段
- if hasattr(self, 'action') and self.action == 'query' and hasattr(self, 'model_config'):
- # 创建一个新的基类,使用 query_valid_config
- class QueryBaseModel(model_cls):
- model_config = self.model_config.copy()
- base_class = QueryBaseModel
- else:
- base_class = model_cls
- new_model = create_model(
- vo_name,
- __base__=base_class,
- __doc__=model_cls.__doc__,
- __module__=model_cls.__module__,
- **field_definitions
- )
- return new_model
- def rebuild_field(self, name: str, info: FieldInfo) -> bool:
- """
- 重置字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- bool: 是否支持重置
- """
- return True
- class BodySchemaFactory(BaseSchemaFactory):
- action = "body"
- model_config = query_valid_config
- def __init__(self, context: VoValidatorContext):
- super().__init__()
- self.context = context
- self.field_filter = VoBodyFieldFilter()
- def validate_annotation(self, annotation: Type) -> Optional[Type[BaseModel]]:
- bo_model = get_final_model(annotation)
- if issubclass(bo_model, BaseModel):
- updated_model = self.rebuild_model(model_cls=bo_model)
- return updated_model
- else:
- if self.context.include_fields or self.context.exclude_fields:
- raise Exception(f"注解{annotation.__name__}不是模型,不支持include和exclude请求条件")
- return None
- def rebuild_field(self, name: str, info: FieldInfo) -> bool:
- """
- 重置字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- bool: 是否支持重置
- """
- if self.context.include_fields and \
- name not in self.context.include_fields:
- return False
- if self.context.exclude_fields and \
- name in self.context.exclude_fields:
- return False
- flag = self.field_filter.filter(name, info)
- return flag
- class QuerySchemaFactory(BaseSchemaFactory):
- action = "query"
- model_config = query_valid_config
- def __init__(
- self,
- context,
- extra_strict_forbid=True,
- extra_allowed_fields: Type[Dict[str, FieldInfo]] = None
- ):
- super().__init__()
- self.context = context
- self.field_filter = VoQueryFieldFilter()
- self.extra_strict_forbid = extra_strict_forbid
- self.extra_allowed_fields = extra_allowed_fields
- self.extra_model = None
- def validate_annotation(self, annotation: Type) -> Optional[Type[BaseEntity]]:
- bo_model = get_final_model(annotation)
- if issubclass(bo_model, BaseEntity):
- updated_model = self.rebuild_model(model_cls=bo_model)
- # 使用 query_valid_config 而不是原模型的配置,允许额外字段
- self.model_config = self.model_config.copy()
- self._validate_context()
- self.rebuild_extra_model()
- return updated_model
- else:
- if self.context.include_fields or self.context.exclude_fields:
- raise Exception(f"注解{annotation.__name__}不是模型,不支持include和exclude请求条件")
- return None
- def _validate_context(self):
- """
- 验证上下文信息
- """
- for name, info in self.field_filter.sort_fields.items():
- alias_set = self.get_validate_alias(name, info)
- self.context.include_sort_alias = self.context.include_sort_alias | alias_set
- def get_validate_alias(self, name: str, info: FieldInfo) -> Set[str]:
- """
- 获取验证别名
- Args:
- name (str): 字段名称
- info (FieldInfo): 字段元信息
- Raises:
- Exception: 模型不支持AliasPath
- Returns:
- Set[str]: 别名集合
- """
- alias_set = set()
- alias = self.get_alias_from_config(name)
- if alias:
- alias_set.add(alias)
- if info.validation_alias:
- if isinstance(info.validation_alias, str):
- alias_set.add(info.validation_alias)
- elif isinstance(info.validation_alias, AliasPath):
- raise Exception(f"模型字段{name}不支持AliasPath")
- elif isinstance(info.validation_alias, AliasChoices):
- alias_set = alias_set | \
- set(info.validation_alias.choices)
- if "populate_by_name" in self.model_config \
- and self.model_config["populate_by_name"]:
- alias_set.add(name)
- return alias_set
- def get_alias_from_config(self, name: str) -> Optional[str]:
- """
- 从配置中获取别名
- Args:
- name (str): 字段名称
- Returns:
- Optional[str]: 别名
- """
- if "generate_alias" in self.model_config:
- generate_alias = self.model_config["generate_alias"]
- if callable(generate_alias):
- alias = generate_alias(name)
- return alias
- elif isinstance(generate_alias, AliasGenerator):
- alias, v_alias, s_alias = generate_alias(name)
- return alias or v_alias
- def rebuild_extra_model(self) -> Optional[Type[ExtraModel]]:
- """
- 重新构建Extra查询模型
- Args:
- fields (Dict[str,ExtraOpt]): 字段元信息
- Returns:
- Optional[Type[ExtraModel]]: extra查询模型
- """
- field_defintions = {}
- for name, opt in self.field_filter.extra_fields.items():
- if isinstance(opt, BetOpt):
- min_fieldinfo, max_fieldinfo = self.rebuild_bet_opt(name, opt)
- field_defintions[opt.min] = min_fieldinfo.annotation, min_fieldinfo
- field_defintions[opt.max] = max_fieldinfo.annotation, max_fieldinfo
- elif isinstance(opt, ExtraOpt):
- field_defintions[name] = opt.info.annotation, opt.info
- else:
- continue
- if self.extra_allowed_fields:
- for name, info in self.extra_allowed_fields.items():
- if name not in field_defintions:
- field_defintions[name] = info.annotation, info
- if field_defintions:
- extra_model_cls = ExtraModel if self.extra_strict_forbid else AllowedExtraModel
- self.extra_model = create_model(
- model_name="ExtraModel",
- __base__=extra_model_cls,
- **field_defintions
- )
- def rebuild_bet_opt(self, name: str, opt: BetOpt) -> Tuple[FieldInfo, FieldInfo]:
- """
- 重新构建BetOpt
- Args:
- name (str): 字段名称
- opt (BetOpt): 字段元信息
- Returns:
- Tuple[FieldInfo,FieldInfo]: Between查询条件信息
- """
- min_opt = opt.replace(active="min")
- min_fieldinfo = FieldInfo.from_annotation(min_opt.info.annotation)
- min_fieldinfo.json_schema_extra = {"vo_opt": min_opt}
- max_opt = opt.replace(active="max")
- max_fieldinfo = FieldInfo.from_annotation(max_opt.info.annotation)
- max_fieldinfo.json_schema_extra = {"vo_opt": max_opt}
- return min_fieldinfo, max_fieldinfo
- def rebuild_field(self, name: str, info: FieldInfo) -> bool:
- """
- 重置字段
- Args:
- name(str): 字段名称
- info(FieldInfo): 字段元信息
- Returns:
- bool: 是否支持重置
- """
- if self.context.include_fields and \
- name not in self.context.include_fields:
- return False
- if self.context.exclude_fields and \
- name in self.context.exclude_fields:
- return False
- flag = self.field_filter.filter(name, info)
- return flag
- class FormSchemaFactory(AbcSchemaFactory):
- action = "form"
- model_config = strict_valid_config
- def validate_annotation(self, annotation: Type[BaseModel]) -> Optional[Type[BaseModel]]:
- pass
- class PathSchemaFactory(AbcSchemaFactory):
- action = "path"
- model_config = query_valid_config
- def validate_annotation(self, annotation: Type[BaseModel]) -> Optional[Type[BaseModel]]:
- pass
- class ArbitrarySchemaFactory(AbcSchemaFactory):
- model_config = query_valid_config
- def __init__(self):
- super().__init__()
- query_valid_config_copy = query_valid_config.copy()
- query_valid_config_copy.update({
- "arbitrary_types_allowed": True
- })
- self.model_config = query_valid_config_copy
- def validate_annotation(self, annotation: Type) -> Optional[Type[BaseEntity]]:
- bo_model = get_final_model(annotation)
- if issubclass(bo_model, BaseEntity):
- return bo_model
- else:
- return None
|