__init__.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. # @FileName: __init__.py
  4. from typing import List, Optional
  5. import json
  6. from datetime import datetime
  7. from dateutil import parser
  8. from ruoyi_common.utils import StringUtil
  9. from ruoyi_generator.domain.entity import GenTable, GenTableColumn
  10. from ruoyi_admin.ext import db
  11. from sqlalchemy import text
  12. from ruoyi_generator.util import GenUtils, to_underscore
  13. from ruoyi_generator.config import GeneratorConfig
  14. from ruoyi_common.sqlalchemy.model import ColumnEntityList
  15. from ruoyi_generator.domain.po import GenTablePo, GenTableColumnPo
  16. class GenTableMapper:
  17. default_fields = {
  18. "table_id", "table_name", "table_comment", "sub_table_name", "sub_table_fk_name",
  19. "class_name", "tpl_category", "package_name", "module_name", "business_name",
  20. "function_name", "function_author", "gen_type", "gen_path", "options",
  21. "create_by", "create_time", "update_by", "update_time", "remark"
  22. }
  23. default_columns = ColumnEntityList(GenTablePo, default_fields)
  24. def select_list(self, gen_table: GenTable) -> List[GenTable]:
  25. """
  26. 查询代码生成表列表
  27. Args:
  28. gen_table (GenTable): 代码生成表对象
  29. Returns:
  30. List[GenTable]: 代码生成表列表
  31. """
  32. try:
  33. criterions = []
  34. if gen_table.table_name:
  35. criterions.append(GenTablePo.table_name.like(f"%{gen_table.table_name}%"))
  36. if gen_table.table_comment:
  37. criterions.append(GenTablePo.table_comment.like(f"%{gen_table.table_comment}%"))
  38. stmt = db.select(*self.default_columns).where(*criterions)
  39. # 分页查询
  40. if hasattr(gen_table, 'page_num') and hasattr(gen_table, 'page_size') and gen_table.page_num and gen_table.page_size:
  41. offset = (gen_table.page_num - 1) * gen_table.page_size
  42. stmt = stmt.limit(gen_table.page_size).offset(offset)
  43. stmt = stmt.order_by(GenTablePo.table_id.desc())
  44. result = db.session.execute(stmt).all()
  45. tables = []
  46. for row in result:
  47. table = self.default_columns.cast(row, GenTable)
  48. # 解析options字段以设置tree相关属性
  49. if table.options:
  50. try:
  51. options_dict = json.loads(table.options)
  52. table.tree_name = options_dict.get('treeName')
  53. table.tree_code = options_dict.get('treeCode')
  54. table.tree_parent_code = options_dict.get('treeParentCode')
  55. except Exception:
  56. pass
  57. tables.append(table)
  58. return tables
  59. except Exception as e:
  60. print(f"查询代码生成表列表出错: {e}")
  61. # 返回空列表而不是模拟数据
  62. return []
  63. def select_db_list(self, gen_table: GenTable) -> List[GenTable]:
  64. """
  65. 查询数据库表列表
  66. Args:
  67. gen_table (GenTable): 代码生成表对象
  68. Returns:
  69. List[GenTable]: 数据库表列表
  70. """
  71. # 查询真实的数据库表信息
  72. try:
  73. # 查询所有表名
  74. result = db.session.execute(text("SHOW TABLES")).fetchall()
  75. table_names = [row[0] for row in result]
  76. tables = []
  77. for table_name in table_names:
  78. # 检查是否已导入
  79. exists_result = db.session.execute(
  80. text("SELECT COUNT(1) FROM gen_table WHERE table_name = :table_name"),
  81. {"table_name": table_name}
  82. ).fetchone()
  83. exists = exists_result[0] > 0 if exists_result else False
  84. if not exists:
  85. # 获取表注释
  86. table_comment_result = db.session.execute(
  87. text("SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  88. {"table_name": table_name}
  89. ).fetchone()
  90. table_comment = table_comment_result[0] if table_comment_result else table_name
  91. table = GenTable()
  92. table.table_name = table_name
  93. table.table_comment = table_comment
  94. # 设置默认值,以便前端显示
  95. clean_table_name = GenUtils.remove_table_prefix(table_name) if GeneratorConfig.auto_remove_pre else table_name
  96. # 使用下划线命名法而不是驼峰命名法
  97. table.class_name = to_underscore(clean_table_name)
  98. table.package_name = GeneratorConfig.package_name
  99. table.module_name = StringUtil.substring_before(clean_table_name, "_") if hasattr(StringUtil, 'substring_before') and "_" in clean_table_name else clean_table_name
  100. table.business_name = StringUtil.substring_after(clean_table_name, "_") if hasattr(StringUtil, 'substring_after') and "_" in clean_table_name else clean_table_name
  101. table.function_name = table.business_name
  102. table.function_author = GeneratorConfig.author
  103. table.create_by = "admin"
  104. tables.append(table)
  105. # 应用过滤条件
  106. filtered_tables = []
  107. for table in tables:
  108. # 表名过滤
  109. if gen_table.table_name and table.table_name.find(gen_table.table_name) == -1:
  110. continue
  111. # 表注释过滤
  112. if gen_table.table_comment and table.table_comment.find(gen_table.table_comment) == -1:
  113. continue
  114. filtered_tables.append(table)
  115. return filtered_tables
  116. except Exception as e:
  117. # 出现异常时返回空列表
  118. print(f"查询数据库表出错: {e}")
  119. # 返回空列表而不是模拟数据
  120. return []
  121. def select_by_id(self, table_id: int) -> Optional[GenTable]:
  122. """
  123. 根据ID查询代码生成表
  124. Args:
  125. table_id (int): 表ID
  126. Returns:
  127. Optional[GenTable]: 代码生成表对象
  128. """
  129. try:
  130. stmt = db.select(*self.default_columns).where(GenTablePo.table_id == table_id)
  131. row = db.session.execute(stmt).first()
  132. if row:
  133. table = self.default_columns.cast(row, GenTable)
  134. # 解析options字段以设置tree相关属性
  135. if table.options:
  136. try:
  137. options_dict = json.loads(table.options)
  138. table.tree_name = options_dict.get('treeName')
  139. table.tree_code = options_dict.get('treeCode')
  140. table.tree_parent_code = options_dict.get('treeParentCode')
  141. except Exception:
  142. pass
  143. return table
  144. return None
  145. except Exception as e:
  146. print(f"根据ID查询代码生成表出错: {e}")
  147. return None
  148. def select_by_table_name(self, table_name: str) -> Optional[GenTable]:
  149. """
  150. 根据表名查询代码生成表
  151. Args:
  152. table_name (str): 表名
  153. Returns:
  154. Optional[GenTable]: 代码生成表对象
  155. """
  156. try:
  157. stmt = db.select(*self.default_columns).where(GenTablePo.table_name == table_name)
  158. row = db.session.execute(stmt).first()
  159. if row:
  160. table = self.default_columns.cast(row, GenTable)
  161. return table
  162. return None
  163. except Exception as e:
  164. print(f"根据表名查询代码生成表出错: {e}")
  165. return None
  166. def select_db_table_comment_by_name(self, table_name: str) -> Optional[str]:
  167. """
  168. 根据表名查询数据库表注释
  169. Args:
  170. table_name (str): 表名
  171. Returns:
  172. Optional[str]: 表注释
  173. """
  174. try:
  175. result = db.session.execute(
  176. text("SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  177. {"table_name": table_name}
  178. ).fetchone()
  179. return result[0] if result else None
  180. except Exception as e:
  181. print(f"查询表注释出错: {e}")
  182. return None
  183. def exists_table(self, table_name: str) -> bool:
  184. """
  185. 检查表是否存在
  186. Args:
  187. table_name (str): 表名
  188. Returns:
  189. bool: 是否存在
  190. """
  191. try:
  192. result = db.session.execute(
  193. text("SELECT COUNT(1) FROM gen_table WHERE table_name = :table_name"),
  194. {"table_name": table_name}
  195. ).fetchone()
  196. return result[0] > 0 if result else False
  197. except Exception as e:
  198. print(f"检查表是否存在出错: {e}")
  199. return False
  200. def insert(self, gen_table: GenTable) -> int:
  201. """
  202. 插入代码生成表
  203. Args:
  204. gen_table (GenTable): 代码生成表对象
  205. Returns:
  206. int: 插入的表ID
  207. """
  208. try:
  209. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  210. table_data = gen_table.model_dump(by_alias=False, exclude_none=True)
  211. # 移除不需要插入的字段
  212. exclude_fields = {'table_id', 'page_size', 'page_num', 'columns', 'pk_column', 'tree_name', 'tree_code', 'tree_parent_code'}
  213. for field in exclude_fields:
  214. table_data.pop(field, None)
  215. # 移除不需要插入的字段
  216. table_data.pop('update_time', None)
  217. # 确保必要的字段有默认值
  218. table_data.setdefault('create_by', 'admin')
  219. table_data.setdefault('update_by', 'admin')
  220. # 设置创建时间
  221. if 'create_time' not in table_data:
  222. table_data['create_time'] = datetime.now()
  223. # 使用SQLAlchemy ORM方式插入数据
  224. gen_table_po = GenTablePo(**table_data)
  225. db.session.add(gen_table_po)
  226. db.session.flush()
  227. table_id = gen_table_po.table_id
  228. db.session.commit()
  229. return table_id
  230. except Exception as e:
  231. db.session.rollback()
  232. print(f"插入代码生成表出错: {e}")
  233. return 0
  234. def update(self, gen_table: GenTable):
  235. """
  236. 更新代码生成表
  237. Args:
  238. gen_table (GenTable): 代码生成表对象
  239. """
  240. try:
  241. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  242. table_data = gen_table.model_dump(by_alias=False, exclude_none=True)
  243. # 移除不需要更新的字段
  244. exclude_fields = {'table_id', 'page_size', 'page_num', 'columns', 'pk_column', 'tree_name', 'tree_code', 'tree_parent_code'}
  245. for field in exclude_fields:
  246. table_data.pop(field, None)
  247. table_data.pop('create_time', None)
  248. table_data.pop('create_by', None)
  249. # 确保必要的字段有默认值
  250. table_data.setdefault('update_by', 'admin')
  251. # 使用ORM方式更新数据
  252. stmt = db.update(GenTablePo).where(GenTablePo.table_id == gen_table.table_id).values(**table_data)
  253. db.session.execute(stmt)
  254. db.session.commit()
  255. except Exception as e:
  256. db.session.rollback()
  257. print(f"更新代码生成表出错: {e}")
  258. def delete_by_id(self, table_id: int):
  259. """
  260. 根据ID删除代码生成表
  261. Args:
  262. table_id (int): 表ID
  263. """
  264. try:
  265. stmt = db.delete(GenTablePo).where(GenTablePo.table_id == table_id)
  266. db.session.execute(stmt)
  267. db.session.commit()
  268. except Exception as e:
  269. db.session.rollback()
  270. print(f"根据ID删除代码生成表出错: {e}")
  271. raise e
  272. def select_db_table_columns_by_name(self, table_name: str) -> List[GenTableColumn]:
  273. """
  274. 根据表名查询数据库表列信息
  275. Args:
  276. table_name (str): 表名
  277. Returns:
  278. List[GenTableColumn]: 表列信息列表
  279. """
  280. try:
  281. # 查询表的列信息
  282. result = db.session.execute(text("""
  283. SELECT
  284. column_name,
  285. column_comment,
  286. data_type,
  287. is_nullable,
  288. column_default,
  289. column_key,
  290. extra
  291. FROM information_schema.columns
  292. WHERE table_schema = DATABASE() AND table_name = :table_name
  293. ORDER BY ordinal_position
  294. """), {"table_name": table_name}).fetchall()
  295. columns = []
  296. for i, row in enumerate(result):
  297. column = GenTableColumn()
  298. column.column_name = row[0]
  299. column.column_comment = row[1] if row[1] else row[0]
  300. column.column_type = row[2]
  301. # 设置Java类型
  302. if row[2] in ['int', 'integer', 'tinyint', 'smallint', 'mediumint']:
  303. column.java_type = 'Integer'
  304. elif row[2] in ['bigint']:
  305. column.java_type = 'Long'
  306. elif row[2] in ['float', 'double', 'decimal', 'numeric']:
  307. column.java_type = 'BigDecimal'
  308. elif row[2] in ['date', 'datetime', 'timestamp']:
  309. column.java_type = 'Date'
  310. else:
  311. column.java_type = 'String'
  312. # 使用 GenUtils.to_camel_case 方法转换字段名
  313. column.java_field = GenUtils.to_camel_case(row[0])
  314. column.is_pk = '1' if row[5] == 'PRI' else '0'
  315. column.is_increment = '1' if row[6] == 'auto_increment' else '0'
  316. column.is_required = '0' if row[3] == 'YES' else '1'
  317. column.is_insert = '1'
  318. column.is_edit = '1'
  319. column.is_list = '1'
  320. column.is_query = '1'
  321. column.query_type = 'EQ'
  322. column.html_type = 'input'
  323. column.sort = i + 1
  324. columns.append(column)
  325. return columns
  326. except Exception as e:
  327. print(f"查询表列信息出错: {e}")
  328. # 返回空列表而不是模拟数据
  329. return []
  330. class GenTableColumnMapper:
  331. default_fields = {
  332. "column_id", "table_id", "column_name", "column_comment", "column_type",
  333. "java_type", "java_field", "is_pk", "is_increment", "is_required",
  334. "is_insert", "is_edit", "is_list", "is_query", "query_type",
  335. "html_type", "dict_type", "sort", "create_by", "create_time",
  336. "update_by", "update_time", "remark"
  337. }
  338. default_columns = ColumnEntityList(GenTableColumnPo, default_fields)
  339. def select_list_by_table_id(self, table_id: int) -> List[GenTableColumn]:
  340. """
  341. 根据表ID查询代码生成表列列表
  342. Args:
  343. table_id (int): 表ID
  344. Returns:
  345. List[GenTableColumn]: 代码生成表列列表
  346. """
  347. try:
  348. stmt = db.select(*self.default_columns).where(GenTableColumnPo.table_id == table_id).order_by(GenTableColumnPo.sort)
  349. result = db.session.execute(stmt).all()
  350. columns = []
  351. for row in result:
  352. column = self.default_columns.cast(row, GenTableColumn)
  353. columns.append(column)
  354. return columns
  355. except Exception as e:
  356. print(f"根据表ID查询代码生成表列列表出错: {e}")
  357. return []
  358. def insert(self, gen_table_column: GenTableColumn) -> int:
  359. """
  360. 插入代码生成表列
  361. Args:
  362. gen_table_column (GenTableColumn): 代码生成表列对象
  363. Returns:
  364. int: 插入的列ID
  365. """
  366. try:
  367. # 使用ORM方式插入数据,使用下划线命名
  368. column_data = gen_table_column.model_dump(by_alias=False, exclude_none=False)
  369. # 移除数据库表中不存在的字段
  370. column_data.pop('page_num', None)
  371. column_data.pop('page_size', None)
  372. # 确保布尔字段有默认值
  373. bool_fields = ['is_pk', 'is_increment', 'is_required', 'is_insert',
  374. 'is_edit', 'is_list', 'is_query']
  375. for field in bool_fields:
  376. if field in column_data and column_data[field] is None:
  377. column_data[field] = "1"
  378. # 移除不需要插入的字段
  379. column_data.pop('update_time', None)
  380. # 设置创建时间
  381. if 'create_time' not in column_data:
  382. column_data['create_time'] = datetime.now()
  383. gen_table_column_po = GenTableColumnPo(**column_data)
  384. db.session.add(gen_table_column_po)
  385. db.session.flush()
  386. column_id = gen_table_column_po.column_id
  387. db.session.commit()
  388. return column_id
  389. except Exception as e:
  390. db.session.rollback()
  391. print(f"插入代码生成表列出错: {e}")
  392. return 0
  393. def update(self, gen_table_column: GenTableColumn):
  394. """
  395. 更新代码生成表列
  396. Args:
  397. gen_table_column (GenTableColumn): 代码生成表列对象
  398. """
  399. try:
  400. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  401. column_data = gen_table_column.model_dump(by_alias=False, exclude_none=False)
  402. # 移除不需要更新的字段
  403. column_data.pop('create_time', None)
  404. column_data.pop('create_by', None)
  405. # 移除数据库表中不存在的字段
  406. column_data.pop('page_num', None)
  407. column_data.pop('page_size', None)
  408. # 设置更新时间
  409. column_data.setdefault('update_time', datetime.now())
  410. # 使用ORM方式更新数据
  411. stmt = db.update(GenTableColumnPo).where(GenTableColumnPo.column_id == gen_table_column.column_id).values(**column_data)
  412. db.session.execute(stmt)
  413. db.session.commit()
  414. except Exception as e:
  415. db.session.rollback()
  416. print(f"更新代码生成表列出错: {e}")
  417. raise e
  418. def delete_by_table_id(self, table_id: int):
  419. """
  420. 根据表ID删除代码生成表列
  421. Args:
  422. table_id (int): 表ID
  423. """
  424. try:
  425. stmt = db.delete(GenTableColumnPo).where(GenTableColumnPo.table_id == table_id)
  426. db.session.execute(stmt)
  427. db.session.commit()
  428. except Exception as e:
  429. db.session.rollback()
  430. print(f"根据表ID删除代码生成表列出错: {e}")
  431. raise e
  432. def delete_by_id(self, column_id: int):
  433. """
  434. 根据ID删除代码生成表列
  435. Args:
  436. column_id (int): 列ID
  437. """
  438. try:
  439. stmt = db.delete(GenTableColumnPo).where(GenTableColumnPo.column_id == column_id)
  440. db.session.execute(stmt)
  441. db.session.commit()
  442. except Exception as e:
  443. db.session.rollback()
  444. print(f"根据ID删除代码生成表列出错: {e}")
  445. # 实例化Mapper
  446. gen_table_mapper = GenTableMapper()
  447. gen_table_column_mapper = GenTableColumnMapper()