Explorar o código

修复数据库池长期未使用 再次进行数据库连接时 数据库会报错的bug
异常捕获:捕获到数据库连接异常会再次发起请求

mengqiankang hai 2 meses
pai
achega
b960ea26a3
Modificáronse 3 ficheiros con 46 adicións e 17 borrados
  1. 2 0
      alien_gateway/config.py
  2. 43 13
      alien_store/repositories/contract_repo.py
  3. 1 4
      alien_util/celery_app.py

+ 2 - 0
alien_gateway/config.py

@@ -18,6 +18,8 @@ class Settings(BaseSettings):
     DB_PORT: int = 30001
     DB_NAME: str = "alien_sit"
 
+    # redis配置
+    REDIS_URL = "redis://:Alien123456@172.31.154.180:30002/0"
     @property
     def SQLALCHEMY_DATABASE_URI(self) -> str:
         return f"mysql+pymysql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"

+ 43 - 13
alien_store/repositories/contract_repo.py

@@ -1,7 +1,11 @@
-from sqlalchemy.ext.asyncio import AsyncSession
-from alien_store.db.models.contract_store import ContractStore
+import logging
 import json
 from datetime import datetime, timedelta
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.exc import DBAPIError
+from alien_store.db.models.contract_store import ContractStore
+
+logger = logging.getLogger(__name__)
 
 
 class ContractRepository:
@@ -10,9 +14,35 @@ class ContractRepository:
     def __init__(self, db: AsyncSession):
         self.db = db
 
+    async def _execute_with_retry(self, statement):
+        """
+        封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。
+        """
+        try:
+            return await self.db.execute(statement)
+        except Exception as exc:
+            if self._should_retry(exc):
+                logger.warning("DB connection invalidated, retrying once: %s", exc)
+                try:
+                    await self.db.rollback()
+                except Exception:
+                    pass
+                return await self.db.execute(statement)
+            raise
+
+    @staticmethod
+    def _should_retry(exc: Exception) -> bool:
+        # SQLAlchemy 标准:connection_invalidated=True
+        if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False):
+            return True
+        # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError)
+        txt = str(exc).lower()
+        closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"]
+        return any(k in txt for k in closed_keywords)
+
     async def get_by_store_id(self, store_id: int):
         """根据店铺id查询所有合同"""
-        result = await self.db.execute(
+        result = await self._execute_with_retry(
             ContractStore.__table__.select().where(ContractStore.store_id == store_id)
         )
         # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
@@ -20,7 +50,7 @@ class ContractRepository:
 
     async def get_all(self):
         """查询所有合同"""
-        result = await self.db.execute(ContractStore.__table__.select())
+        result = await self._execute_with_retry(ContractStore.__table__.select())
         return [dict(row) for row in result.mappings().all()]
 
     async def get_all_paged(self, page: int, page_size: int = 10):
@@ -28,12 +58,12 @@ class ContractRepository:
         offset = (page - 1) * page_size
         # 查询总数
         from sqlalchemy import func, select
-        count_result = await self.db.execute(
+        count_result = await self._execute_with_retry(
             select(func.count()).select_from(ContractStore.__table__)
         )
         total = count_result.scalar() or 0
         # 查询分页数据
-        result = await self.db.execute(
+        result = await self._execute_with_retry(
             ContractStore.__table__.select().offset(offset).limit(page_size)
         )
         items = [dict(row) for row in result.mappings().all()]
@@ -63,7 +93,7 @@ class ContractRepository:
         同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
         同时更新 contract_download_url 到对应的字典中
         """
-        result = await self.db.execute(
+        result = await self._execute_with_retry(
             ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
         )
         rows = result.mappings().all()
@@ -104,7 +134,7 @@ class ContractRepository:
                     matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
                     matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
                 
-                await self.db.execute(
+                await self._execute_with_retry(
                     ContractStore.__table__.update()
                     .where(ContractStore.id == row["id"])
                     .values(
@@ -118,7 +148,7 @@ class ContractRepository:
                 updated = True
             elif changed:
                 # is_master 不为 1 时,只更新 status,不更新时间字段
-                await self.db.execute(
+                await self._execute_with_retry(
                     ContractStore.__table__.update()
                     .where(ContractStore.id == row["id"])
                     .values(
@@ -134,7 +164,7 @@ class ContractRepository:
         """
         根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
         """
-        result = await self.db.execute(
+        result = await self._execute_with_retry(
             ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
         )
         rows = result.mappings().all()
@@ -155,7 +185,7 @@ class ContractRepository:
                     item["sign_url"] = sign_url
                     changed = True
             if changed:
-                await self.db.execute(
+                await self._execute_with_retry(
                     ContractStore.__table__.update()
                     .where(ContractStore.id == row["id"])
                     .values(contract_url=json.dumps(items, ensure_ascii=False))
@@ -171,7 +201,7 @@ class ContractRepository:
         若手机号不存在,则创建新记录。
         """
         contact_phone = getattr(templates_data, "contact_phone", None)
-        result = await self.db.execute(
+        result = await self._execute_with_retry(
             ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
         )
         rows = result.mappings().all()
@@ -190,7 +220,7 @@ class ContractRepository:
                 update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
                 if store_name:
                     update_values["store_name"] = store_name
-                await self.db.execute(
+                await self._execute_with_retry(
                     ContractStore.__table__.update()
                     .where(ContractStore.id == row["id"])
                     .values(**update_values)

+ 1 - 4
alien_util/celery_app.py

@@ -1,12 +1,9 @@
-"""
-Celery 应用配置
-"""
 from celery import Celery
 from celery.schedules import crontab
 from alien_gateway.config import settings
 
 # Redis 配置(可以从环境变量读取,这里使用默认配置)
-REDIS_URL = "redis://:Alien123456@172.31.154.180:30002/0"
+REDIS_URL = settings.REDIS_URL
 
 # 创建 Celery 应用
 celery_app = Celery(