import logging import json from datetime import datetime, timedelta from sqlalchemy import func, select, text, or_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.exc import DBAPIError from alien_store.db.models.contract_store import ContractStore logger = logging.getLogger(__name__) class ContractRepository: """合同数据访问层""" def __init__(self, db: AsyncSession): self.db = db async def _execute_with_retry(self, statement): """ 封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。 """ try: return await self.db.execute(statement) except Exception as exc: if self._should_retry(exc): logger.warning("DB connection invalidated, retrying once: %s", exc) try: await self.db.rollback() except Exception: pass return await self.db.execute(statement) raise @staticmethod def _should_retry(exc: Exception) -> bool: # SQLAlchemy 标准:connection_invalidated=True if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False): return True # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError) txt = str(exc).lower() closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"] return any(k in txt for k in closed_keywords) async def get_by_store_id(self, store_id: int): """根据店铺id查询所有合同""" result = await self._execute_with_retry( ContractStore.__table__.select().where(ContractStore.store_id == store_id) ) # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错 return [dict(row) for row in result.mappings().all()] async def check_store_status(self, store_id: int) -> str | None: """ 检查 store_info 表中对应 store_id 的 reason 字段 """ query = text("SELECT reason FROM store_info WHERE id = :store_id") result = await self._execute_with_retry(query.bindparams(store_id=store_id)) row = result.fetchone() return row[0] if row else None async def get_contract_item_by_sign_flow_id(self, sign_flow_id: str): """ 根据 sign_flow_id 查找合同项,返回 (row, item, items) """ result = await self._execute_with_retry(ContractStore.__table__.select()) rows = result.mappings().all() for row in rows: contract_url_raw = row.get("contract_url") if not contract_url_raw: continue try: items = json.loads(contract_url_raw) except Exception: items = None if not isinstance(items, list): continue for item in items: if item.get("sign_flow_id") == sign_flow_id: return dict(row), item, items return None, None, None async def update_contract_items(self, row_id: int, items: list) -> bool: """ 更新指定记录的 contract_url 列表 """ if not isinstance(items, list): return False await self._execute_with_retry( ContractStore.__table__.update() .where(ContractStore.id == row_id) .values(contract_url=json.dumps(items, ensure_ascii=False)) ) await self.db.commit() return True async def get_all(self): """查询所有合同""" result = await self._execute_with_retry(ContractStore.__table__.select()) return [dict(row) for row in result.mappings().all()] async def get_all_paged( self, page: int, page_size: int = 10, store_name: str | None = None, merchant_name: str | None = None, signing_status: str | None = None, business_segment: str | None = None, store_status: str | None = None, expiry_start: datetime | None = None, expiry_end: datetime | None = None, ): """分页查询所有合同,支持筛选,返回 (items, total)""" offset = (page - 1) * page_size table = ContractStore.__table__ conditions = [] if store_name: conditions.append(table.c.store_name.like(f"%{store_name}%")) if merchant_name: conditions.append(table.c.merchant_name.like(f"%{merchant_name}%")) if signing_status: conditions.append(table.c.signing_status == signing_status) if business_segment: conditions.append(table.c.business_segment == business_segment) if store_status: if store_status == "正常": conditions.append(table.c.signing_status == "已签署") elif store_status == "禁用": conditions.append( or_(table.c.signing_status != "已签署", table.c.signing_status.is_(None)) ) if expiry_start: conditions.append(table.c.expiry_time >= expiry_start) if expiry_end: conditions.append(table.c.expiry_time <= expiry_end) # 查询总数 count_stmt = select(func.count()).select_from(table) if conditions: count_stmt = count_stmt.where(*conditions) count_result = await self._execute_with_retry(count_stmt) total = count_result.scalar() or 0 # 查询分页数据 data_stmt = table.select() if conditions: data_stmt = data_stmt.where(*conditions) data_stmt = data_stmt.offset(offset).limit(page_size) result = await self._execute_with_retry(data_stmt) items = [dict(row) for row in result.mappings().all()] return items, total async def create(self, user_data): """创建未签署合同模板""" db_templates = ContractStore( store_id=user_data.store_id, store_name=getattr(user_data, "store_name", None), merchant_name=user_data.merchant_name, business_segment=user_data.business_segment, contact_phone=user_data.contact_phone, contract_url=user_data.contract_url, ord_id=user_data.ord_id, signing_status='未签署' ) self.db.add(db_templates) await self.db.commit() await self.db.refresh(db_templates) return db_templates async def mark_signed_by_phone(self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None): """ 根据手机号 + sign_flow_id 将合同标记为已签署,只更新匹配的合同项 当 is_master 为 1 时,更新签署状态和时间字段 同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天) 同时更新 contract_download_url 到对应的字典中 """ result = await self._execute_with_retry( ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone) ) rows = result.mappings().all() updated = False for row in rows: contract_url_raw = row.get("contract_url") items = None if contract_url_raw: try: items = json.loads(contract_url_raw) except Exception: items = None changed = False matched_item = None if isinstance(items, list): for item in items: if item.get("sign_flow_id") == sign_flow_id: item["status"] = 1 # 更新 contract_download_url if contract_download_url: item["contract_download_url"] = contract_download_url matched_item = item changed = True break # 只有当 is_master 为 1 时才更新时间字段 if changed and matched_item and matched_item.get("is_master") == 1: # 时间处理 signing_dt = signing_time effective_dt = expiry_dt = None if signing_dt: # effective_time 是 signing_time 第二天的 0 点 effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) expiry_dt = effective_dt + timedelta(days=365) # 更新 contract_url 中对应字典的时间字段 matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") if signing_dt else "" matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else "" matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else "" await self._execute_with_retry( ContractStore.__table__.update() .where(ContractStore.id == row["id"]) .values( signing_status="已签署", contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw, signing_time=signing_dt, effective_time=effective_dt, expiry_time=expiry_dt, ) ) updated = True elif changed: # is_master 不为 1 时,只更新 status,不更新时间字段 await self._execute_with_retry( ContractStore.__table__.update() .where(ContractStore.id == row["id"]) .values( contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw, ) ) updated = True if updated: await self.db.commit() return updated async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str): """ 根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url """ result = await self._execute_with_retry( ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone) ) rows = result.mappings().all() updated = False for row in rows: contract_url_raw = row.get("contract_url") if not contract_url_raw: continue try: items = json.loads(contract_url_raw) except Exception: items = None if not isinstance(items, list): continue changed = False for item in items: if item.get("sign_flow_id") == sign_flow_id: item["sign_url"] = sign_url changed = True if changed: await self._execute_with_retry( ContractStore.__table__.update() .where(ContractStore.id == row["id"]) .values(contract_url=json.dumps(items, ensure_ascii=False)) ) updated = True if updated: await self.db.commit() return updated async def append_contract_url(self, templates_data, contract_item: dict): """ 根据 store_id,向 contract_url(JSON 列表)追加新的合同信息; 若 store_id 不存在,则创建新记录。 """ store_id = getattr(templates_data, "store_id", None) if store_id is None: logger.error("append_contract_url missing store_id") return False result = await self._execute_with_retry( ContractStore.__table__.select().where(ContractStore.store_id == store_id) ) rows = result.mappings().all() updated = False store_name = getattr(templates_data, "store_name", None) if rows: for row in rows: contract_url_raw = row.get("contract_url") try: items = json.loads(contract_url_raw) if contract_url_raw else [] except Exception: items = [] if not isinstance(items, list): items = [] items.append(contract_item) update_values = {"contract_url": json.dumps(items, ensure_ascii=False)} if store_name: update_values["store_name"] = store_name contact_phone = getattr(templates_data, "contact_phone", None) if contact_phone: update_values["contact_phone"] = contact_phone await self._execute_with_retry( ContractStore.__table__.update() .where(ContractStore.id == row["id"]) .values(**update_values) ) updated = True if updated: await self.db.commit() return updated # 未找到则创建新记录 new_record = ContractStore( store_id=store_id, store_name=store_name, business_segment=getattr(templates_data, "business_segment", None), merchant_name=getattr(templates_data, "merchant_name", None), contact_phone=getattr(templates_data, "contact_phone", None), contract_url=json.dumps([contract_item], ensure_ascii=False), ord_id=getattr(templates_data, "ord_id", None), signing_status='未签署' ) self.db.add(new_record) await self.db.commit() await self.db.refresh(new_record) return True