| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342 |
- import logging
- import json
- from datetime import datetime, timedelta
- from sqlalchemy import func, select, text, or_
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.exc import DBAPIError
- from alien_store.db.models.contract_store import ContractStore
- logger = logging.getLogger(__name__)
- class ContractRepository:
- """合同数据访问层"""
- def __init__(self, db: AsyncSession):
- self.db = db
- async def _execute_with_retry(self, statement):
- """
- 封装一次重试:如果连接已失效(被 MySQL/网络关掉),回滚并重试一次。
- """
- try:
- return await self.db.execute(statement)
- except Exception as exc:
- if self._should_retry(exc):
- logger.warning("DB connection invalidated, retrying once: %s", exc)
- try:
- await self.db.rollback()
- except Exception:
- pass
- return await self.db.execute(statement)
- raise
- @staticmethod
- def _should_retry(exc: Exception) -> bool:
- # SQLAlchemy 标准:connection_invalidated=True
- if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False):
- return True
- # 兜底判断常见“连接已关闭”文案(aiomysql/uvloop RuntimeError)
- txt = str(exc).lower()
- closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"]
- return any(k in txt for k in closed_keywords)
- async def get_by_store_id(self, store_id: int):
- """根据店铺id查询所有合同"""
- result = await self._execute_with_retry(
- ContractStore.__table__.select().where(ContractStore.store_id == store_id)
- )
- # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
- return [dict(row) for row in result.mappings().all()]
- async def check_store_status(self, store_id: int) -> str | None:
- """
- 检查 store_info 表中对应 store_id 的 reason 字段
- """
- query = text("SELECT reason FROM store_info WHERE id = :store_id")
- result = await self._execute_with_retry(query.bindparams(store_id=store_id))
- row = result.fetchone()
- return row[0] if row else None
- async def get_contract_item_by_sign_flow_id(self, sign_flow_id: str):
- """
- 根据 sign_flow_id 查找合同项,返回 (row, item, items)
- """
- result = await self._execute_with_retry(ContractStore.__table__.select())
- rows = result.mappings().all()
- for row in rows:
- contract_url_raw = row.get("contract_url")
- if not contract_url_raw:
- continue
- try:
- items = json.loads(contract_url_raw)
- except Exception:
- items = None
- if not isinstance(items, list):
- continue
- for item in items:
- if item.get("sign_flow_id") == sign_flow_id:
- return dict(row), item, items
- return None, None, None
- async def update_contract_items(self, row_id: int, items: list) -> bool:
- """
- 更新指定记录的 contract_url 列表
- """
- if not isinstance(items, list):
- return False
- await self._execute_with_retry(
- ContractStore.__table__.update()
- .where(ContractStore.id == row_id)
- .values(contract_url=json.dumps(items, ensure_ascii=False))
- )
- await self.db.commit()
- return True
- async def get_all(self):
- """查询所有合同"""
- result = await self._execute_with_retry(ContractStore.__table__.select())
- return [dict(row) for row in result.mappings().all()]
- async def get_all_paged(
- self,
- page: int,
- page_size: int = 10,
- store_name: str | None = None,
- merchant_name: str | None = None,
- signing_status: str | None = None,
- business_segment: str | None = None,
- store_status: str | None = None,
- expiry_start: datetime | None = None,
- expiry_end: datetime | None = None,
- ):
- """分页查询所有合同,支持筛选,返回 (items, total)"""
- offset = (page - 1) * page_size
- table = ContractStore.__table__
- conditions = []
- if store_name:
- conditions.append(table.c.store_name.like(f"%{store_name}%"))
- if merchant_name:
- conditions.append(table.c.merchant_name.like(f"%{merchant_name}%"))
- if signing_status:
- conditions.append(table.c.signing_status == signing_status)
- if business_segment:
- conditions.append(table.c.business_segment == business_segment)
- if store_status:
- if store_status == "正常":
- conditions.append(table.c.signing_status == "已签署")
- elif store_status == "禁用":
- conditions.append(
- or_(table.c.signing_status != "已签署", table.c.signing_status.is_(None))
- )
- if expiry_start:
- conditions.append(table.c.expiry_time >= expiry_start)
- if expiry_end:
- conditions.append(table.c.expiry_time <= expiry_end)
- # 查询总数
- count_stmt = select(func.count()).select_from(table)
- if conditions:
- count_stmt = count_stmt.where(*conditions)
- count_result = await self._execute_with_retry(count_stmt)
- total = count_result.scalar() or 0
- # 查询分页数据
- data_stmt = table.select()
- if conditions:
- data_stmt = data_stmt.where(*conditions)
- data_stmt = data_stmt.offset(offset).limit(page_size)
- result = await self._execute_with_retry(data_stmt)
- items = [dict(row) for row in result.mappings().all()]
- return items, total
- async def create(self, user_data):
- """创建未签署合同模板"""
- db_templates = ContractStore(
- store_id=user_data.store_id,
- store_name=getattr(user_data, "store_name", None),
- merchant_name=user_data.merchant_name,
- business_segment=user_data.business_segment,
- contact_phone=user_data.contact_phone,
- contract_url=user_data.contract_url,
- ord_id=user_data.ord_id,
- signing_status='未签署'
- )
- self.db.add(db_templates)
- await self.db.commit()
- await self.db.refresh(db_templates)
- return db_templates
- async def mark_signed_by_phone(self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None):
- """
- 根据手机号 + sign_flow_id 将合同标记为已签署,只更新匹配的合同项
- 当 is_master 为 1 时,更新签署状态和时间字段
- 同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
- 同时更新 contract_download_url 到对应的字典中
- """
- result = await self._execute_with_retry(
- ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
- )
- rows = result.mappings().all()
- updated = False
- for row in rows:
- contract_url_raw = row.get("contract_url")
- items = None
- if contract_url_raw:
- try:
- items = json.loads(contract_url_raw)
- except Exception:
- items = None
- changed = False
- matched_item = None
- if isinstance(items, list):
- for item in items:
- if item.get("sign_flow_id") == sign_flow_id:
- item["status"] = 1
- # 更新 contract_download_url
- if contract_download_url:
- item["contract_download_url"] = contract_download_url
- matched_item = item
- changed = True
- break
-
- # 只有当 is_master 为 1 时才更新时间字段
- if changed and matched_item and matched_item.get("is_master") == 1:
- # 时间处理
- signing_dt = signing_time
- effective_dt = expiry_dt = None
- if signing_dt:
- # effective_time 是 signing_time 第二天的 0 点
- effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
- expiry_dt = effective_dt + timedelta(days=365)
-
- # 更新 contract_url 中对应字典的时间字段
- matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") if signing_dt else ""
- matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
- matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
-
- await self._execute_with_retry(
- ContractStore.__table__.update()
- .where(ContractStore.id == row["id"])
- .values(
- signing_status="已签署",
- contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
- signing_time=signing_dt,
- effective_time=effective_dt,
- expiry_time=expiry_dt,
- )
- )
- updated = True
- elif changed:
- # is_master 不为 1 时,只更新 status,不更新时间字段
- await self._execute_with_retry(
- ContractStore.__table__.update()
- .where(ContractStore.id == row["id"])
- .values(
- contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
- )
- )
- updated = True
- if updated:
- await self.db.commit()
- return updated
- async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str):
- """
- 根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
- """
- result = await self._execute_with_retry(
- ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
- )
- rows = result.mappings().all()
- updated = False
- for row in rows:
- contract_url_raw = row.get("contract_url")
- if not contract_url_raw:
- continue
- try:
- items = json.loads(contract_url_raw)
- except Exception:
- items = None
- if not isinstance(items, list):
- continue
- changed = False
- for item in items:
- if item.get("sign_flow_id") == sign_flow_id:
- item["sign_url"] = sign_url
- changed = True
- if changed:
- await self._execute_with_retry(
- ContractStore.__table__.update()
- .where(ContractStore.id == row["id"])
- .values(contract_url=json.dumps(items, ensure_ascii=False))
- )
- updated = True
- if updated:
- await self.db.commit()
- return updated
- async def append_contract_url(self, templates_data, contract_item: dict):
- """
- 根据 store_id,向 contract_url(JSON 列表)追加新的合同信息;
- 若 store_id 不存在,则创建新记录。
- """
- store_id = getattr(templates_data, "store_id", None)
- if store_id is None:
- logger.error("append_contract_url missing store_id")
- return False
- result = await self._execute_with_retry(
- ContractStore.__table__.select().where(ContractStore.store_id == store_id)
- )
- rows = result.mappings().all()
- updated = False
- store_name = getattr(templates_data, "store_name", None)
- if rows:
- for row in rows:
- contract_url_raw = row.get("contract_url")
- try:
- items = json.loads(contract_url_raw) if contract_url_raw else []
- except Exception:
- items = []
- if not isinstance(items, list):
- items = []
- items.append(contract_item)
- update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
- if store_name:
- update_values["store_name"] = store_name
- contact_phone = getattr(templates_data, "contact_phone", None)
- if contact_phone:
- update_values["contact_phone"] = contact_phone
- await self._execute_with_retry(
- ContractStore.__table__.update()
- .where(ContractStore.id == row["id"])
- .values(**update_values)
- )
- updated = True
- if updated:
- await self.db.commit()
- return updated
- # 未找到则创建新记录
- new_record = ContractStore(
- store_id=store_id,
- store_name=store_name,
- business_segment=getattr(templates_data, "business_segment", None),
- merchant_name=getattr(templates_data, "merchant_name", None),
- contact_phone=getattr(templates_data, "contact_phone", None),
- contract_url=json.dumps([contract_item], ensure_ascii=False),
- ord_id=getattr(templates_data, "ord_id", None),
- signing_status='未签署'
- )
- self.db.add(new_record)
- await self.db.commit()
- await self.db.refresh(new_record)
- return True
|