contract_repo.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import json
  2. from typing import Any
  3. from sqlalchemy import select, func
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from alien_contract.db.models.bundle import ContractBundle
  6. from alien_contract.db.models.document import ContractDocument
  7. from alien_contract.db.models.event import ContractEvent
  8. class ContractRepository:
  9. def __init__(self, db: AsyncSession):
  10. self.db = db
  11. async def create_bundle(self, data: dict[str, Any]) -> ContractBundle:
  12. bundle = ContractBundle(**data)
  13. self.db.add(bundle)
  14. await self.db.flush()
  15. return bundle
  16. async def create_documents(self, bundle_id: int, items: list[dict[str, Any]]) -> list[ContractDocument]:
  17. docs: list[ContractDocument] = []
  18. for item in items:
  19. doc = ContractDocument(
  20. bundle_id=bundle_id,
  21. contract_type=item["contract_type"],
  22. contract_name=item["contract_name"],
  23. is_primary=item["is_master"],
  24. status=item["status"],
  25. sign_flow_id=item["sign_flow_id"],
  26. file_id=item["file_id"],
  27. template_url=item["contract_url"],
  28. sign_url=item.get("sign_url", ""),
  29. download_url=item.get("contract_download_url", ""),
  30. )
  31. self.db.add(doc)
  32. docs.append(doc)
  33. await self.db.flush()
  34. return docs
  35. async def set_primary_document(self, bundle_id: int, document_id: int) -> None:
  36. stmt = (
  37. ContractBundle.__table__.update()
  38. .where(ContractBundle.id == bundle_id)
  39. .values(primary_document_id=document_id)
  40. )
  41. await self.db.execute(stmt)
  42. async def list_bundles(self, subject_type: str, subject_id: int, page: int, page_size: int):
  43. conditions = [
  44. ContractBundle.subject_type == subject_type,
  45. ContractBundle.subject_id == subject_id,
  46. ContractBundle.delete_flag == 0,
  47. ]
  48. count_stmt = select(func.count()).select_from(ContractBundle).where(*conditions)
  49. total_result = await self.db.execute(count_stmt)
  50. total = total_result.scalar() or 0
  51. stmt = (
  52. select(ContractBundle)
  53. .where(*conditions)
  54. .order_by(ContractBundle.id.desc())
  55. .offset((page - 1) * page_size)
  56. .limit(page_size)
  57. )
  58. result = await self.db.execute(stmt)
  59. bundles = result.scalars().all()
  60. return bundles, total
  61. async def list_documents_by_bundle_ids(self, bundle_ids: list[int], *, doc_status: int | None = None) -> dict[int, list[ContractDocument]]:
  62. if not bundle_ids:
  63. return {}
  64. conditions = [ContractDocument.bundle_id.in_(bundle_ids), ContractDocument.delete_flag == 0]
  65. if doc_status is not None:
  66. conditions.append(ContractDocument.status == doc_status)
  67. stmt = select(ContractDocument).where(*conditions)
  68. result = await self.db.execute(stmt)
  69. documents = result.scalars().all()
  70. grouped: dict[int, list[ContractDocument]] = {}
  71. for doc in documents:
  72. grouped.setdefault(doc.bundle_id, []).append(doc)
  73. return grouped
  74. async def get_document_and_bundle(self, sign_flow_id: str):
  75. stmt = (
  76. select(ContractDocument, ContractBundle)
  77. .join(ContractBundle, ContractDocument.bundle_id == ContractBundle.id)
  78. .where(ContractDocument.sign_flow_id == sign_flow_id, ContractDocument.delete_flag == 0, ContractBundle.delete_flag == 0)
  79. )
  80. result = await self.db.execute(stmt)
  81. row = result.first()
  82. if not row:
  83. return None, None
  84. return row[0], row[1]
  85. async def update_document_urls(self, document_id: int, template_url: str | None = None, sign_url: str | None = None, download_url: str | None = None):
  86. values: dict[str, Any] = {}
  87. if template_url is not None:
  88. values["template_url"] = template_url
  89. if sign_url is not None:
  90. values["sign_url"] = sign_url
  91. if download_url is not None:
  92. values["download_url"] = download_url
  93. if values:
  94. stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values)
  95. await self.db.execute(stmt)
  96. async def mark_document_signed(self, document_id: int, signing_time, effective_time, expiry_time, download_url: str | None):
  97. values: dict[str, Any] = {
  98. "status": 1,
  99. "signing_time": signing_time,
  100. "effective_time": effective_time,
  101. "expiry_time": expiry_time,
  102. }
  103. if download_url:
  104. values["download_url"] = download_url
  105. stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values)
  106. await self.db.execute(stmt)
  107. async def recalc_bundle_status(self, bundle_id: int) -> str:
  108. stmt = select(ContractDocument.status).where(ContractDocument.bundle_id == bundle_id, ContractDocument.delete_flag == 0)
  109. result = await self.db.execute(stmt)
  110. statuses = [row[0] for row in result.fetchall()]
  111. if not statuses:
  112. status = "未签署"
  113. elif all(s == 1 for s in statuses):
  114. status = "已签署"
  115. elif any(s == 1 for s in statuses):
  116. status = "审核中"
  117. else:
  118. status = "未签署"
  119. update_stmt = ContractBundle.__table__.update().where(ContractBundle.id == bundle_id).values(status=status)
  120. await self.db.execute(update_stmt)
  121. return status
  122. async def create_event(self, bundle_id: int | None, document_id: int | None, sign_flow_id: str, event_type: str, payload: dict[str, Any]):
  123. event = ContractEvent(
  124. bundle_id=bundle_id,
  125. document_id=document_id,
  126. sign_flow_id=sign_flow_id,
  127. event_type=event_type,
  128. payload_json=json.dumps(payload, ensure_ascii=False),
  129. )
  130. self.db.add(event)
  131. async def commit(self):
  132. await self.db.commit()
  133. async def rollback(self):
  134. await self.db.rollback()