model.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from collections import UserList
  2. from math import ceil
  3. from typing import List, Set, Type, TypeVar
  4. import typing as t
  5. from flask import abort
  6. from flask_sqlalchemy.model import Model
  7. from flask_sqlalchemy.pagination import Pagination as _Pagination
  8. from sqlalchemy import Column, ScalarSelect
  9. from sqlalchemy.orm.attributes import InstrumentedAttribute
  10. from ruoyi_common.base.model import BaseEntity, DbValidatorContext
  11. T = TypeVar('T', bound=BaseEntity)
  12. class ColumnEntityList(UserList[Column]):
  13. """
  14. 字段名称列表类
  15. """
  16. context_key = "db_columns_alias"
  17. def __init__(self, clz:type[Model], names:Set, alia_prefix:bool=True):
  18. self._clz = clz
  19. self._names = names
  20. self._alia_prefix = alia_prefix
  21. super(ColumnEntityList, self).__init__(self._columns())
  22. def _columns(self)-> List[Column]:
  23. """
  24. 获取字段列表
  25. Returns:
  26. List[Column]: 字段列表
  27. """
  28. columns = []
  29. for name in self._names:
  30. if not hasattr(self._clz, name):
  31. raise AttributeError(
  32. f"column {name} not found in {self._clz.__name__}"
  33. )
  34. column:InstrumentedAttribute = getattr(self._clz, name)
  35. if not isinstance(column, InstrumentedAttribute):
  36. raise AttributeError(
  37. f"column {name} is not a column in {self._clz.__name__}"
  38. )
  39. if self._alia_prefix:
  40. label_name = self.to_label(name)
  41. columns.append(column.label(label_name))
  42. else:
  43. columns.append(column)
  44. return columns
  45. def to_label(self, name:str) -> str:
  46. """
  47. 给字段列表添加别名前缀,生成标签名
  48. Args:
  49. name (str): 字段名
  50. Returns:
  51. str: 标签名
  52. """
  53. return "{}_{}".format(self._clz.__name__, name)
  54. def to_field(self, label:str) -> str:
  55. """
  56. 从标签名删除别名前缀,还原字段名
  57. Args:
  58. label (str): 标签名
  59. Returns:
  60. str: 字段名
  61. """
  62. alias_prefix = self._clz.__name__ + "_"
  63. return label[len(alias_prefix):]
  64. def check_prefix(self, label:str) -> bool:
  65. """
  66. 检查标签名是否以表名开头
  67. Args:
  68. label (str): 标签名
  69. Returns:
  70. bool: 是否以表名开头
  71. """
  72. alias_prefix = self._clz.__name__ + "_"
  73. return label.startswith(alias_prefix)
  74. def cast(self, row, to: Type[T]) -> T:
  75. """
  76. 获取字段值,并转换为指定数据模型对象
  77. Args:
  78. row (dict): 数据库查询结果
  79. to (Type[T]): 数据模型类
  80. Returns:
  81. T: 数据模型对象
  82. """
  83. data = to.model_validate(
  84. row,
  85. from_attributes=True,
  86. context=DbValidatorContext(col_entity_list=self)
  87. )
  88. return data
  89. def append_scalar(self, scalar:ScalarSelect):
  90. """
  91. 追加一个标量字段
  92. Args:
  93. scalar (ScalarSelect): 标量字段
  94. """
  95. if self._alia_prefix:
  96. self._names.add(self.to_label(scalar.name))
  97. else:
  98. self._names.add(scalar.name)
  99. self.append(scalar)