contract_repo.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import logging
  2. import json
  3. from datetime import datetime, timedelta
  4. from sqlalchemy import func, select, text, or_
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from sqlalchemy.exc import DBAPIError
  7. from alien_store.db.models.contract_store import ContractStore
  8. logger = logging.getLogger(__name__)
  9. class ContractRepository:
  10. """合同数据访问层"""
  11. def __init__(self, db: AsyncSession):
  12. self.db = db
  13. async def _execute_with_retry(self, statement):
  14. """
  15. 封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。
  16. """
  17. try:
  18. return await self.db.execute(statement)
  19. except Exception as exc:
  20. if self._should_retry(exc):
  21. logger.warning("DB connection invalidated, retrying once: %s", exc)
  22. try:
  23. await self.db.rollback()
  24. except Exception:
  25. pass
  26. return await self.db.execute(statement)
  27. raise
  28. @staticmethod
  29. def _should_retry(exc: Exception) -> bool:
  30. # SQLAlchemy 标准:connection_invalidated=True
  31. if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False):
  32. return True
  33. # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError)
  34. txt = str(exc).lower()
  35. closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"]
  36. return any(k in txt for k in closed_keywords)
  37. async def get_by_store_id(self, store_id: int):
  38. """根据店铺id查询所有合同"""
  39. result = await self._execute_with_retry(
  40. ContractStore.__table__.select().where(ContractStore.store_id == store_id)
  41. )
  42. # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
  43. return [dict(row) for row in result.mappings().all()]
  44. async def check_store_status(self, store_id: int) -> str | None:
  45. """
  46. 检查 store_info 表中对应 store_id 的 reason 字段
  47. """
  48. query = text("SELECT reason FROM store_info WHERE id = :store_id")
  49. result = await self._execute_with_retry(query.bindparams(store_id=store_id))
  50. row = result.fetchone()
  51. return row[0] if row else None
  52. async def get_contract_item_by_sign_flow_id(self, sign_flow_id: str):
  53. """
  54. 根据 sign_flow_id 查找合同项,返回 (row, item, items)
  55. """
  56. result = await self._execute_with_retry(ContractStore.__table__.select())
  57. rows = result.mappings().all()
  58. for row in rows:
  59. contract_url_raw = row.get("contract_url")
  60. if not contract_url_raw:
  61. continue
  62. try:
  63. items = json.loads(contract_url_raw)
  64. except Exception:
  65. items = None
  66. if not isinstance(items, list):
  67. continue
  68. for item in items:
  69. if item.get("sign_flow_id") == sign_flow_id:
  70. return dict(row), item, items
  71. return None, None, None
  72. async def update_contract_items(self, row_id: int, items: list) -> bool:
  73. """
  74. 更新指定记录的 contract_url 列表
  75. """
  76. if not isinstance(items, list):
  77. return False
  78. await self._execute_with_retry(
  79. ContractStore.__table__.update()
  80. .where(ContractStore.id == row_id)
  81. .values(contract_url=json.dumps(items, ensure_ascii=False))
  82. )
  83. await self.db.commit()
  84. return True
  85. async def get_all(self):
  86. """查询所有合同"""
  87. result = await self._execute_with_retry(ContractStore.__table__.select())
  88. return [dict(row) for row in result.mappings().all()]
  89. async def get_all_paged(
  90. self,
  91. page: int,
  92. page_size: int = 10,
  93. store_name: str | None = None,
  94. merchant_name: str | None = None,
  95. signing_status: str | None = None,
  96. business_segment: str | None = None,
  97. store_status: str | None = None,
  98. expiry_start: datetime | None = None,
  99. expiry_end: datetime | None = None,
  100. ):
  101. """分页查询所有合同,支持筛选,返回 (items, total)"""
  102. offset = (page - 1) * page_size
  103. table = ContractStore.__table__
  104. conditions = []
  105. if store_name:
  106. conditions.append(table.c.store_name.like(f"%{store_name}%"))
  107. if merchant_name:
  108. conditions.append(table.c.merchant_name.like(f"%{merchant_name}%"))
  109. if signing_status:
  110. conditions.append(table.c.signing_status == signing_status)
  111. if business_segment:
  112. conditions.append(table.c.business_segment == business_segment)
  113. if store_status:
  114. if store_status == "正常":
  115. conditions.append(table.c.signing_status == "已签署")
  116. elif store_status == "禁用":
  117. conditions.append(
  118. or_(table.c.signing_status != "已签署", table.c.signing_status.is_(None))
  119. )
  120. if expiry_start:
  121. conditions.append(table.c.expiry_time >= expiry_start)
  122. if expiry_end:
  123. conditions.append(table.c.expiry_time <= expiry_end)
  124. # 查询总数
  125. count_stmt = select(func.count()).select_from(table)
  126. if conditions:
  127. count_stmt = count_stmt.where(*conditions)
  128. count_result = await self._execute_with_retry(count_stmt)
  129. total = count_result.scalar() or 0
  130. # 查询分页数据
  131. data_stmt = table.select()
  132. if conditions:
  133. data_stmt = data_stmt.where(*conditions)
  134. data_stmt = data_stmt.offset(offset).limit(page_size)
  135. result = await self._execute_with_retry(data_stmt)
  136. items = [dict(row) for row in result.mappings().all()]
  137. return items, total
  138. async def create(self, user_data):
  139. """创建未签署合同模板"""
  140. db_templates = ContractStore(
  141. store_id=user_data.store_id,
  142. store_name=getattr(user_data, "store_name", None),
  143. merchant_name=user_data.merchant_name,
  144. business_segment=user_data.business_segment,
  145. contact_phone=user_data.contact_phone,
  146. contract_url=user_data.contract_url,
  147. ord_id=user_data.ord_id,
  148. signing_status='未签署'
  149. )
  150. self.db.add(db_templates)
  151. await self.db.commit()
  152. await self.db.refresh(db_templates)
  153. return db_templates
  154. async def mark_signed_by_phone(self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None):
  155. """
  156. 根据手机号 + sign_flow_id 将合同标记为已签署,只更新匹配的合同项
  157. 当 is_master 为 1 时,更新签署状态和时间字段
  158. 同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
  159. 同时更新 contract_download_url 到对应的字典中
  160. """
  161. result = await self._execute_with_retry(
  162. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  163. )
  164. rows = result.mappings().all()
  165. updated = False
  166. for row in rows:
  167. contract_url_raw = row.get("contract_url")
  168. items = None
  169. if contract_url_raw:
  170. try:
  171. items = json.loads(contract_url_raw)
  172. except Exception:
  173. items = None
  174. changed = False
  175. matched_item = None
  176. if isinstance(items, list):
  177. for item in items:
  178. if item.get("sign_flow_id") == sign_flow_id:
  179. item["status"] = 1
  180. # 更新 contract_download_url
  181. if contract_download_url:
  182. item["contract_download_url"] = contract_download_url
  183. matched_item = item
  184. changed = True
  185. break
  186. # 只有当 is_master 为 1 时才更新时间字段
  187. if changed and matched_item and matched_item.get("is_master") == 1:
  188. # 时间处理
  189. signing_dt = signing_time
  190. effective_dt = expiry_dt = None
  191. if signing_dt:
  192. # effective_time 是 signing_time 第二天的 0 点
  193. effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
  194. expiry_dt = effective_dt + timedelta(days=365)
  195. # 更新 contract_url 中对应字典的时间字段
  196. matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") if signing_dt else ""
  197. matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
  198. matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
  199. await self._execute_with_retry(
  200. ContractStore.__table__.update()
  201. .where(ContractStore.id == row["id"])
  202. .values(
  203. signing_status="已签署",
  204. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  205. signing_time=signing_dt,
  206. effective_time=effective_dt,
  207. expiry_time=expiry_dt,
  208. )
  209. )
  210. updated = True
  211. elif changed:
  212. # is_master 不为 1 时,只更新 status,不更新时间字段
  213. await self._execute_with_retry(
  214. ContractStore.__table__.update()
  215. .where(ContractStore.id == row["id"])
  216. .values(
  217. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  218. )
  219. )
  220. updated = True
  221. if updated:
  222. await self.db.commit()
  223. return updated
  224. async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str):
  225. """
  226. 根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
  227. """
  228. result = await self._execute_with_retry(
  229. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  230. )
  231. rows = result.mappings().all()
  232. updated = False
  233. for row in rows:
  234. contract_url_raw = row.get("contract_url")
  235. if not contract_url_raw:
  236. continue
  237. try:
  238. items = json.loads(contract_url_raw)
  239. except Exception:
  240. items = None
  241. if not isinstance(items, list):
  242. continue
  243. changed = False
  244. for item in items:
  245. if item.get("sign_flow_id") == sign_flow_id:
  246. item["sign_url"] = sign_url
  247. changed = True
  248. if changed:
  249. await self._execute_with_retry(
  250. ContractStore.__table__.update()
  251. .where(ContractStore.id == row["id"])
  252. .values(contract_url=json.dumps(items, ensure_ascii=False))
  253. )
  254. updated = True
  255. if updated:
  256. await self.db.commit()
  257. return updated
  258. async def append_contract_url(self, templates_data, contract_item: dict):
  259. """
  260. 根据 store_id,向 contract_url(JSON 列表)追加新的合同信息;
  261. 若 store_id 不存在,则创建新记录。
  262. """
  263. store_id = getattr(templates_data, "store_id", None)
  264. if store_id is None:
  265. logger.error("append_contract_url missing store_id")
  266. return False
  267. result = await self._execute_with_retry(
  268. ContractStore.__table__.select().where(ContractStore.store_id == store_id)
  269. )
  270. rows = result.mappings().all()
  271. updated = False
  272. store_name = getattr(templates_data, "store_name", None)
  273. if rows:
  274. for row in rows:
  275. contract_url_raw = row.get("contract_url")
  276. try:
  277. items = json.loads(contract_url_raw) if contract_url_raw else []
  278. except Exception:
  279. items = []
  280. if not isinstance(items, list):
  281. items = []
  282. items.append(contract_item)
  283. update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
  284. if store_name:
  285. update_values["store_name"] = store_name
  286. contact_phone = getattr(templates_data, "contact_phone", None)
  287. if contact_phone:
  288. update_values["contact_phone"] = contact_phone
  289. await self._execute_with_retry(
  290. ContractStore.__table__.update()
  291. .where(ContractStore.id == row["id"])
  292. .values(**update_values)
  293. )
  294. updated = True
  295. if updated:
  296. await self.db.commit()
  297. return updated
  298. # 未找到则创建新记录
  299. new_record = ContractStore(
  300. store_id=store_id,
  301. store_name=store_name,
  302. business_segment=getattr(templates_data, "business_segment", None),
  303. merchant_name=getattr(templates_data, "merchant_name", None),
  304. contact_phone=getattr(templates_data, "contact_phone", None),
  305. contract_url=json.dumps([contract_item], ensure_ascii=False),
  306. ord_id=getattr(templates_data, "ord_id", None),
  307. signing_status='未签署'
  308. )
  309. self.db.add(new_record)
  310. await self.db.commit()
  311. await self.db.refresh(new_record)
  312. return True