mapper.py.vm 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # -*- coding: utf-8 -*-
  2. # @Author : {{ table.function_author }}
  3. # @FileName: {{ underscore(table.class_name) }}_mapper.py
  4. # @Time : {{ datetime }}
  5. from typing import List
  6. from datetime import datetime
  7. from flask import g
  8. from sqlalchemy import select, update, delete
  9. from ruoyi_admin.ext import db
  10. from {{ get_import_path(table.package_name, table.module_name, 'domain') }} import {{ class_name_pascal }}
  11. from {{ get_import_path(table.package_name, table.module_name, 'domain', table.class_name) }} import {{ class_name_pascal }}Po
  12. class {{ class_name_pascal }}Mapper:
  13. """{{ table.function_name }}Mapper"""
  14. @staticmethod
  15. def select_{{ underscore(table.class_name) }}_list({{ underscore(table.business_name) }}: {{ class_name_pascal }}) -> List[{{ class_name_pascal }}]:
  16. """
  17. 查询{{ table.function_name }}列表
  18. Args:
  19. {{ underscore(table.business_name) }} ({{ table.class_name }}): {{ table.function_name }}对象
  20. Returns:
  21. List[{{ table.class_name }}]: {{ table.function_name }}列表
  22. """
  23. try:
  24. # 构建查询条件
  25. stmt = select({{ class_name_pascal }}Po)
  26. {% for column in table.columns %}
  27. {% if column.is_query %}
  28. {%- set field_name = underscore(column.java_field) %}
  29. {%- if column.query_type == 'EQ' %}
  30. if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
  31. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} == {{ underscore(table.business_name) }}.{{ field_name }})
  32. {%- elif column.query_type == 'NE' %}
  33. if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
  34. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} != {{ underscore(table.business_name) }}.{{ field_name }})
  35. {%- elif column.query_type == 'GT' %}
  36. if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
  37. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} > {{ underscore(table.business_name) }}.{{ field_name }})
  38. {%- elif column.query_type == 'GTE' %}
  39. if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
  40. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} >= {{ underscore(table.business_name) }}.{{ field_name }})
  41. {%- elif column.query_type == 'LT' %}
  42. if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
  43. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} < {{ underscore(table.business_name) }}.{{ field_name }})
  44. {%- elif column.query_type == 'LTE' %}
  45. if {{ underscore(table.business_name) }}.{{ field_name }} is not None:
  46. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} <= {{ underscore(table.business_name) }}.{{ field_name }})
  47. {%- elif column.query_type == 'LIKE' %}
  48. if {{ underscore(table.business_name) }}.{{ field_name }}:
  49. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }}.like("%" + str({{ underscore(table.business_name) }}.{{ field_name }}) + "%"))
  50. {%- elif column.query_type == 'BETWEEN' %}
  51. _params = getattr({{ underscore(table.business_name) }}, "params", {}) or {}
  52. begin_val = _params.get("begin{{ capitalize_first(column.java_field) }}")
  53. end_val = _params.get("end{{ capitalize_first(column.java_field) }}")
  54. if begin_val is not None:
  55. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} >= begin_val)
  56. if end_val is not None:
  57. stmt = stmt.where({{ class_name_pascal }}Po.{{ field_name }} <= end_val)
  58. {%- endif %}
  59. {% endif %}
  60. {% endfor %}
  61. if "criterian_meta" in g and g.criterian_meta.page:
  62. g.criterian_meta.page.stmt = stmt
  63. result = db.session.execute(stmt).scalars().all()
  64. return [{{ class_name_pascal }}.model_validate(item) for item in result] if result else []
  65. except Exception as e:
  66. print(f"查询{{ table.function_name }}列表出错: {e}")
  67. return []
  68. {% if table.pk_column %}
  69. @staticmethod
  70. def select_{{ underscore(table.class_name) }}_by_id({{ underscore(table.pk_column.java_field) }}: int) -> {{ class_name_pascal }}:
  71. """
  72. 根据ID查询{{ table.function_name }}
  73. Args:
  74. {{ underscore(table.pk_column.java_field) }} (int): {{ table.pk_column.column_comment }}
  75. Returns:
  76. {{ table.class_name }}: {{ table.function_name }}对象
  77. """
  78. try:
  79. result = db.session.get({{ class_name_pascal }}Po, {{ underscore(table.pk_column.java_field) }})
  80. return {{ class_name_pascal }}.model_validate(result) if result else None
  81. except Exception as e:
  82. print(f"根据ID查询{{ table.function_name }}出错: {e}")
  83. return None
  84. {% endif %}
  85. @staticmethod
  86. def insert_{{ underscore(table.class_name) }}({{ underscore(table.business_name) }}: {{ class_name_pascal }}) -> int:
  87. """
  88. 新增{{ table.function_name }}
  89. Args:
  90. {{ underscore(table.business_name) }} ({{ table.class_name }}): {{ table.function_name }}对象
  91. Returns:
  92. int: 插入的记录数
  93. """
  94. try:
  95. now = datetime.now()
  96. new_po = {{ class_name_pascal }}Po()
  97. {%- for column in table.columns %}
  98. {%- set attr = underscore(column.java_field) %}
  99. {%- if column.column_name in ['create_time', 'update_time'] %}
  100. new_po.{{ attr }} = {{ underscore(table.business_name) }}.{{ attr }} or now
  101. {%- else %}
  102. new_po.{{ attr }} = {{ underscore(table.business_name) }}.{{ attr }}
  103. {%- endif %}
  104. {%- endfor %}
  105. db.session.add(new_po)
  106. db.session.commit()
  107. {%- if table.pk_column %}
  108. {{ underscore(table.business_name) }}.{{ underscore(table.pk_column.java_field) }} = new_po.{{ underscore(table.pk_column.java_field) }}
  109. {%- endif %}
  110. return 1
  111. except Exception as e:
  112. db.session.rollback()
  113. print(f"新增{{ table.function_name }}出错: {e}")
  114. return 0
  115. {% if table.pk_column %}
  116. @staticmethod
  117. def update_{{ underscore(table.class_name) }}({{ underscore(table.business_name) }}: {{ class_name_pascal }}) -> int:
  118. """
  119. 修改{{ table.function_name }}
  120. Args:
  121. {{ underscore(table.business_name) }} ({{ table.class_name }}): {{ table.function_name }}对象
  122. Returns:
  123. int: 更新的记录数
  124. """
  125. try:
  126. {% if table.pk_column %}
  127. existing = db.session.get({{ class_name_pascal }}Po, {{ underscore(table.business_name) }}.{{ underscore(table.pk_column.java_field) }})
  128. if not existing:
  129. return 0
  130. now = datetime.now()
  131. {%- for column in table.columns %}
  132. {%- set attr = underscore(column.java_field) %}
  133. {%- if column.is_pk == '1' %}
  134. # 主键不参与更新
  135. {%- elif column.column_name == 'update_time' %}
  136. existing.{{ attr }} = {{ underscore(table.business_name) }}.{{ attr }} or now
  137. {%- else %}
  138. existing.{{ attr }} = {{ underscore(table.business_name) }}.{{ attr }}
  139. {%- endif %}
  140. {%- endfor %}
  141. db.session.commit()
  142. return 1
  143. {% else %}
  144. db.session.merge({{ underscore(table.business_name) }})
  145. db.session.commit()
  146. return 1
  147. {% endif %}
  148. except Exception as e:
  149. db.session.rollback()
  150. print(f"修改{{ table.function_name }}出错: {e}")
  151. return 0
  152. @staticmethod
  153. def delete_{{ underscore(table.class_name) }}_by_ids(ids: List[int]) -> int:
  154. """
  155. 批量删除{{ table.function_name }}
  156. Args:
  157. ids (List[int]): ID列表
  158. Returns:
  159. int: 删除的记录数
  160. """
  161. try:
  162. stmt = delete({{ class_name_pascal }}Po).where({{ class_name_pascal }}Po.{{ underscore(table.pk_column.java_field) }}.in_(ids))
  163. result = db.session.execute(stmt)
  164. db.session.commit()
  165. return result.rowcount
  166. except Exception as e:
  167. db.session.rollback()
  168. print(f"批量删除{{ table.function_name }}出错: {e}")
  169. return 0
  170. {% endif %}