|
|
@@ -1,7 +1,11 @@
|
|
|
-from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
-from alien_store.db.models.contract_store import ContractStore
|
|
|
+import logging
|
|
|
import json
|
|
|
from datetime import datetime, timedelta
|
|
|
+from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
+from sqlalchemy.exc import DBAPIError
|
|
|
+from alien_store.db.models.contract_store import ContractStore
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ContractRepository:
|
|
|
@@ -10,9 +14,35 @@ class ContractRepository:
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
self.db = db
|
|
|
|
|
|
+ async def _execute_with_retry(self, statement):
|
|
|
+ """
|
|
|
+ 封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ return await self.db.execute(statement)
|
|
|
+ except Exception as exc:
|
|
|
+ if self._should_retry(exc):
|
|
|
+ logger.warning("DB connection invalidated, retrying once: %s", exc)
|
|
|
+ try:
|
|
|
+ await self.db.rollback()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ return await self.db.execute(statement)
|
|
|
+ raise
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _should_retry(exc: Exception) -> bool:
|
|
|
+ # SQLAlchemy 标准:connection_invalidated=True
|
|
|
+ if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False):
|
|
|
+ return True
|
|
|
+ # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError)
|
|
|
+ txt = str(exc).lower()
|
|
|
+ closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"]
|
|
|
+ return any(k in txt for k in closed_keywords)
|
|
|
+
|
|
|
async def get_by_store_id(self, store_id: int):
|
|
|
"""根据店铺id查询所有合同"""
|
|
|
- result = await self.db.execute(
|
|
|
+ result = await self._execute_with_retry(
|
|
|
ContractStore.__table__.select().where(ContractStore.store_id == store_id)
|
|
|
)
|
|
|
# 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
|
|
|
@@ -20,7 +50,7 @@ class ContractRepository:
|
|
|
|
|
|
async def get_all(self):
|
|
|
"""查询所有合同"""
|
|
|
- result = await self.db.execute(ContractStore.__table__.select())
|
|
|
+ result = await self._execute_with_retry(ContractStore.__table__.select())
|
|
|
return [dict(row) for row in result.mappings().all()]
|
|
|
|
|
|
async def get_all_paged(self, page: int, page_size: int = 10):
|
|
|
@@ -28,12 +58,12 @@ class ContractRepository:
|
|
|
offset = (page - 1) * page_size
|
|
|
# 查询总数
|
|
|
from sqlalchemy import func, select
|
|
|
- count_result = await self.db.execute(
|
|
|
+ count_result = await self._execute_with_retry(
|
|
|
select(func.count()).select_from(ContractStore.__table__)
|
|
|
)
|
|
|
total = count_result.scalar() or 0
|
|
|
# 查询分页数据
|
|
|
- result = await self.db.execute(
|
|
|
+ result = await self._execute_with_retry(
|
|
|
ContractStore.__table__.select().offset(offset).limit(page_size)
|
|
|
)
|
|
|
items = [dict(row) for row in result.mappings().all()]
|
|
|
@@ -63,7 +93,7 @@ class ContractRepository:
|
|
|
同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
|
|
|
同时更新 contract_download_url 到对应的字典中
|
|
|
"""
|
|
|
- result = await self.db.execute(
|
|
|
+ result = await self._execute_with_retry(
|
|
|
ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
|
|
|
)
|
|
|
rows = result.mappings().all()
|
|
|
@@ -104,7 +134,7 @@ class ContractRepository:
|
|
|
matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
|
|
|
matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
|
|
|
|
|
|
- await self.db.execute(
|
|
|
+ await self._execute_with_retry(
|
|
|
ContractStore.__table__.update()
|
|
|
.where(ContractStore.id == row["id"])
|
|
|
.values(
|
|
|
@@ -118,7 +148,7 @@ class ContractRepository:
|
|
|
updated = True
|
|
|
elif changed:
|
|
|
# is_master 不为 1 时,只更新 status,不更新时间字段
|
|
|
- await self.db.execute(
|
|
|
+ await self._execute_with_retry(
|
|
|
ContractStore.__table__.update()
|
|
|
.where(ContractStore.id == row["id"])
|
|
|
.values(
|
|
|
@@ -134,7 +164,7 @@ class ContractRepository:
|
|
|
"""
|
|
|
根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
|
|
|
"""
|
|
|
- result = await self.db.execute(
|
|
|
+ result = await self._execute_with_retry(
|
|
|
ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
|
|
|
)
|
|
|
rows = result.mappings().all()
|
|
|
@@ -155,7 +185,7 @@ class ContractRepository:
|
|
|
item["sign_url"] = sign_url
|
|
|
changed = True
|
|
|
if changed:
|
|
|
- await self.db.execute(
|
|
|
+ await self._execute_with_retry(
|
|
|
ContractStore.__table__.update()
|
|
|
.where(ContractStore.id == row["id"])
|
|
|
.values(contract_url=json.dumps(items, ensure_ascii=False))
|
|
|
@@ -171,7 +201,7 @@ class ContractRepository:
|
|
|
若手机号不存在,则创建新记录。
|
|
|
"""
|
|
|
contact_phone = getattr(templates_data, "contact_phone", None)
|
|
|
- result = await self.db.execute(
|
|
|
+ result = await self._execute_with_retry(
|
|
|
ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
|
|
|
)
|
|
|
rows = result.mappings().all()
|
|
|
@@ -190,7 +220,7 @@ class ContractRepository:
|
|
|
update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
|
|
|
if store_name:
|
|
|
update_values["store_name"] = store_name
|
|
|
- await self.db.execute(
|
|
|
+ await self._execute_with_retry(
|
|
|
ContractStore.__table__.update()
|
|
|
.where(ContractStore.id == row["id"])
|
|
|
.values(**update_values)
|