| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- import json
- from typing import Any
- from sqlalchemy import select, func
- from sqlalchemy.ext.asyncio import AsyncSession
- from alien_contract.db.models.bundle import ContractBundle
- from alien_contract.db.models.document import ContractDocument
- from alien_contract.db.models.event import ContractEvent
- class ContractRepository:
- def __init__(self, db: AsyncSession):
- self.db = db
- async def create_bundle(self, data: dict[str, Any]) -> ContractBundle:
- bundle = ContractBundle(**data)
- self.db.add(bundle)
- await self.db.flush()
- return bundle
- async def create_documents(self, bundle_id: int, items: list[dict[str, Any]]) -> list[ContractDocument]:
- docs: list[ContractDocument] = []
- for item in items:
- doc = ContractDocument(
- bundle_id=bundle_id,
- contract_type=item["contract_type"],
- contract_name=item["contract_name"],
- is_primary=item["is_master"],
- status=item["status"],
- sign_flow_id=item["sign_flow_id"],
- file_id=item["file_id"],
- template_url=item["contract_url"],
- sign_url=item.get("sign_url", ""),
- download_url=item.get("contract_download_url", ""),
- )
- self.db.add(doc)
- docs.append(doc)
- await self.db.flush()
- return docs
- async def set_primary_document(self, bundle_id: int, document_id: int) -> None:
- stmt = (
- ContractBundle.__table__.update()
- .where(ContractBundle.id == bundle_id)
- .values(primary_document_id=document_id)
- )
- await self.db.execute(stmt)
- async def list_bundles(self, subject_type: str, subject_id: int, page: int, page_size: int):
- conditions = [
- ContractBundle.subject_type == subject_type,
- ContractBundle.subject_id == subject_id,
- ContractBundle.delete_flag == 0,
- ]
- count_stmt = select(func.count()).select_from(ContractBundle).where(*conditions)
- total_result = await self.db.execute(count_stmt)
- total = total_result.scalar() or 0
- stmt = (
- select(ContractBundle)
- .where(*conditions)
- .order_by(ContractBundle.id.desc())
- .offset((page - 1) * page_size)
- .limit(page_size)
- )
- result = await self.db.execute(stmt)
- bundles = result.scalars().all()
- return bundles, total
- async def list_documents_by_bundle_ids(self, bundle_ids: list[int]) -> dict[int, list[ContractDocument]]:
- if not bundle_ids:
- return {}
- stmt = select(ContractDocument).where(ContractDocument.bundle_id.in_(bundle_ids), ContractDocument.delete_flag == 0)
- result = await self.db.execute(stmt)
- documents = result.scalars().all()
- grouped: dict[int, list[ContractDocument]] = {}
- for doc in documents:
- grouped.setdefault(doc.bundle_id, []).append(doc)
- return grouped
- async def get_document_and_bundle(self, sign_flow_id: str):
- stmt = (
- select(ContractDocument, ContractBundle)
- .join(ContractBundle, ContractDocument.bundle_id == ContractBundle.id)
- .where(ContractDocument.sign_flow_id == sign_flow_id, ContractDocument.delete_flag == 0, ContractBundle.delete_flag == 0)
- )
- result = await self.db.execute(stmt)
- row = result.first()
- if not row:
- return None, None
- return row[0], row[1]
- async def update_document_urls(self, document_id: int, template_url: str | None = None, sign_url: str | None = None, download_url: str | None = None):
- values: dict[str, Any] = {}
- if template_url is not None:
- values["template_url"] = template_url
- if sign_url is not None:
- values["sign_url"] = sign_url
- if download_url is not None:
- values["download_url"] = download_url
- if values:
- stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values)
- await self.db.execute(stmt)
- async def mark_document_signed(self, document_id: int, signing_time, effective_time, expiry_time, download_url: str | None):
- values: dict[str, Any] = {
- "status": 1,
- "signing_time": signing_time,
- "effective_time": effective_time,
- "expiry_time": expiry_time,
- }
- if download_url:
- values["download_url"] = download_url
- stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values)
- await self.db.execute(stmt)
- async def recalc_bundle_status(self, bundle_id: int) -> str:
- stmt = select(ContractDocument.status).where(ContractDocument.bundle_id == bundle_id, ContractDocument.delete_flag == 0)
- result = await self.db.execute(stmt)
- statuses = [row[0] for row in result.fetchall()]
- if not statuses:
- status = "pending"
- elif all(s == 1 for s in statuses):
- status = "all_signed"
- elif any(s == 1 for s in statuses):
- status = "partially_signed"
- else:
- status = "pending"
- update_stmt = ContractBundle.__table__.update().where(ContractBundle.id == bundle_id).values(status=status)
- await self.db.execute(update_stmt)
- return status
- async def create_event(self, bundle_id: int | None, document_id: int | None, sign_flow_id: str, event_type: str, payload: dict[str, Any]):
- event = ContractEvent(
- bundle_id=bundle_id,
- document_id=document_id,
- sign_flow_id=sign_flow_id,
- event_type=event_type,
- payload_json=json.dumps(payload, ensure_ascii=False),
- )
- self.db.add(event)
- async def commit(self):
- await self.db.commit()
|