| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- import inspect
- from abc import ABC, abstractmethod
- from functools import wraps
- from dataclasses import dataclass, field
- from typing import Annotated, Any, Callable, Dict, Tuple, Type, ClassVar, \
- Optional, Set
- from werkzeug.exceptions import BadRequest, InternalServerError
- from flask import has_request_context, request
- from pydantic import BaseModel, ValidationError, validate_call
- from pydantic.fields import FieldInfo
- from ruoyi_common.base.model import MultiFile
- from ruoyi_common.base.reqparser import BaseReqParser, BodyReqParser, \
- DownloadFileQueryReqParser, UploadFileFormReqParser, PathReqParser, \
- QueryReqParser, VoValidatorContext
- from ruoyi_common.base.schema_vo import ArbitrarySchemaFactory, \
- BaseSchemaFactory, BodySchemaFactory, PathSchemaFactory, QuerySchemaFactory
- class AbcValidatorFunction(ABC):
- @abstractmethod
- def validate_unbound_parameters(self):
- raise NotImplementedError()
- @abstractmethod
- def validate_function(self):
- raise NotImplementedError()
- @abstractmethod
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- raise NotImplementedError()
- class ValidatorScopeFunction(AbcValidatorFunction):
- def __init__(self,func:Callable):
- self.func = func
- self.sig = inspect.signature(self.func)
- self._unbound_fields:Dict[str,Annotated] = {}
- self._unbound_model: \
- Optional[Tuple[str,Type[BaseModel],Type[BaseModel]]] = None
- self.args = ()
- self.kwargs = {}
- self.validate_unbound_parameters()
- self.validate_function()
- @property
- def unbound_model(self):
- return self._unbound_model
- def _validate_kind(self,kind):
- if kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
- raise Exception("参数必须是位置参数")
- def validate_unbound_parameters(self):
- index = 0
- for key in self.sig.parameters:
- param = self.sig.parameters[key]
- self._validate_kind(param.kind)
- if isinstance(param.annotation, BaseModel):
- self._unbound_model = (key,param.annotation)
- if index > 0:
- raise Exception(
- f"{self.func.__name__} 类型参数有且仅有第一个"
- )
- else:
- self._unbound_fields[key] = \
- FieldInfo.from_annotation(param.annotation)
- index += 1
- def validate_function(self):
- self.func = validate_call(self.func)
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- self.args = args if args else ()
- self.kwargs = kwargs if kwargs else {}
- return self.func(*self.args, **self.kwargs)
- class ValidatorViewFunction(AbcValidatorFunction):
- def __init__(self,func:Callable):
- self.func = func
- self.sig = inspect.signature(self.func)
- self._unbound_fields:Dict[str,Annotated] = {}
- self._unbound_model: \
- Optional[Tuple[str,Type[BaseModel],Type[BaseModel]]] = None
- self._schema_factory = None
- self._data_parser = None
- self.args = ()
- self.kwargs = {}
- @property
- def unbound_model(self):
- return self._unbound_model
- def _validate_kind(self,kind):
- if kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
- raise Exception("参数必须是位置参数")
- def validate_unbound_parameters(self):
- index = 0
- for key in self.sig.parameters:
- param = self.sig.parameters[key]
- self._validate_kind(param.kind)
- if self._schema_factory:
- annotation = self._schema_factory. \
- validate_annotation(param.annotation)
- if annotation:
- self._unbound_model = (key,annotation)
- if index > 0:
- raise Exception(
- f"{self.func.__name__} 类型参数有且仅有第一个"
- )
- else:
- self._unbound_fields[key] = \
- FieldInfo.from_annotation(param.annotation)
- else:
- if isinstance(param.annotation, BaseModel):
- self._unbound_model = (key,param.annotation)
- if index > 0:
- raise Exception(
- f"{self.func.__name__} 类型参数有且仅有第一个"
- )
- else:
- self._unbound_fields[key] = \
- FieldInfo.from_annotation(param.annotation)
- index += 1
- def validate_function(self):
- if self._schema_factory:
- if self._unbound_model:
- self.func = validate_call(
- self.func,
- config=self._schema_factory.model_config)
- else:
- self.func = validate_call(
- self.func,
- config=self._schema_factory.model_config
- )
- else:
- self.func = validate_call(self.func)
- def unbound_data(
- self,
- data_parser: BaseReqParser
- ):
- self._data_parser = data_parser
- if self._schema_factory and self._data_parser:
- self._data_parser.prepare_factory(self._schema_factory)
- def unbound_schema(
- self,
- schema_factory:Optional[BaseSchemaFactory],
- ):
- self._schema_factory = schema_factory
- self.validate_unbound_parameters()
- self.validate_function()
- def bound_data(self, args:Tuple=(), kwargs:Dict={}):
- if self._unbound_model:
- key, bo_model = self._unbound_model
- obj = self._data_parser.cast_model(bo_model)
- kwargs[key] = obj
- else:
- # 特殊处理文件上传参数(MultiFile),直接从 request.files 构造
- has_multi_file_param = any(
- param.annotation is MultiFile
- for param in self.sig.parameters.values()
- )
- if has_multi_file_param:
- files = MultiFile.from_obj(request.files)
- kwargs.clear()
- for name, param in self.sig.parameters.items():
- if param.annotation is MultiFile:
- kwargs[name] = files
- return
- data = self._data_parser.data()
- kwargs.clear()
- kwargs.update(data)
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- self.args = args if args else ()
- self.kwargs = kwargs if kwargs else {}
- if not has_request_context:
- raise Exception("请在flask请求上下文中调用")
- try:
- if self._data_parser:
- self._data_parser.prepare()
- self.bound_data(self.args, self.kwargs)
- except ValidationError as e:
- return BadRequest(description=str(e))
- except TypeError as e:
- return BadRequest(description=str(e))
- except Exception as e:
- return InternalServerError(description=str(e))
- else:
- return self.func(*self.args, **self.kwargs)
- @dataclass
- class BaseValidator:
- data_parser:ClassVar = None
- schema_factory:ClassVar = None
- vo_context:VoValidatorContext = field(init=False)
- def __call__(self, func):
- view_function = ValidatorViewFunction(func)
- view_function.unbound_schema(self.schema_factory)
- view_function.unbound_data(self.data_parser)
- @wraps(func)
- def wrapper(*args, **kwargs):
- return view_function(*args, **kwargs)
- return wrapper
- @dataclass
- class PathValidator(BaseValidator):
- def __post_init__(self):
- self.schema_factory = PathSchemaFactory()
- self.data_parser = PathReqParser()
- @dataclass
- class QueryValidator(BaseValidator):
- is_page: bool = False
- include:Optional[Set[str]] = field(default=None)
- exclude:Optional[Set[str]] = field(default=None)
- extra_fields:Optional[Dict[str, FieldInfo]] = field(default=None)
- def __post_init__(self):
- vo_context = VoValidatorContext(
- exclude_data_alias=True,
- is_page=self.is_page,
- is_sort=self.is_page,
- include_fields=self.include,
- exclude_fields=self.exclude,
- )
- self.schema_factory = QuerySchemaFactory(
- vo_context,
- extra_strict_forbid=True,
- extra_allowed_fields=self.extra_fields
- )
- self.data_parser = QueryReqParser(vo_context)
- @dataclass
- class BodyValidator(BaseValidator):
- include:Optional[Set[str]] = field(default=None)
- exclude:Optional[Set[str]] = field(default=None)
- def __post_init__(self):
- vo_context = VoValidatorContext(
- include_fields=self.include,
- exclude_fields=self.exclude
- )
- self.schema_factory = BodySchemaFactory(vo_context)
- self.data_parser = BodyReqParser(vo_context)
- @dataclass
- class FileDownloadValidator(BaseValidator):
- def __post_init__(self):
- vo_context = VoValidatorContext(
- exclude_data_alias=True,
- is_page=True,
- )
- self.schema_factory = QuerySchemaFactory(vo_context)
- self.data_parser = DownloadFileQueryReqParser(vo_context)
- @dataclass
- class FileUploadValidator(BaseValidator):
- def __post_init__(self):
- self.schema_factory = ArbitrarySchemaFactory()
- self.data_parser = UploadFileFormReqParser(
- is_form=False, is_query=True, is_file=True
- )
- @dataclass
- class FileValidator(BaseValidator):
- include:Optional[Set[str]] = field(default=None)
- def __post_init__(self):
- self.schema_factory = ArbitrarySchemaFactory()
- self.data_parser = UploadFileFormReqParser()
|