contract_repo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import logging
  2. import json
  3. from datetime import datetime, timedelta
  4. from sqlalchemy import func, select, text
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from sqlalchemy.exc import DBAPIError
  7. from alien_store.db.models.contract_store import ContractStore
  8. logger = logging.getLogger(__name__)
  9. class ContractRepository:
  10. """合同数据访问层"""
  11. def __init__(self, db: AsyncSession):
  12. self.db = db
  13. async def _execute_with_retry(self, statement):
  14. """
  15. 封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。
  16. """
  17. try:
  18. return await self.db.execute(statement)
  19. except Exception as exc:
  20. if self._should_retry(exc):
  21. logger.warning("DB connection invalidated, retrying once: %s", exc)
  22. try:
  23. await self.db.rollback()
  24. except Exception:
  25. pass
  26. return await self.db.execute(statement)
  27. raise
  28. @staticmethod
  29. def _should_retry(exc: Exception) -> bool:
  30. # SQLAlchemy 标准:connection_invalidated=True
  31. if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False):
  32. return True
  33. # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError)
  34. txt = str(exc).lower()
  35. closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"]
  36. return any(k in txt for k in closed_keywords)
  37. async def get_by_store_id(self, store_id: int):
  38. """根据店铺id查询所有合同"""
  39. result = await self._execute_with_retry(
  40. ContractStore.__table__.select().where(ContractStore.store_id == store_id)
  41. )
  42. # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
  43. return [dict(row) for row in result.mappings().all()]
  44. async def check_store_status(self, store_id: int) -> str | None:
  45. """
  46. 检查 store_info 表中对应 store_id 的 reason 字段
  47. """
  48. query = text("SELECT reason FROM store_info WHERE id = :store_id")
  49. result = await self._execute_with_retry(query.bindparams(store_id=store_id))
  50. row = result.fetchone()
  51. return row[0] if row else None
  52. async def get_contract_item_by_sign_flow_id(self, sign_flow_id: str):
  53. """
  54. 根据 sign_flow_id 查找合同项,返回 (row, item, items)
  55. """
  56. result = await self._execute_with_retry(ContractStore.__table__.select())
  57. rows = result.mappings().all()
  58. for row in rows:
  59. contract_url_raw = row.get("contract_url")
  60. if not contract_url_raw:
  61. continue
  62. try:
  63. items = json.loads(contract_url_raw)
  64. except Exception:
  65. items = None
  66. if not isinstance(items, list):
  67. continue
  68. for item in items:
  69. if item.get("sign_flow_id") == sign_flow_id:
  70. return dict(row), item, items
  71. return None, None, None
  72. async def update_contract_items(self, row_id: int, items: list) -> bool:
  73. """
  74. 更新指定记录的 contract_url 列表
  75. """
  76. if not isinstance(items, list):
  77. return False
  78. await self._execute_with_retry(
  79. ContractStore.__table__.update()
  80. .where(ContractStore.id == row_id)
  81. .values(contract_url=json.dumps(items, ensure_ascii=False))
  82. )
  83. await self.db.commit()
  84. return True
  85. async def get_all(self):
  86. """查询所有合同"""
  87. result = await self._execute_with_retry(ContractStore.__table__.select())
  88. return [dict(row) for row in result.mappings().all()]
  89. async def get_all_paged(self, page: int, page_size: int = 10):
  90. """分页查询所有合同,返回 (items, total)"""
  91. offset = (page - 1) * page_size
  92. # 查询总数
  93. count_result = await self._execute_with_retry(
  94. select(func.count()).select_from(ContractStore.__table__)
  95. )
  96. total = count_result.scalar() or 0
  97. # 查询分页数据
  98. result = await self._execute_with_retry(
  99. ContractStore.__table__.select().offset(offset).limit(page_size)
  100. )
  101. items = [dict(row) for row in result.mappings().all()]
  102. return items, total
  103. async def create(self, user_data):
  104. """创建未签署合同模板"""
  105. db_templates = ContractStore(
  106. store_id=user_data.store_id,
  107. store_name=getattr(user_data, "store_name", None),
  108. merchant_name=user_data.merchant_name,
  109. business_segment=user_data.business_segment,
  110. contact_phone=user_data.contact_phone,
  111. contract_url=user_data.contract_url,
  112. seal_url='0.0',
  113. signing_status='未签署'
  114. )
  115. self.db.add(db_templates)
  116. await self.db.commit()
  117. await self.db.refresh(db_templates)
  118. return db_templates
  119. async def mark_signed_by_phone(self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None):
  120. """
  121. 根据手机号 + sign_flow_id 将合同标记为已签署,只更新匹配的合同项
  122. 当 is_master 为 1 时,更新签署状态和时间字段
  123. 同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
  124. 同时更新 contract_download_url 到对应的字典中
  125. """
  126. result = await self._execute_with_retry(
  127. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  128. )
  129. rows = result.mappings().all()
  130. updated = False
  131. for row in rows:
  132. contract_url_raw = row.get("contract_url")
  133. items = None
  134. if contract_url_raw:
  135. try:
  136. items = json.loads(contract_url_raw)
  137. except Exception:
  138. items = None
  139. changed = False
  140. matched_item = None
  141. if isinstance(items, list):
  142. for item in items:
  143. if item.get("sign_flow_id") == sign_flow_id:
  144. item["status"] = 1
  145. # 更新 contract_download_url
  146. if contract_download_url:
  147. item["contract_download_url"] = contract_download_url
  148. matched_item = item
  149. changed = True
  150. break
  151. # 只有当 is_master 为 1 时才更新时间字段
  152. if changed and matched_item and matched_item.get("is_master") == 1:
  153. # 时间处理
  154. signing_dt = signing_time
  155. effective_dt = expiry_dt = None
  156. if signing_dt:
  157. # effective_time 是 signing_time 第二天的 0 点
  158. effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
  159. expiry_dt = effective_dt + timedelta(days=365)
  160. # 更新 contract_url 中对应字典的时间字段
  161. matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") if signing_dt else ""
  162. matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
  163. matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
  164. await self._execute_with_retry(
  165. ContractStore.__table__.update()
  166. .where(ContractStore.id == row["id"])
  167. .values(
  168. signing_status="已签署",
  169. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  170. signing_time=signing_dt,
  171. effective_time=effective_dt,
  172. expiry_time=expiry_dt,
  173. )
  174. )
  175. updated = True
  176. elif changed:
  177. # is_master 不为 1 时,只更新 status,不更新时间字段
  178. await self._execute_with_retry(
  179. ContractStore.__table__.update()
  180. .where(ContractStore.id == row["id"])
  181. .values(
  182. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  183. )
  184. )
  185. updated = True
  186. if updated:
  187. await self.db.commit()
  188. return updated
  189. async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str):
  190. """
  191. 根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
  192. """
  193. result = await self._execute_with_retry(
  194. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  195. )
  196. rows = result.mappings().all()
  197. updated = False
  198. for row in rows:
  199. contract_url_raw = row.get("contract_url")
  200. if not contract_url_raw:
  201. continue
  202. try:
  203. items = json.loads(contract_url_raw)
  204. except Exception:
  205. items = None
  206. if not isinstance(items, list):
  207. continue
  208. changed = False
  209. for item in items:
  210. if item.get("sign_flow_id") == sign_flow_id:
  211. item["sign_url"] = sign_url
  212. changed = True
  213. if changed:
  214. await self._execute_with_retry(
  215. ContractStore.__table__.update()
  216. .where(ContractStore.id == row["id"])
  217. .values(contract_url=json.dumps(items, ensure_ascii=False))
  218. )
  219. updated = True
  220. if updated:
  221. await self.db.commit()
  222. return updated
  223. async def append_contract_url(self, templates_data, contract_item: dict):
  224. """
  225. 根据手机号,向 contract_url(JSON 列表)追加新的合同信息;
  226. 若手机号不存在,则创建新记录。
  227. """
  228. contact_phone = getattr(templates_data, "contact_phone", None)
  229. result = await self._execute_with_retry(
  230. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  231. )
  232. rows = result.mappings().all()
  233. updated = False
  234. store_name = getattr(templates_data, "store_name", None)
  235. if rows:
  236. for row in rows:
  237. contract_url_raw = row.get("contract_url")
  238. try:
  239. items = json.loads(contract_url_raw) if contract_url_raw else []
  240. except Exception:
  241. items = []
  242. if not isinstance(items, list):
  243. items = []
  244. items.append(contract_item)
  245. update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
  246. if store_name:
  247. update_values["store_name"] = store_name
  248. await self._execute_with_retry(
  249. ContractStore.__table__.update()
  250. .where(ContractStore.id == row["id"])
  251. .values(**update_values)
  252. )
  253. updated = True
  254. if updated:
  255. await self.db.commit()
  256. return updated
  257. # 未找到则创建新记录
  258. new_record = ContractStore(
  259. store_id=getattr(templates_data, "store_id", None),
  260. store_name=store_name,
  261. business_segment=getattr(templates_data, "business_segment", None),
  262. merchant_name=getattr(templates_data, "merchant_name", None),
  263. contact_phone=contact_phone,
  264. contract_url=json.dumps([contract_item], ensure_ascii=False),
  265. seal_url='0.0',
  266. signing_status='未签署'
  267. )
  268. self.db.add(new_record)
  269. await self.db.commit()
  270. await self.db.refresh(new_record)
  271. return True