contract_repo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import logging
  2. import json
  3. from datetime import datetime, timedelta
  4. from sqlalchemy import func, select, text
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from sqlalchemy.exc import DBAPIError
  7. from alien_store.db.models.contract_store import ContractStore
  8. logger = logging.getLogger(__name__)
  9. class ContractRepository:
  10. """合同数据访问层"""
  11. def __init__(self, db: AsyncSession):
  12. self.db = db
  13. async def _execute_with_retry(self, statement):
  14. """
  15. 封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。
  16. """
  17. try:
  18. return await self.db.execute(statement)
  19. except Exception as exc:
  20. if self._should_retry(exc):
  21. logger.warning("DB connection invalidated, retrying once: %s", exc)
  22. try:
  23. await self.db.rollback()
  24. except Exception:
  25. pass
  26. return await self.db.execute(statement)
  27. raise
  28. @staticmethod
  29. def _should_retry(exc: Exception) -> bool:
  30. # SQLAlchemy 标准:connection_invalidated=True
  31. if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False):
  32. return True
  33. # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError)
  34. txt = str(exc).lower()
  35. closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"]
  36. return any(k in txt for k in closed_keywords)
  37. async def get_by_store_id(self, store_id: int):
  38. """根据店铺id查询所有合同"""
  39. result = await self._execute_with_retry(
  40. ContractStore.__table__.select().where(ContractStore.store_id == store_id)
  41. )
  42. # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
  43. return [dict(row) for row in result.mappings().all()]
  44. async def check_store_status(self, store_id: int) -> str | None:
  45. """
  46. 检查 store_info 表中对应 store_id 的 reason 字段
  47. """
  48. query = text("SELECT reason FROM store_info WHERE id = :store_id")
  49. result = await self._execute_with_retry(query.bindparams(store_id=store_id))
  50. row = result.fetchone()
  51. return row[0] if row else None
  52. async def get_all(self):
  53. """查询所有合同"""
  54. result = await self._execute_with_retry(ContractStore.__table__.select())
  55. return [dict(row) for row in result.mappings().all()]
  56. async def get_all_paged(self, page: int, page_size: int = 10):
  57. """分页查询所有合同,返回 (items, total)"""
  58. offset = (page - 1) * page_size
  59. # 查询总数
  60. count_result = await self._execute_with_retry(
  61. select(func.count()).select_from(ContractStore.__table__)
  62. )
  63. total = count_result.scalar() or 0
  64. # 查询分页数据
  65. result = await self._execute_with_retry(
  66. ContractStore.__table__.select().offset(offset).limit(page_size)
  67. )
  68. items = [dict(row) for row in result.mappings().all()]
  69. return items, total
  70. async def create(self, user_data):
  71. """创建未签署合同模板"""
  72. db_templates = ContractStore(
  73. store_id=user_data.store_id,
  74. store_name=getattr(user_data, "store_name", None),
  75. merchant_name=user_data.merchant_name,
  76. business_segment=user_data.business_segment,
  77. contact_phone=user_data.contact_phone,
  78. contract_url=user_data.contract_url,
  79. seal_url='0.0',
  80. signing_status='未签署'
  81. )
  82. self.db.add(db_templates)
  83. await self.db.commit()
  84. await self.db.refresh(db_templates)
  85. return db_templates
  86. async def mark_signed_by_phone(self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None):
  87. """
  88. 根据手机号 + sign_flow_id 将合同标记为已签署,只更新匹配的合同项
  89. 当 is_master 为 1 时,更新签署状态和时间字段
  90. 同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
  91. 同时更新 contract_download_url 到对应的字典中
  92. """
  93. result = await self._execute_with_retry(
  94. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  95. )
  96. rows = result.mappings().all()
  97. updated = False
  98. for row in rows:
  99. contract_url_raw = row.get("contract_url")
  100. items = None
  101. if contract_url_raw:
  102. try:
  103. items = json.loads(contract_url_raw)
  104. except Exception:
  105. items = None
  106. changed = False
  107. matched_item = None
  108. if isinstance(items, list):
  109. for item in items:
  110. if item.get("sign_flow_id") == sign_flow_id:
  111. item["status"] = 1
  112. # 更新 contract_download_url
  113. if contract_download_url:
  114. item["contract_download_url"] = contract_download_url
  115. matched_item = item
  116. changed = True
  117. break
  118. # 只有当 is_master 为 1 时才更新时间字段
  119. if changed and matched_item and matched_item.get("is_master") == 1:
  120. # 时间处理
  121. signing_dt = signing_time
  122. effective_dt = expiry_dt = None
  123. if signing_dt:
  124. # effective_time 是 signing_time 第二天的 0 点
  125. effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
  126. expiry_dt = effective_dt + timedelta(days=365)
  127. # 更新 contract_url 中对应字典的时间字段
  128. matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") if signing_dt else ""
  129. matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
  130. matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
  131. await self._execute_with_retry(
  132. ContractStore.__table__.update()
  133. .where(ContractStore.id == row["id"])
  134. .values(
  135. signing_status="已签署",
  136. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  137. signing_time=signing_dt,
  138. effective_time=effective_dt,
  139. expiry_time=expiry_dt,
  140. )
  141. )
  142. updated = True
  143. elif changed:
  144. # is_master 不为 1 时,只更新 status,不更新时间字段
  145. await self._execute_with_retry(
  146. ContractStore.__table__.update()
  147. .where(ContractStore.id == row["id"])
  148. .values(
  149. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  150. )
  151. )
  152. updated = True
  153. if updated:
  154. await self.db.commit()
  155. return updated
  156. async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str):
  157. """
  158. 根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
  159. """
  160. result = await self._execute_with_retry(
  161. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  162. )
  163. rows = result.mappings().all()
  164. updated = False
  165. for row in rows:
  166. contract_url_raw = row.get("contract_url")
  167. if not contract_url_raw:
  168. continue
  169. try:
  170. items = json.loads(contract_url_raw)
  171. except Exception:
  172. items = None
  173. if not isinstance(items, list):
  174. continue
  175. changed = False
  176. for item in items:
  177. if item.get("sign_flow_id") == sign_flow_id:
  178. item["sign_url"] = sign_url
  179. changed = True
  180. if changed:
  181. await self._execute_with_retry(
  182. ContractStore.__table__.update()
  183. .where(ContractStore.id == row["id"])
  184. .values(contract_url=json.dumps(items, ensure_ascii=False))
  185. )
  186. updated = True
  187. if updated:
  188. await self.db.commit()
  189. return updated
  190. async def append_contract_url(self, templates_data, contract_item: dict):
  191. """
  192. 根据手机号,向 contract_url(JSON 列表)追加新的合同信息;
  193. 若手机号不存在,则创建新记录。
  194. """
  195. contact_phone = getattr(templates_data, "contact_phone", None)
  196. result = await self._execute_with_retry(
  197. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  198. )
  199. rows = result.mappings().all()
  200. updated = False
  201. store_name = getattr(templates_data, "store_name", None)
  202. if rows:
  203. for row in rows:
  204. contract_url_raw = row.get("contract_url")
  205. try:
  206. items = json.loads(contract_url_raw) if contract_url_raw else []
  207. except Exception:
  208. items = []
  209. if not isinstance(items, list):
  210. items = []
  211. items.append(contract_item)
  212. update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
  213. if store_name:
  214. update_values["store_name"] = store_name
  215. await self._execute_with_retry(
  216. ContractStore.__table__.update()
  217. .where(ContractStore.id == row["id"])
  218. .values(**update_values)
  219. )
  220. updated = True
  221. if updated:
  222. await self.db.commit()
  223. return updated
  224. # 未找到则创建新记录
  225. new_record = ContractStore(
  226. store_id=getattr(templates_data, "store_id", None),
  227. store_name=store_name,
  228. business_segment=getattr(templates_data, "business_segment", None),
  229. merchant_name=getattr(templates_data, "merchant_name", None),
  230. contact_phone=contact_phone,
  231. contract_url=json.dumps([contract_item], ensure_ascii=False),
  232. seal_url='0.0',
  233. signing_status='未签署'
  234. )
  235. self.db.add(new_record)
  236. await self.db.commit()
  237. await self.db.refresh(new_record)
  238. return True