contract_repo.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import json
  2. from typing import Any
  3. from sqlalchemy import select, func
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from alien_contract.db.models.bundle import ContractBundle
  6. from alien_contract.db.models.document import ContractDocument
  7. from alien_contract.db.models.event import ContractEvent
  8. class ContractRepository:
  9. def __init__(self, db: AsyncSession):
  10. self.db = db
  11. async def create_bundle(self, data: dict[str, Any]) -> ContractBundle:
  12. bundle = ContractBundle(**data)
  13. self.db.add(bundle)
  14. await self.db.flush()
  15. return bundle
  16. async def create_documents(self, bundle_id: int, items: list[dict[str, Any]]) -> list[ContractDocument]:
  17. docs: list[ContractDocument] = []
  18. for item in items:
  19. doc = ContractDocument(
  20. bundle_id=bundle_id,
  21. contract_type=item["contract_type"],
  22. contract_name=item["contract_name"],
  23. is_primary=item["is_master"],
  24. status=item["status"],
  25. sign_flow_id=item["sign_flow_id"],
  26. file_id=item["file_id"],
  27. template_url=item["contract_url"],
  28. sign_url=item.get("sign_url", ""),
  29. download_url=item.get("contract_download_url", ""),
  30. )
  31. self.db.add(doc)
  32. docs.append(doc)
  33. await self.db.flush()
  34. return docs
  35. async def set_primary_document(self, bundle_id: int, document_id: int) -> None:
  36. stmt = (
  37. ContractBundle.__table__.update()
  38. .where(ContractBundle.id == bundle_id)
  39. .values(primary_document_id=document_id)
  40. )
  41. await self.db.execute(stmt)
  42. async def list_bundles(self, subject_type: str, subject_id: int, page: int, page_size: int):
  43. conditions = [
  44. ContractBundle.subject_type == subject_type,
  45. ContractBundle.subject_id == subject_id,
  46. ContractBundle.delete_flag == 0,
  47. ]
  48. count_stmt = select(func.count()).select_from(ContractBundle).where(*conditions)
  49. total_result = await self.db.execute(count_stmt)
  50. total = total_result.scalar() or 0
  51. stmt = (
  52. select(ContractBundle)
  53. .where(*conditions)
  54. .order_by(ContractBundle.id.desc())
  55. .offset((page - 1) * page_size)
  56. .limit(page_size)
  57. )
  58. result = await self.db.execute(stmt)
  59. bundles = result.scalars().all()
  60. return bundles, total
  61. async def list_documents_by_bundle_ids(self, bundle_ids: list[int]) -> dict[int, list[ContractDocument]]:
  62. if not bundle_ids:
  63. return {}
  64. stmt = select(ContractDocument).where(ContractDocument.bundle_id.in_(bundle_ids), ContractDocument.delete_flag == 0)
  65. result = await self.db.execute(stmt)
  66. documents = result.scalars().all()
  67. grouped: dict[int, list[ContractDocument]] = {}
  68. for doc in documents:
  69. grouped.setdefault(doc.bundle_id, []).append(doc)
  70. return grouped
  71. async def get_document_and_bundle(self, sign_flow_id: str):
  72. stmt = (
  73. select(ContractDocument, ContractBundle)
  74. .join(ContractBundle, ContractDocument.bundle_id == ContractBundle.id)
  75. .where(ContractDocument.sign_flow_id == sign_flow_id, ContractDocument.delete_flag == 0, ContractBundle.delete_flag == 0)
  76. )
  77. result = await self.db.execute(stmt)
  78. row = result.first()
  79. if not row:
  80. return None, None
  81. return row[0], row[1]
  82. async def update_document_urls(self, document_id: int, template_url: str | None = None, sign_url: str | None = None, download_url: str | None = None):
  83. values: dict[str, Any] = {}
  84. if template_url is not None:
  85. values["template_url"] = template_url
  86. if sign_url is not None:
  87. values["sign_url"] = sign_url
  88. if download_url is not None:
  89. values["download_url"] = download_url
  90. if values:
  91. stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values)
  92. await self.db.execute(stmt)
  93. async def mark_document_signed(self, document_id: int, signing_time, effective_time, expiry_time, download_url: str | None):
  94. values: dict[str, Any] = {
  95. "status": 1,
  96. "signing_time": signing_time,
  97. "effective_time": effective_time,
  98. "expiry_time": expiry_time,
  99. }
  100. if download_url:
  101. values["download_url"] = download_url
  102. stmt = ContractDocument.__table__.update().where(ContractDocument.id == document_id).values(**values)
  103. await self.db.execute(stmt)
  104. async def recalc_bundle_status(self, bundle_id: int) -> str:
  105. stmt = select(ContractDocument.status).where(ContractDocument.bundle_id == bundle_id, ContractDocument.delete_flag == 0)
  106. result = await self.db.execute(stmt)
  107. statuses = [row[0] for row in result.fetchall()]
  108. if not statuses:
  109. status = "pending"
  110. elif all(s == 1 for s in statuses):
  111. status = "all_signed"
  112. elif any(s == 1 for s in statuses):
  113. status = "partially_signed"
  114. else:
  115. status = "pending"
  116. update_stmt = ContractBundle.__table__.update().where(ContractBundle.id == bundle_id).values(status=status)
  117. await self.db.execute(update_stmt)
  118. return status
  119. async def create_event(self, bundle_id: int | None, document_id: int | None, sign_flow_id: str, event_type: str, payload: dict[str, Any]):
  120. event = ContractEvent(
  121. bundle_id=bundle_id,
  122. document_id=document_id,
  123. sign_flow_id=sign_flow_id,
  124. event_type=event_type,
  125. payload_json=json.dumps(payload, ensure_ascii=False),
  126. )
  127. self.db.add(event)
  128. async def commit(self):
  129. await self.db.commit()