import json from typing import Any from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from alien_contract.db.models.bundle import ContractBundle from alien_contract.db.models.document import ContractDocument from alien_contract.db.models.event import ContractEvent class ContractRepository: def __init__(self, db: AsyncSession): self.db = db async def create_bundle(self, data: dict[str, Any]) -> ContractBundle: bundle = ContractBundle(**data) self.db.add(bundle) await self.db.flush() return bundle async def create_documents(self, bundle_id: int, items: list[dict[str, Any]]) -> list[ContractDocument]: docs: list[ContractDocument] = [] for item in items: doc = ContractDocument( bundle_id=bundle_id, contract_type=item["contract_type"], contract_name=item["contract_name"], is_primary=item["is_master"], status=item["status"], sign_flow_id=item["sign_flow_id"], file_id=item["file_id"], template_url=item["contract_url"], sign_url=item.get("sign_url", ""), download_url=item.get("contract_download_url", ""), ) self.db.add(doc) docs.append(doc) await self.db.flush() return docs async def set_primary_document(self, bundle_id: int, document_id: int) -> None: stmt = ( ContractBundle.__table__.update() .where(ContractBundle.id == bundle_id) .values(primary_document_id=document_id) ) await self.db.execute(stmt) async def list_bundles(self, subject_type: str, subject_id: int, page: int, page_size: int): conditions = [ ContractBundle.subject_type == subject_type, ContractBundle.subject_id == subject_id, ContractBundle.delete_flag == 0, ] count_stmt = select(func.count()).select_from(ContractBundle).where(*conditions) total_result = await self.db.execute(count_stmt) total = total_result.scalar() or 0 stmt = ( select(ContractBundle) .where(*conditions) .order_by(ContractBundle.id.desc()) .offset((page - 1) * page_size) .limit(page_size) ) result = await self.db.execute(stmt) bundles = result.scalars().all() return bundles, total async def list_documents_by_bundle_ids(self, bundle_ids: list[int]) -> dict[int, list[ContractDocument]]: if not bundle_ids: return {} stmt = select(ContractDocument).where(ContractDocument.bundle_id.in_(bundle_ids), ContractDocument.delete_flag == 0) result = await self.db.execute(stmt) documents = result.scalars().all() grouped: dict[int, list[ContractDocument]] = {} for doc in documents: grouped.setdefault(doc.bundle_id, []).append(doc) return grouped async def get_document_and_bundle(self, sign_flow_id: str): stmt = ( select(ContractDocument, ContractBundle) .join(ContractBundle, ContractDocument.bundle_id == ContractBundle.id) .where(ContractDocument.sign_flow_id == sign_flow_id, ContractDocument.delete_flag == 0, ContractBundle.delete_flag == 0) ) result = await self.db.execute(stmt) row = result.first() if not row: return None, None return row[0], row[1] async def update_document_urls(self, document_id: int, template_url: str | None = None, sign_url: str | None = None, download_url: str | None = None): values: dict[str, Any] = {} if template_url is not None: values["template_url"] = template_url if sign_url is not None: values["sign_url"] = sign_url if download_url is not None: values["download_url"] = download_url if values: stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values) await self.db.execute(stmt) async def mark_document_signed(self, document_id: int, signing_time, effective_time, expiry_time, download_url: str | None): values: dict[str, Any] = { "status": 1, "signing_time": signing_time, "effective_time": effective_time, "expiry_time": expiry_time, } if download_url: values["download_url"] = download_url stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values) await self.db.execute(stmt) async def recalc_bundle_status(self, bundle_id: int) -> str: stmt = select(ContractDocument.status).where(ContractDocument.bundle_id == bundle_id, ContractDocument.delete_flag == 0) result = await self.db.execute(stmt) statuses = [row[0] for row in result.fetchall()] if not statuses: status = "pending" elif all(s == 1 for s in statuses): status = "all_signed" elif any(s == 1 for s in statuses): status = "partially_signed" else: status = "pending" update_stmt = ContractBundle.__table__.update().where(ContractBundle.id == bundle_id).values(status=status) await self.db.execute(update_stmt) return status async def create_event(self, bundle_id: int | None, document_id: int | None, sign_flow_id: str, event_type: str, payload: dict[str, Any]): event = ContractEvent( bundle_id=bundle_id, document_id=document_id, sign_flow_id=sign_flow_id, event_type=event_type, payload_json=json.dumps(payload, ensure_ascii=False), ) self.db.add(event) async def commit(self): await self.db.commit()