query.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from collections import UserDict,UserList
  2. from contextlib import contextmanager
  3. from threading import Lock
  4. from pydantic.dataclasses import dataclass
  5. import sqlalchemy.orm as sa_orm
  6. import sqlalchemy as sa
  7. from sqlalchemy.orm import Session
  8. DEFAULT_PAGE_SIZE = 10
  9. DEFAULT_PAGE_NUM = 1
  10. class WriteReadLock:
  11. """读写锁实现-写优先"""
  12. def __init__(self):
  13. self._read_lock = Lock()
  14. self._write_lock = Lock()
  15. self._write_count = 0
  16. @contextmanager
  17. def write_lock(self):
  18. """
  19. 获取写锁
  20. """
  21. try:
  22. with self._write_lock:
  23. self._write_count += 1
  24. if self._write_count == 1:
  25. self._read_lock.acquire()
  26. yield
  27. finally:
  28. with self._write_lock:
  29. self._write_count -= 1
  30. if self._write_count == 0:
  31. self._read_lock.release()
  32. @contextmanager
  33. def read_lock(self):
  34. """
  35. 获取读锁
  36. """
  37. try:
  38. self._read_lock.acquire()
  39. yield
  40. finally:
  41. self._read_lock.release()
  42. class SafeUserDict(UserDict):
  43. """使用读写锁的线程安全字典"""
  44. def __init__(self, *args, **kwargs):
  45. self._lock = WriteReadLock()
  46. super().__init__(*args, **kwargs)
  47. @contextmanager
  48. def write(self):
  49. with self._lock.write_lock():
  50. yield self
  51. @contextmanager
  52. def read(self):
  53. with self._lock.read_lock():
  54. yield self
  55. class CriterianDict(SafeUserDict):
  56. pass
  57. @dataclass
  58. class Pagination:
  59. page_size: int
  60. page_num: int
  61. def __post_init__(self):
  62. self.page_size = self.page_size or DEFAULT_PAGE_SIZE
  63. self.page_num = self.page_num or DEFAULT_PAGE_NUM
  64. @property
  65. def offset(self) -> int:
  66. '''
  67. 偏移量
  68. Returns:
  69. int: 偏移量
  70. '''
  71. return (self.page_num - 1) * self.page_size
  72. def compute_count(self,stmt:sa.Select,session:Session) -> int:
  73. """
  74. 计算总数
  75. Args:
  76. stmt (sa.Select): 选择表达式
  77. session (Session): 数据库会话
  78. Returns:
  79. int: 总数
  80. """
  81. sub = stmt.options(sa_orm.lazyload("*")).order_by(None).subquery()
  82. count_stmt = sa.select(sa.func.count()).select_from(sub)
  83. return session.execute(count_stmt).scalar_one_or_none() or 0
  84. def rebuild(self,stmt:sa.Select) -> sa.Select:
  85. """
  86. 重新构建选择表达式
  87. Args:
  88. stmt (sa.Select): 选择表达式
  89. Returns:
  90. sa.Select: 选择表达式
  91. """
  92. new_stmt = stmt.limit(self.page_size).offset(self.offset)
  93. return new_stmt