| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- from abc import ABC, abstractmethod
- from dataclasses import dataclass
- from typing import ClassVar, Dict, Iterable, Set, Type
- from flask import g, request
- from pydantic import AliasChoices, AliasPath, BaseModel
- from werkzeug.datastructures import ImmutableMultiDict
- from werkzeug.exceptions import BadRequest,UnsupportedMediaType
- from ruoyi_common.base.model import BaseEntity, CriterianMeta, ExtraModel, \
- BaseEntity, OrderModel, PageModel, VoValidatorContext
- from ruoyi_common.base.schema_vo import BaseSchemaFactory, QuerySchemaFactory
- class AbsReqParser(ABC):
-
- @abstractmethod
- def data(self) -> Dict:
- """
- 获取请求参数
- Returns:
- Dict: 请求参数字典
- """
-
- @abstractmethod
- def cast_model(self, bo_model:BaseEntity) -> BaseModel:
- """
- 适配模型
- Args:
- bo_model (BaseEntity): Vo模型
- src_model (BaseModel): 源模型
- Returns:
- BaseModel: 适配后的模型
- """
-
- @abstractmethod
- def prepare_factory(self, factory:BaseSchemaFactory):
- """
- 准备工厂
- Args:
- factory (BaseSchemaFactory): 工厂
- """
-
- @abstractmethod
- def prepare(self):
- """
- 准备数据
- """
- class BaseReqParser(AbsReqParser):
-
- def data(self) -> Dict:
- pass
-
- def cast_model(self, bo_model:BaseEntity) -> BaseModel:
- pass
-
- def prepare_factory(self, factory:BaseSchemaFactory):
- pass
- def prepare(self):
- pass
-
- class QueryReqParser(BaseReqParser):
- def __init__(self, context:VoValidatorContext):
- self.context = context
- self.extra_model = ExtraModel
-
- def prepare_factory(self, factory: QuerySchemaFactory):
- if factory.extra_model:
- self.extra_model = factory.extra_model
-
- def prepare(self):
- self.criterian_meta = CriterianMeta()
- g.criterian_meta = self.criterian_meta
-
- def validate_request(self) -> Dict:
- return request.args.to_dict()
-
- def data(self) -> Dict:
- data = self.validate_request().copy()
- if self.context.is_page:
- page = PageModel.model_validate(data,context=self.context)
- if page.model_fields_set:
- self.criterian_meta.page = page
- self._remove_model_aliases(data, PageModel)
- if self.context.is_sort:
- sort = OrderModel.model_validate(data,context=self.context)
- if sort.model_fields_set:
- self.criterian_meta.sort = sort
- self._remove_model_aliases(data, OrderModel)
- if self.extra_model:
- extra = self.extra_model.model_validate(data,context=self.context)
- if extra.model_fields_set:
- self.criterian_meta.extra = extra
- self._remove_model_aliases(data, self.extra_model)
- return data
-
- def cast_model(self, bo_model:BaseEntity) -> BaseModel:
- data = self.data()
- bo = bo_model.model_validate(data)
- return bo
- def _remove_model_aliases(
- self,
- data: Dict[str, str],
- model_cls: Type[BaseModel]
- ) -> None:
- """
- 删除已经用于解析的模型字段别名,避免后续模型校验时报额外字段错误
- """
- if not data:
- return
- for alias in self._collect_aliases(model_cls):
- data.pop(alias, None)
- def _collect_aliases(self, model_cls: Type[BaseModel]) -> Set[str]:
- alias_set: Set[str] = set()
- populate_by_name = getattr(model_cls, "model_config", {}).get(
- "populate_by_name", False
- )
- for name, info in model_cls.model_fields.items():
- alias_set.update(self._field_aliases(name, info, populate_by_name))
- return alias_set
- @staticmethod
- def _field_aliases(
- name: str,
- info,
- populate_by_name: bool
- ) -> Iterable[str]:
- aliases: Set[str] = set()
- if getattr(info, "alias", None):
- aliases.add(info.alias)
- validation_alias = getattr(info, "validation_alias", None)
- if isinstance(validation_alias, str):
- aliases.add(validation_alias)
- elif isinstance(validation_alias, AliasChoices):
- aliases.update(validation_alias.choices)
- elif isinstance(validation_alias, AliasPath):
- pass
- if populate_by_name:
- aliases.add(name)
- return aliases
-
-
- @dataclass
- class PathReqParser(BaseReqParser):
-
- def data(self) -> Dict:
- return request.view_args.copy()
-
- @dataclass
- class BodyReqParser(BaseReqParser):
-
- minetype: ClassVar[str] = "application/json"
- def __init__(self, context:VoValidatorContext):
- self.context = context
-
- def validate_request(self) -> Dict:
- content_type = request.headers.get("Content-Type", "").lower()
- minetype = content_type.split(";")[0]
- if minetype == self.minetype:
- body: dict | list = request.get_json()
- if not body:
- raise BadRequest(
- description="在{}, body数据不能为空".format(content_type),
- )
- else:
- raise UnsupportedMediaType(
- description="content-type仅支持application/json"
- )
- return body
-
- def data(self) -> Dict:
- data = self.validate_request().copy()
- return data
- def cast_model(self, bo_model:BaseEntity) -> BaseModel:
- data = self.data()
- bo = bo_model.model_validate(data, context=self.context)
- return bo
- @dataclass
- class FormUrlencodedQueryReqParser(QueryReqParser):
-
- minetype: ClassVar[str] = "application/x-www-form-urlencoded"
-
- def __init__(self, context:VoValidatorContext):
- super().__init__(context)
-
- def validate_request(self) -> Dict:
- content_type = request.headers.get("Content-Type", "").lower()
- minetype = content_type.split(";")[0]
- if minetype == self.minetype:
- form:ImmutableMultiDict = request.form
- body = form.to_dict()
- else:
- raise UnsupportedMediaType(
- description="除了{},content-type不支持{}".format(self.minetype,minetype)
- )
- return body
- @dataclass
- class DownloadFileQueryReqParser(FormUrlencodedQueryReqParser):
-
- def __init__(self, context:VoValidatorContext):
- super().__init__(context)
- class FormReqParser(BaseReqParser):
-
- minetype: ClassVar[str] = "multipart/form-data"
-
- def __init__(
- self,
- is_form:bool=True,
- is_query:bool=False,
- is_file:bool|None=None,
- ):
- self.is_form = is_form
- self.is_query = is_query
- self.is_file = is_file
-
- def validate_request(self) -> Dict:
- content_type = request.headers.get("Content-Type", "").lower()
- minetype = content_type.split(";")[0]
- new_data = {}
- if minetype == self.minetype:
- if self.is_form:
- new_data.update(request.form.to_dict())
- if self.is_query:
- new_data.update(request.args.to_dict())
- if self.is_file:
- new_data.update(request.files.to_dict(flat=False))
- else:
- raise UnsupportedMediaType(
- description="除了{},content-type不支持{}".format(self.minetype,minetype)
- )
- return new_data
-
- def data(self) -> Dict:
- data = self.validate_request()
- return data
-
- class UploadFileFormReqParser(FormReqParser):
-
- def validate_request(self) -> Dict:
- return super().validate_request()
- def data(self) -> Dict:
- data = self.validate_request()
- return data
-
-
- class StreamReqParser(BaseReqParser):
- minetype: ClassVar[str] = "application/octet-stream"
- def data(self, *args, **kwargs) -> Dict:
- pass
|