__init__.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. # @FileName: __init__.py
  4. import json
  5. from typing import List, Optional, Tuple
  6. from datetime import datetime
  7. from ruoyi_common.utils import DateUtil, StringUtil
  8. from ruoyi_generator.domain.entity import GenTable, GenTableColumn
  9. from ruoyi_generator.mapper import gen_table_mapper, gen_table_column_mapper
  10. from ruoyi_generator.util import GenUtils, to_underscore
  11. from ruoyi_generator.config import GeneratorConfig
  12. class GenTableService:
  13. def select_gen_table_list(self, gen_table: GenTable) -> Tuple[List[GenTable], int]:
  14. """
  15. 查询代码生成表列表
  16. Args:
  17. gen_table (GenTable): 代码生成表对象
  18. Returns:
  19. Tuple[List[GenTable], int]: 代码生成表列表和总数
  20. """
  21. # 查询列表
  22. gen_tables = gen_table_mapper.select_list(gen_table)
  23. # 查询总数
  24. # 注意:这里需要根据实际需求实现总数查询逻辑
  25. total = len(gen_tables)
  26. return gen_tables, total
  27. def select_db_table_list(self, gen_table: GenTable) -> Tuple[List[GenTable], int]:
  28. """
  29. 查询数据库表列表
  30. Args:
  31. gen_table (GenTable): 代码生成表对象
  32. Returns:
  33. Tuple[List[GenTable], int]: 数据库表列表和总数
  34. """
  35. # 查询列表
  36. gen_tables = gen_table_mapper.select_db_list(gen_table)
  37. # 查询总数
  38. # 注意:这里需要根据实际需求实现总数查询逻辑
  39. total = len(gen_tables)
  40. return gen_tables, total
  41. def select_gen_table_by_id(self, table_id: int) -> Optional[GenTable]:
  42. """
  43. 根据ID查询代码生成表
  44. Args:
  45. table_id (int): 表ID
  46. Returns:
  47. Optional[GenTable]: 代码生成表对象
  48. """
  49. gen_table = gen_table_mapper.select_by_id(table_id)
  50. if gen_table:
  51. gen_table.columns = gen_table_column_mapper.select_list_by_table_id(table_id)
  52. return gen_table
  53. def select_gen_table_by_name(self, table_name: str) -> Optional[GenTable]:
  54. """
  55. 根据表名查询代码生成表
  56. Args:
  57. table_name (str): 表名
  58. Returns:
  59. Optional[GenTable]: 代码生成表对象
  60. """
  61. return gen_table_mapper.select_by_table_name(table_name)
  62. def delete_gen_table_by_id(self, table_id: int):
  63. """
  64. 根据ID删除代码生成表
  65. Args:
  66. table_id (int): 表ID
  67. """
  68. # 先删除字段信息
  69. gen_table_column_mapper.delete_by_table_id(table_id)
  70. # 再删除表信息
  71. gen_table_mapper.delete_by_id(table_id)
  72. def delete_gen_table_by_ids(self, table_ids: List[int]):
  73. """
  74. 批量删除代码生成表
  75. Args:
  76. table_ids (List[int]): 表ID列表
  77. """
  78. for table_id in table_ids:
  79. # 先删除字段信息
  80. gen_table_column_mapper.delete_by_table_id(table_id)
  81. # 再删除表信息
  82. gen_table_mapper.delete_by_id(table_id)
  83. def import_gen_table(self, table_names: List[str]) -> int:
  84. """
  85. 导入代码生成表
  86. Args:
  87. table_names (List[str]): 表名列表
  88. Returns:
  89. int: 导入的表数量
  90. """
  91. success_count = 0
  92. for table_name in table_names:
  93. # 检查表是否已存在
  94. if gen_table_mapper.exists_table(table_name):
  95. continue
  96. # 创建GenTable对象
  97. table = GenTable()
  98. # 设置默认值
  99. table.table_name = table_name
  100. clean_table_name = GenUtils.remove_table_prefix(
  101. table_name) if GeneratorConfig.auto_remove_pre else table_name
  102. table.class_name = to_underscore(clean_table_name)
  103. table.business_name = GenUtils.get_business_name(clean_table_name)
  104. table.package_name = GeneratorConfig.package_name
  105. table.module_name = StringUtil.substring_before(clean_table_name, "_") if hasattr(StringUtil,
  106. 'substring_before') and "_" in clean_table_name else clean_table_name
  107. # 获取表注释
  108. try:
  109. result = gen_table_mapper.select_db_table_comment_by_name(table_name)
  110. table.table_comment = result if result else table_name
  111. except:
  112. table.table_comment = table_name
  113. table.function_name = table.table_comment
  114. table.function_author = GeneratorConfig.author
  115. table.create_by = "admin"
  116. table.create_time = datetime.now()
  117. # 保存表信息
  118. table_id = gen_table_mapper.insert(table)
  119. if table_id > 0:
  120. # 保存列信息
  121. columns = gen_table_mapper.select_db_table_columns_by_name(table_name)
  122. for column in columns:
  123. column.table_id = table_id
  124. column.create_by = "admin"
  125. column.create_time = datetime.now()
  126. gen_table_column_mapper.insert(column)
  127. success_count += 1
  128. return success_count
  129. def update_gen_table(self, gen_table: GenTable):
  130. """
  131. 更新代码生成表
  132. Args:
  133. gen_table (GenTable): 代码生成表对象
  134. """
  135. print("更新代码生成表:", gen_table.columns)
  136. # 获取列信息
  137. columns = gen_table.columns
  138. if columns:
  139. # 处理列信息
  140. for column in columns:
  141. print("处理列信息:", column)
  142. # 确保column是GenTableColumn对象而不是dict
  143. column_obj = None
  144. if isinstance(column, dict):
  145. # 直接使用字典数据创建对象,避免手动字段映射
  146. # 前端传过来的已经是正确的字符串格式,直接使用
  147. print("原始列数据:", column)
  148. # 使用 model_validate 方法处理别名字段
  149. column_obj = GenTableColumn.model_validate(column)
  150. print("处理后列对象:", column_obj.model_dump())
  151. else:
  152. column_obj = column
  153. # 设置表ID
  154. column_obj.table_id = gen_table.table_id
  155. column_obj.update_time = datetime.now()
  156. # 如果有column_id则更新,否则插入
  157. if hasattr(column_obj, 'column_id') and column_obj.column_id:
  158. # 更新列信息
  159. gen_table_column_mapper.update(column_obj)
  160. else:
  161. # 插入列信息
  162. column_obj.create_by = "admin"
  163. column_obj.create_time = datetime.now()
  164. gen_table_column_mapper.insert(column_obj)
  165. # 更新表信息
  166. gen_table.update_time = datetime.now()
  167. gen_table_mapper.update(gen_table)
  168. def synch_db(self, table_name: str):
  169. """
  170. 同步数据库表结构
  171. Args:
  172. table_name (str): 表名
  173. """
  174. try:
  175. # 查询表信息
  176. gen_table = gen_table_mapper.select_by_table_name(table_name)
  177. if not gen_table:
  178. raise Exception(f"表{table_name}不存在")
  179. # 查询数据库表列信息
  180. db_columns = gen_table_mapper.select_db_table_columns_by_name(table_name)
  181. # 查询代码生成表列信息
  182. gen_columns = gen_table_column_mapper.select_list_by_table_id(gen_table.table_id)
  183. # 处理新增和更新的列
  184. for db_column in db_columns:
  185. exist_column = None
  186. for gen_column in gen_columns:
  187. if db_column.column_name == gen_column.column_name:
  188. exist_column = gen_column
  189. break
  190. if exist_column:
  191. # 更新列信息
  192. exist_column.column_comment = db_column.column_comment
  193. exist_column.column_type = db_column.column_type
  194. exist_column.java_type = db_column.java_type
  195. exist_column.java_field = db_column.java_field
  196. exist_column.is_pk = db_column.is_pk
  197. exist_column.is_increment = db_column.is_increment
  198. exist_column.is_required = db_column.is_required
  199. exist_column.update_by = "admin"
  200. exist_column.update_time = datetime.now()
  201. gen_table_column_mapper.update(exist_column)
  202. else:
  203. # 新增列信息
  204. db_column.table_id = gen_table.table_id
  205. db_column.create_by = "admin"
  206. db_column.create_time = datetime.now()
  207. gen_table_column_mapper.insert(db_column)
  208. # 处理删除的列
  209. for gen_column in gen_columns:
  210. exist_column = False
  211. for db_column in db_columns:
  212. if gen_column.column_name == db_column.column_name:
  213. exist_column = True
  214. break
  215. if not exist_column:
  216. # 删除列信息
  217. gen_table_column_mapper.delete_by_id(gen_column.column_id)
  218. return True
  219. except Exception as e:
  220. print(f"同步数据库失败: {e}")
  221. raise e
  222. def generator_code(self, table_name: str) -> bytes:
  223. """
  224. 生成代码
  225. Args:
  226. table_name (str): 表名
  227. Returns:
  228. bytes: 生成的代码文件
  229. """
  230. # 查询表信息
  231. gen_table = gen_table_mapper.select_by_table_name(table_name)
  232. if not gen_table:
  233. raise Exception(f"表{table_name}不存在")
  234. # 查询列信息
  235. gen_table.columns = gen_table_column_mapper.select_list_by_table_id(gen_table.table_id)
  236. # 生成代码
  237. return GenUtils.generator_code(gen_table).getvalue()
  238. def batch_generator_code(self, table_names: List[str]) -> bytes:
  239. """
  240. 批量生成代码
  241. Args:
  242. table_names (List[str]): 表名列表
  243. Returns:
  244. bytes: 生成的代码文件
  245. """
  246. gen_tables = []
  247. for table_name in table_names:
  248. # 查询表信息
  249. gen_table = gen_table_mapper.select_by_table_name(table_name)
  250. if gen_table:
  251. # 查询列信息
  252. gen_table.columns = gen_table_column_mapper.select_list_by_table_id(gen_table.table_id)
  253. gen_tables.append(gen_table)
  254. # 生成代码
  255. return GenUtils.batch_generator_code(gen_tables).getvalue()
  256. def preview_code(self, table_id: int) -> dict:
  257. """
  258. 预览代码
  259. Args:
  260. table_id (int): 表ID
  261. Returns:
  262. dict: 预览代码
  263. """
  264. # 查询表信息
  265. gen_table = gen_table_mapper.select_by_id(table_id)
  266. if not gen_table:
  267. raise Exception(f"表ID{table_id}不存在")
  268. # 查询列信息
  269. gen_table.columns = gen_table_column_mapper.select_list_by_table_id(table_id)
  270. # 预览代码
  271. return GenUtils.preview_code(gen_table)
  272. def select_db_table_comment_by_name(self, table_name: str) -> Optional[str]:
  273. """
  274. 根据表名查询数据库表注释
  275. Args:
  276. table_name (str): 表名
  277. Returns:
  278. Optional[str]: 表注释
  279. """
  280. try:
  281. result = gen_table_mapper.session.execute(
  282. text(
  283. "SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  284. {"table_name": table_name}
  285. ).fetchone()
  286. return result[0] if result else None
  287. except Exception:
  288. return None
  289. # 实例化Service
  290. gen_table_service = GenTableService()