import json import logging from datetime import datetime, timedelta from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import AsyncSession from alien_lawyer.db.models.lawyer_contract import LawyerContract logger = logging.getLogger(__name__) class LawyerContractRepository: def __init__(self, db: AsyncSession): self.db = db async def _execute_with_retry(self, statement): try: return await self.db.execute(statement) except Exception as exc: if self._should_retry(exc): logger.warning("DB connection invalidated, retrying once: %s", exc) try: await self.db.rollback() except Exception: pass return await self.db.execute(statement) raise @staticmethod def _should_retry(exc: Exception) -> bool: if isinstance(exc, DBAPIError) and getattr(exc, "connection_invalidated", False): return True txt = str(exc).lower() closed_keywords = ["closed", "lost connection", "connection was killed", "terminat"] return any(k in txt for k in closed_keywords) async def get_by_lawyer_id(self, lawyer_id: int): result = await self._execute_with_retry( LawyerContract.__table__.select().where(LawyerContract.lawyer_id == lawyer_id) ) return [dict(row) for row in result.mappings().all()] async def get_contract_item_by_sign_flow_id(self, sign_flow_id: str): result = await self._execute_with_retry(LawyerContract.__table__.select()) rows = result.mappings().all() for row in rows: contract_url_raw = row.get("contract_url") if not contract_url_raw: continue try: items = json.loads(contract_url_raw) except Exception: items = None if not isinstance(items, list): continue for item in items: if item.get("sign_flow_id") == sign_flow_id: return dict(row), item, items return None, None, None async def update_contract_items(self, row_id: int, items: list) -> bool: if not isinstance(items, list): return False await self._execute_with_retry( LawyerContract.__table__.update() .where(LawyerContract.id == row_id) .values(contract_url=json.dumps(items, ensure_ascii=False)) ) await self.db.commit() return True async def mark_signed_by_phone( self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None, ): result = await self._execute_with_retry( LawyerContract.__table__.select().where(LawyerContract.contact_phone == contact_phone) ) rows = result.mappings().all() updated = False for row in rows: contract_url_raw = row.get("contract_url") items = None if contract_url_raw: try: items = json.loads(contract_url_raw) except Exception: items = None changed = False matched_item = None if isinstance(items, list): for item in items: if item.get("sign_flow_id") == sign_flow_id: item["status"] = 1 if contract_download_url: item["contract_download_url"] = contract_download_url matched_item = item changed = True break if changed and matched_item and matched_item.get("is_master") == 1: signing_dt = signing_time effective_dt = expiry_dt = None if signing_dt: effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) expiry_dt = effective_dt + timedelta(days=365) matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") await self._execute_with_retry( LawyerContract.__table__.update() .where(LawyerContract.id == row["id"]) .values( signing_status="已签署", contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw, signing_time=signing_dt, effective_time=effective_dt, expiry_time=expiry_dt, ) ) updated = True elif changed: await self._execute_with_retry( LawyerContract.__table__.update() .where(LawyerContract.id == row["id"]) .values(contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw) ) updated = True if updated: await self.db.commit() return updated async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str): result = await self._execute_with_retry( LawyerContract.__table__.select().where(LawyerContract.contact_phone == contact_phone) ) rows = result.mappings().all() updated = False for row in rows: contract_url_raw = row.get("contract_url") if not contract_url_raw: continue try: items = json.loads(contract_url_raw) except Exception: items = None if not isinstance(items, list): continue changed = False for item in items: if item.get("sign_flow_id") == sign_flow_id: item["sign_url"] = sign_url changed = True if changed: await self._execute_with_retry( LawyerContract.__table__.update() .where(LawyerContract.id == row["id"]) .values(contract_url=json.dumps(items, ensure_ascii=False)) ) updated = True if updated: await self.db.commit() return updated async def append_contract_url(self, templates_data, contract_item: dict): lawyer_id = getattr(templates_data, "lawyer_id", None) if lawyer_id is None: return False result = await self._execute_with_retry( LawyerContract.__table__.select().where(LawyerContract.lawyer_id == lawyer_id) ) rows = result.mappings().all() updated = False law_firm_name = getattr(templates_data, "law_firm_name", None) if rows: for row in rows: contract_url_raw = row.get("contract_url") try: items = json.loads(contract_url_raw) if contract_url_raw else [] except Exception: items = [] if not isinstance(items, list): items = [] items.append(contract_item) update_values = {"contract_url": json.dumps(items, ensure_ascii=False)} if law_firm_name: update_values["law_firm_name"] = law_firm_name contact_phone = getattr(templates_data, "contact_phone", None) if contact_phone: update_values["contact_phone"] = contact_phone await self._execute_with_retry( LawyerContract.__table__.update() .where(LawyerContract.id == row["id"]) .values(**update_values) ) updated = True if updated: await self.db.commit() return updated new_record = LawyerContract( lawyer_id=lawyer_id, law_firm_name=law_firm_name, business_segment=getattr(templates_data, "business_segment", None), contact_name=getattr(templates_data, "contact_name", None), contact_phone=getattr(templates_data, "contact_phone", None), contract_url=json.dumps([contract_item], ensure_ascii=False), ord_id=getattr(templates_data, "ord_id", None), signing_status="未签署", ) self.db.add(new_record) await self.db.commit() await self.db.refresh(new_record) return True