contract_repo.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from sqlalchemy.ext.asyncio import AsyncSession
  2. from alien_store.db.models.contract_store import ContractStore
  3. import json
  4. from datetime import datetime, timedelta
  5. class ContractRepository:
  6. """合同数据访问层"""
  7. def __init__(self, db: AsyncSession):
  8. self.db = db
  9. async def get_by_store_id(self, store_id: int):
  10. """根据店铺id查询所有合同"""
  11. result = await self.db.execute(
  12. ContractStore.__table__.select().where(ContractStore.store_id == store_id)
  13. )
  14. # 返回列表[dict],避免 Pydantic 序列化 Row 对象出错
  15. return [dict(row) for row in result.mappings().all()]
  16. async def get_all(self):
  17. """查询所有合同"""
  18. result = await self.db.execute(ContractStore.__table__.select())
  19. return [dict(row) for row in result.mappings().all()]
  20. async def get_all_paged(self, page: int, page_size: int = 10):
  21. """分页查询所有合同,返回 (items, total)"""
  22. offset = (page - 1) * page_size
  23. # 查询总数
  24. from sqlalchemy import func, select
  25. count_result = await self.db.execute(
  26. select(func.count()).select_from(ContractStore.__table__)
  27. )
  28. total = count_result.scalar() or 0
  29. # 查询分页数据
  30. result = await self.db.execute(
  31. ContractStore.__table__.select().offset(offset).limit(page_size)
  32. )
  33. items = [dict(row) for row in result.mappings().all()]
  34. return items, total
  35. async def create(self, user_data):
  36. """创建未签署合同模板"""
  37. db_templates = ContractStore(
  38. store_id=user_data.store_id,
  39. store_name=getattr(user_data, "store_name", None),
  40. merchant_name=user_data.merchant_name,
  41. business_segment=user_data.business_segment,
  42. contact_phone=user_data.contact_phone,
  43. contract_url=user_data.contract_url,
  44. seal_url='0.0',
  45. signing_status='未签署'
  46. )
  47. self.db.add(db_templates)
  48. await self.db.commit()
  49. await self.db.refresh(db_templates)
  50. return db_templates
  51. async def mark_signed_by_phone(self, contact_phone: str, sign_flow_id: str, signing_time: datetime | None = None, contract_download_url: str | None = None):
  52. """
  53. 根据手机号 + sign_flow_id 将合同标记为已签署,只更新匹配的合同项
  54. 当 is_master 为 1 时,更新签署状态和时间字段
  55. 同时写入签署/生效/到期时间(签署时间=T,生效=T+1天0点,失效=生效+365天)
  56. 同时更新 contract_download_url 到对应的字典中
  57. """
  58. result = await self.db.execute(
  59. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  60. )
  61. rows = result.mappings().all()
  62. updated = False
  63. for row in rows:
  64. contract_url_raw = row.get("contract_url")
  65. items = None
  66. if contract_url_raw:
  67. try:
  68. items = json.loads(contract_url_raw)
  69. except Exception:
  70. items = None
  71. changed = False
  72. matched_item = None
  73. if isinstance(items, list):
  74. for item in items:
  75. if item.get("sign_flow_id") == sign_flow_id:
  76. item["status"] = 1
  77. # 更新 contract_download_url
  78. if contract_download_url:
  79. item["contract_download_url"] = contract_download_url
  80. matched_item = item
  81. changed = True
  82. break
  83. # 只有当 is_master 为 1 时才更新时间字段
  84. if changed and matched_item and matched_item.get("is_master") == 1:
  85. # 时间处理
  86. signing_dt = signing_time
  87. effective_dt = expiry_dt = None
  88. if signing_dt:
  89. # effective_time 是 signing_time 第二天的 0 点
  90. effective_dt = (signing_dt + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
  91. expiry_dt = effective_dt + timedelta(days=365)
  92. # 更新 contract_url 中对应字典的时间字段
  93. matched_item["signing_time"] = signing_dt.strftime("%Y-%m-%d %H:%M:%S") if signing_dt else ""
  94. matched_item["effective_time"] = effective_dt.strftime("%Y-%m-%d %H:%M:%S") if effective_dt else ""
  95. matched_item["expiry_time"] = expiry_dt.strftime("%Y-%m-%d %H:%M:%S") if expiry_dt else ""
  96. await self.db.execute(
  97. ContractStore.__table__.update()
  98. .where(ContractStore.id == row["id"])
  99. .values(
  100. signing_status="已签署",
  101. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  102. signing_time=signing_dt,
  103. effective_time=effective_dt,
  104. expiry_time=expiry_dt,
  105. )
  106. )
  107. updated = True
  108. elif changed:
  109. # is_master 不为 1 时,只更新 status,不更新时间字段
  110. await self.db.execute(
  111. ContractStore.__table__.update()
  112. .where(ContractStore.id == row["id"])
  113. .values(
  114. contract_url=json.dumps(items, ensure_ascii=False) if items else contract_url_raw,
  115. )
  116. )
  117. updated = True
  118. if updated:
  119. await self.db.commit()
  120. return updated
  121. async def update_sign_url(self, contact_phone: str, sign_flow_id: str, sign_url: str):
  122. """
  123. 根据手机号 + sign_flow_id 更新 contract_url 列表中对应项的 sign_url
  124. """
  125. result = await self.db.execute(
  126. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  127. )
  128. rows = result.mappings().all()
  129. updated = False
  130. for row in rows:
  131. contract_url_raw = row.get("contract_url")
  132. if not contract_url_raw:
  133. continue
  134. try:
  135. items = json.loads(contract_url_raw)
  136. except Exception:
  137. items = None
  138. if not isinstance(items, list):
  139. continue
  140. changed = False
  141. for item in items:
  142. if item.get("sign_flow_id") == sign_flow_id:
  143. item["sign_url"] = sign_url
  144. changed = True
  145. if changed:
  146. await self.db.execute(
  147. ContractStore.__table__.update()
  148. .where(ContractStore.id == row["id"])
  149. .values(contract_url=json.dumps(items, ensure_ascii=False))
  150. )
  151. updated = True
  152. if updated:
  153. await self.db.commit()
  154. return updated
  155. async def append_contract_url(self, templates_data, contract_item: dict):
  156. """
  157. 根据手机号,向 contract_url(JSON 列表)追加新的合同信息;
  158. 若手机号不存在,则创建新记录。
  159. """
  160. contact_phone = getattr(templates_data, "contact_phone", None)
  161. result = await self.db.execute(
  162. ContractStore.__table__.select().where(ContractStore.contact_phone == contact_phone)
  163. )
  164. rows = result.mappings().all()
  165. updated = False
  166. store_name = getattr(templates_data, "store_name", None)
  167. if rows:
  168. for row in rows:
  169. contract_url_raw = row.get("contract_url")
  170. try:
  171. items = json.loads(contract_url_raw) if contract_url_raw else []
  172. except Exception:
  173. items = []
  174. if not isinstance(items, list):
  175. items = []
  176. items.append(contract_item)
  177. update_values = {"contract_url": json.dumps(items, ensure_ascii=False)}
  178. if store_name:
  179. update_values["store_name"] = store_name
  180. await self.db.execute(
  181. ContractStore.__table__.update()
  182. .where(ContractStore.id == row["id"])
  183. .values(**update_values)
  184. )
  185. updated = True
  186. if updated:
  187. await self.db.commit()
  188. return updated
  189. # 未找到则创建新记录
  190. new_record = ContractStore(
  191. store_id=getattr(templates_data, "store_id", None),
  192. store_name=store_name,
  193. business_segment=getattr(templates_data, "business_segment", None),
  194. merchant_name=getattr(templates_data, "merchant_name", None),
  195. contact_phone=contact_phone,
  196. contract_url=json.dumps([contract_item], ensure_ascii=False),
  197. seal_url='0.0',
  198. signing_status='未签署'
  199. )
  200. self.db.add(new_record)
  201. await self.db.commit()
  202. await self.db.refresh(new_record)
  203. return True