transaction.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from functools import wraps
  2. from typing import Any
  3. from enum import Enum
  4. from sqlalchemy.orm.scoping import scoped_session
  5. from sqlalchemy.orm.session import Session
  6. class Propagation(Enum):
  7. # 不存在事务,则创建事务,如果存在事务,则加入该事务
  8. REQUIRED = 'REQUIRED'
  9. # 不存在事务,则创建新事务,如果存在事务,则创建新嵌套事务
  10. REQUIRES_NEW = 'REQUIRES_NEW'
  11. # 创建新嵌套事务
  12. NESTED = 'NESTED'
  13. # 存在事务或者不存在事务,都执行该操作
  14. SUPPORTS = 'SUPPORTS'
  15. # 如果存在事务,则立即结束事务
  16. NOT_SUPPORTED = 'NOT_SUPPORTED'
  17. # 如果不存在事务,则抛出异常
  18. MANDATORY = 'MANDATORY'
  19. # 如果存在事务,则抛出异常
  20. NEVER = 'NEVER'
  21. class Transactional:
  22. def __init__(self,
  23. session:scoped_session,
  24. propagation=Propagation.REQUIRED,
  25. ):
  26. self.session = session
  27. self.propagation = propagation
  28. def __call__(self, func) -> Any:
  29. @wraps(func)
  30. def wrapper(*args, **kwargs):
  31. return TransactionWrapper(
  32. func,
  33. session=self.session,
  34. propagation=self.propagation
  35. )(*args, **kwargs)
  36. return wrapper
  37. def prepare(self):
  38. pass
  39. class TransactionWrapper:
  40. def __init__(self,func ,session:scoped_session | Session, propagation:Propagation=Propagation.REQUIRED):
  41. self.func = func
  42. self.session:Session = session._proxied \
  43. if isinstance(session, scoped_session) else session
  44. self.propagation = propagation
  45. def top_transaction(self) -> bool:
  46. '''
  47. 判断当前事务是否为顶层事务
  48. :return: bool
  49. '''
  50. return self.session.get_transaction() is self.session._transaction
  51. def prepare_transaction(self):
  52. '''
  53. 准备工作:提前关闭顶层查询事务
  54. '''
  55. if self.session.in_transaction() and self.top_transaction():
  56. if self.session._trans_context_manager is None:
  57. self.session.close()
  58. def __call__(self, *args, **kwargs) -> Any:
  59. if self.propagation == Propagation.REQUIRED:
  60. self.prepare_transaction()
  61. if self.session.in_transaction():
  62. rv = self.func(*args, **kwargs)
  63. else:
  64. with self.session.begin():
  65. rv = self.func(*args, **kwargs)
  66. elif self.propagation == Propagation.REQUIRES_NEW:
  67. if self.session.in_transaction():
  68. with self.session.begin_nested():
  69. rv = self.func(*args, **kwargs)
  70. else:
  71. with self.session.begin():
  72. rv = self.func(*args, **kwargs)
  73. elif self.propagation == Propagation.NESTED:
  74. with self.session.begin_nested():
  75. rv = self.func(*args, **kwargs)
  76. elif self.propagation == Propagation.SUPPORTS:
  77. if self.session.in_transaction():
  78. rv = self.func(*args, **kwargs)
  79. else:
  80. rv = self.func(*args, **kwargs)
  81. elif self.propagation == Propagation.NOT_SUPPORTED:
  82. if self.session.in_transaction():
  83. self.session.commit() # Or handle appropriately
  84. rv = self.func(*args, **kwargs)
  85. elif self.propagation == Propagation.MANDATORY:
  86. if self.session.in_transaction():
  87. rv = self.func(*args, **kwargs)
  88. else:
  89. raise Exception("No existing transaction found")
  90. elif self.propagation == Propagation.NEVER:
  91. if self.session.in_transaction():
  92. raise Exception("Existing transaction found")
  93. rv = self.func(None, *args, **kwargs)
  94. else:
  95. raise ValueError(f"Unknown propagation level: {self.propagation}")
  96. return rv