__init__.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. # @FileName: __init__.py
  4. from typing import List, Optional
  5. import json
  6. from datetime import datetime
  7. from dateutil import parser
  8. from ruoyi_common.utils import StringUtil
  9. from ruoyi_generator.domain.entity import GenTable, GenTableColumn
  10. from ruoyi_admin.ext import db
  11. from sqlalchemy import text
  12. from ruoyi_generator.util import GenUtils, to_underscore
  13. from ruoyi_generator.config import GeneratorConfig
  14. from ruoyi_common.sqlalchemy.model import ColumnEntityList
  15. from ruoyi_generator.domain.po import GenTablePo, GenTableColumnPo
  16. class GenTableMapper:
  17. default_fields = {
  18. "table_id", "table_name", "table_comment", "sub_table_name", "sub_table_fk_name",
  19. "class_name", "tpl_category", "package_name", "module_name", "business_name",
  20. "function_name", "function_author", "gen_type", "gen_path", "options",
  21. "create_by", "create_time", "update_by", "update_time", "remark"
  22. }
  23. default_columns = ColumnEntityList(GenTablePo, default_fields)
  24. def select_list(self, gen_table: GenTable) -> List[GenTable]:
  25. """
  26. 查询代码生成表列表
  27. Args:
  28. gen_table (GenTable): 代码生成表对象
  29. Returns:
  30. List[GenTable]: 代码生成表列表
  31. """
  32. try:
  33. criterions = []
  34. if gen_table.table_name:
  35. criterions.append(GenTablePo.table_name.like(f"%{gen_table.table_name}%"))
  36. if gen_table.table_comment:
  37. criterions.append(GenTablePo.table_comment.like(f"%{gen_table.table_comment}%"))
  38. stmt = db.select(*self.default_columns).where(*criterions)
  39. # 分页查询
  40. if hasattr(gen_table, 'page_num') and hasattr(gen_table, 'page_size') and gen_table.page_num and gen_table.page_size:
  41. offset = (gen_table.page_num - 1) * gen_table.page_size
  42. stmt = stmt.limit(gen_table.page_size).offset(offset)
  43. stmt = stmt.order_by(GenTablePo.table_id.desc())
  44. result = db.session.execute(stmt).all()
  45. tables = []
  46. for row in result:
  47. table = self.default_columns.cast(row, GenTable)
  48. # 解析options字段以设置tree相关属性
  49. if table.options:
  50. try:
  51. options_dict = json.loads(table.options)
  52. table.tree_name = options_dict.get('treeName')
  53. table.tree_code = options_dict.get('treeCode')
  54. table.tree_parent_code = options_dict.get('treeParentCode')
  55. except Exception:
  56. pass
  57. tables.append(table)
  58. return tables
  59. except Exception as e:
  60. print(f"查询代码生成表列表出错: {e}")
  61. # 返回空列表而不是模拟数据
  62. return []
  63. def select_db_list(self, gen_table: GenTable) -> List[GenTable]:
  64. """
  65. 查询数据库表列表
  66. Args:
  67. gen_table (GenTable): 代码生成表对象
  68. Returns:
  69. List[GenTable]: 数据库表列表
  70. """
  71. # 查询真实的数据库表信息
  72. try:
  73. # 查询所有表名
  74. result = db.session.execute(text("SHOW TABLES")).fetchall()
  75. table_names = [row[0] for row in result]
  76. tables = []
  77. for table_name in table_names:
  78. # 检查是否已导入
  79. exists_result = db.session.execute(
  80. text("SELECT COUNT(1) FROM gen_table WHERE table_name = :table_name"),
  81. {"table_name": table_name}
  82. ).fetchone()
  83. exists = exists_result[0] > 0 if exists_result else False
  84. if not exists:
  85. # 获取表注释
  86. table_comment_result = db.session.execute(
  87. text("SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  88. {"table_name": table_name}
  89. ).fetchone()
  90. table_comment = table_comment_result[0] if table_comment_result else table_name
  91. table = GenTable()
  92. table.table_name = table_name
  93. table.table_comment = table_comment
  94. # 设置默认值,以便前端显示
  95. clean_table_name = GenUtils.remove_table_prefix(table_name) if GeneratorConfig.auto_remove_pre else table_name
  96. # 使用下划线命名法而不是驼峰命名法
  97. table.class_name = to_underscore(clean_table_name)
  98. table.package_name = GeneratorConfig.package_name
  99. table.module_name = StringUtil.substring_before(clean_table_name, "_") if hasattr(StringUtil, 'substring_before') and "_" in clean_table_name else clean_table_name
  100. table.business_name = StringUtil.substring_after(clean_table_name, "_") if hasattr(StringUtil, 'substring_after') and "_" in clean_table_name else clean_table_name
  101. table.function_name = table.business_name
  102. table.function_author = GeneratorConfig.author
  103. table.create_by = "admin"
  104. tables.append(table)
  105. # 应用过滤条件
  106. filtered_tables = []
  107. for table in tables:
  108. # 表名过滤
  109. if gen_table.table_name and table.table_name.find(gen_table.table_name) == -1:
  110. continue
  111. # 表注释过滤
  112. if gen_table.table_comment and table.table_comment.find(gen_table.table_comment) == -1:
  113. continue
  114. filtered_tables.append(table)
  115. return filtered_tables
  116. except Exception as e:
  117. # 出现异常时返回空列表
  118. print(f"查询数据库表出错: {e}")
  119. # 返回空列表而不是模拟数据
  120. return []
  121. def select_by_id(self, table_id: int) -> Optional[GenTable]:
  122. """
  123. 根据ID查询代码生成表
  124. Args:
  125. table_id (int): 表ID
  126. Returns:
  127. Optional[GenTable]: 代码生成表对象
  128. """
  129. try:
  130. stmt = db.select(*self.default_columns).where(GenTablePo.table_id == table_id)
  131. row = db.session.execute(stmt).first()
  132. if row:
  133. table = self.default_columns.cast(row, GenTable)
  134. # 解析options字段以设置tree相关属性
  135. if table.options:
  136. try:
  137. options_dict = json.loads(table.options)
  138. table.tree_name = options_dict.get('treeName')
  139. table.tree_code = options_dict.get('treeCode')
  140. table.tree_parent_code = options_dict.get('treeParentCode')
  141. except Exception:
  142. pass
  143. return table
  144. return None
  145. except Exception as e:
  146. print(f"根据ID查询代码生成表出错: {e}")
  147. return None
  148. def select_by_table_name(self, table_name: str) -> Optional[GenTable]:
  149. """
  150. 根据表名查询代码生成表
  151. Args:
  152. table_name (str): 表名
  153. Returns:
  154. Optional[GenTable]: 代码生成表对象
  155. """
  156. try:
  157. stmt = db.select(*self.default_columns).where(GenTablePo.table_name == table_name)
  158. row = db.session.execute(stmt).first()
  159. if row:
  160. table = self.default_columns.cast(row, GenTable)
  161. return table
  162. return None
  163. except Exception as e:
  164. print(f"根据表名查询代码生成表出错: {e}")
  165. return None
  166. def select_db_table_comment_by_name(self, table_name: str) -> Optional[str]:
  167. """
  168. 根据表名查询数据库表注释
  169. Args:
  170. table_name (str): 表名
  171. Returns:
  172. Optional[str]: 表注释
  173. """
  174. try:
  175. result = db.session.execute(
  176. text("SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  177. {"table_name": table_name}
  178. ).fetchone()
  179. return result[0] if result else None
  180. except Exception as e:
  181. print(f"查询表注释出错: {e}")
  182. return None
  183. def exists_table(self, table_name: str) -> bool:
  184. """
  185. 检查表是否存在
  186. Args:
  187. table_name (str): 表名
  188. Returns:
  189. bool: 是否存在
  190. """
  191. try:
  192. result = db.session.execute(
  193. text("SELECT COUNT(1) FROM gen_table WHERE table_name = :table_name"),
  194. {"table_name": table_name}
  195. ).fetchone()
  196. return result[0] > 0 if result else False
  197. except Exception as e:
  198. print(f"检查表是否存在出错: {e}")
  199. return False
  200. def insert(self, gen_table: GenTable) -> int:
  201. """
  202. 插入代码生成表
  203. Args:
  204. gen_table (GenTable): 代码生成表对象
  205. Returns:
  206. int: 插入的表ID
  207. """
  208. try:
  209. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  210. table_data = gen_table.model_dump(by_alias=False, exclude_none=True)
  211. # 移除不需要插入的字段
  212. exclude_fields = {'table_id', 'page_size', 'page_num', 'columns', 'pk_column', 'tree_name', 'tree_code', 'tree_parent_code'}
  213. for field in exclude_fields:
  214. table_data.pop(field, None)
  215. # 移除不需要插入的字段
  216. table_data.pop('update_time', None)
  217. # 确保必要的字段有默认值
  218. table_data.setdefault('create_by', 'admin')
  219. table_data.setdefault('update_by', 'admin')
  220. # 设置创建时间
  221. if 'create_time' not in table_data:
  222. table_data['create_time'] = datetime.now()
  223. # 使用SQLAlchemy ORM方式插入数据
  224. gen_table_po = GenTablePo(**table_data)
  225. db.session.add(gen_table_po)
  226. db.session.flush()
  227. table_id = gen_table_po.table_id
  228. db.session.commit()
  229. return table_id
  230. except Exception as e:
  231. db.session.rollback()
  232. print(f"插入代码生成表出错: {e}")
  233. return 0
  234. def update(self, gen_table: GenTable):
  235. """
  236. 更新代码生成表
  237. Args:
  238. gen_table (GenTable): 代码生成表对象
  239. """
  240. try:
  241. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  242. table_data = gen_table.model_dump(by_alias=False, exclude_none=True)
  243. # 移除不需要更新的字段
  244. exclude_fields = {'table_id', 'page_size', 'page_num', 'columns', 'pk_column', 'tree_name', 'tree_code', 'tree_parent_code'}
  245. for field in exclude_fields:
  246. table_data.pop(field, None)
  247. table_data.pop('create_time', None)
  248. table_data.pop('create_by', None)
  249. # 确保必要的字段有默认值
  250. table_data.setdefault('update_by', 'admin')
  251. # 使用ORM方式更新数据
  252. stmt = db.update(GenTablePo).where(GenTablePo.table_id == gen_table.table_id).values(**table_data)
  253. db.session.execute(stmt)
  254. db.session.commit()
  255. except Exception as e:
  256. db.session.rollback()
  257. print(f"更新代码生成表出错: {e}")
  258. def delete_by_id(self, table_id: int):
  259. """
  260. 根据ID删除代码生成表
  261. Args:
  262. table_id (int): 表ID
  263. """
  264. try:
  265. stmt = db.delete(GenTablePo).where(GenTablePo.table_id == table_id)
  266. db.session.execute(stmt)
  267. db.session.commit()
  268. except Exception as e:
  269. db.session.rollback()
  270. print(f"根据ID删除代码生成表出错: {e}")
  271. raise e
  272. def select_db_table_columns_by_name(self, table_name: str) -> List[GenTableColumn]:
  273. """
  274. 根据表名查询数据库表列信息
  275. Args:
  276. table_name (str): 表名
  277. Returns:
  278. List[GenTableColumn]: 表列信息列表
  279. """
  280. try:
  281. # 查询表的列信息
  282. result = db.session.execute(text("""
  283. SELECT
  284. column_name,
  285. column_comment,
  286. data_type,
  287. is_nullable,
  288. column_default,
  289. column_key,
  290. extra
  291. FROM information_schema.columns
  292. WHERE table_schema = DATABASE() AND table_name = :table_name
  293. ORDER BY ordinal_position
  294. """), {"table_name": table_name}).fetchall()
  295. columns = []
  296. for i, row in enumerate(result):
  297. column = GenTableColumn()
  298. column.column_name = row[0]
  299. column.column_comment = row[1] if row[1] else row[0]
  300. column.column_type = row[2]
  301. # 设置Java类型
  302. if row[2] in ['int', 'integer', 'tinyint', 'smallint', 'mediumint']:
  303. column.java_type = 'Integer'
  304. elif row[2] in ['bigint']:
  305. column.java_type = 'Long'
  306. elif row[2] in ['float', 'double', 'decimal', 'numeric']:
  307. column.java_type = 'BigDecimal'
  308. elif row[2] in ['date', 'datetime', 'timestamp']:
  309. column.java_type = 'Date'
  310. else:
  311. column.java_type = 'String'
  312. column.java_field = StringUtil.to_camel_case(row[0]) if hasattr(StringUtil, 'to_camel_case') else row[0]
  313. column.is_pk = '1' if row[5] == 'PRI' else '0'
  314. column.is_increment = '1' if row[6] == 'auto_increment' else '0'
  315. column.is_required = '0' if row[3] == 'YES' else '1'
  316. column.is_insert = '1'
  317. column.is_edit = '1'
  318. column.is_list = '1'
  319. column.is_query = '1'
  320. column.query_type = 'EQ'
  321. column.html_type = 'input'
  322. column.sort = i + 1
  323. columns.append(column)
  324. return columns
  325. except Exception as e:
  326. print(f"查询表列信息出错: {e}")
  327. # 返回空列表而不是模拟数据
  328. return []
  329. class GenTableColumnMapper:
  330. default_fields = {
  331. "column_id", "table_id", "column_name", "column_comment", "column_type",
  332. "java_type", "java_field", "is_pk", "is_increment", "is_required",
  333. "is_insert", "is_edit", "is_list", "is_query", "query_type",
  334. "html_type", "dict_type", "sort", "create_by", "create_time",
  335. "update_by", "update_time", "remark"
  336. }
  337. default_columns = ColumnEntityList(GenTableColumnPo, default_fields)
  338. def select_list_by_table_id(self, table_id: int) -> List[GenTableColumn]:
  339. """
  340. 根据表ID查询代码生成表列列表
  341. Args:
  342. table_id (int): 表ID
  343. Returns:
  344. List[GenTableColumn]: 代码生成表列列表
  345. """
  346. try:
  347. stmt = db.select(*self.default_columns).where(GenTableColumnPo.table_id == table_id).order_by(GenTableColumnPo.sort)
  348. result = db.session.execute(stmt).all()
  349. columns = []
  350. for row in result:
  351. column = self.default_columns.cast(row, GenTableColumn)
  352. columns.append(column)
  353. return columns
  354. except Exception as e:
  355. print(f"根据表ID查询代码生成表列列表出错: {e}")
  356. return []
  357. def insert(self, gen_table_column: GenTableColumn) -> int:
  358. """
  359. 插入代码生成表列
  360. Args:
  361. gen_table_column (GenTableColumn): 代码生成表列对象
  362. Returns:
  363. int: 插入的列ID
  364. """
  365. try:
  366. # 使用ORM方式插入数据,使用下划线命名
  367. column_data = gen_table_column.model_dump(by_alias=False, exclude_none=False)
  368. # 移除数据库表中不存在的字段
  369. column_data.pop('page_num', None)
  370. column_data.pop('page_size', None)
  371. # 确保布尔字段有默认值
  372. bool_fields = ['is_pk', 'is_increment', 'is_required', 'is_insert',
  373. 'is_edit', 'is_list', 'is_query']
  374. for field in bool_fields:
  375. if field in column_data and column_data[field] is None:
  376. column_data[field] = "1"
  377. # 移除不需要插入的字段
  378. column_data.pop('update_time', None)
  379. # 设置创建时间
  380. if 'create_time' not in column_data:
  381. column_data['create_time'] = datetime.now()
  382. gen_table_column_po = GenTableColumnPo(**column_data)
  383. db.session.add(gen_table_column_po)
  384. db.session.flush()
  385. column_id = gen_table_column_po.column_id
  386. db.session.commit()
  387. return column_id
  388. except Exception as e:
  389. db.session.rollback()
  390. print(f"插入代码生成表列出错: {e}")
  391. return 0
  392. def update(self, gen_table_column: GenTableColumn):
  393. """
  394. 更新代码生成表列
  395. Args:
  396. gen_table_column (GenTableColumn): 代码生成表列对象
  397. """
  398. try:
  399. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  400. column_data = gen_table_column.model_dump(by_alias=False, exclude_none=False)
  401. # 移除不需要更新的字段
  402. column_data.pop('create_time', None)
  403. column_data.pop('create_by', None)
  404. # 移除数据库表中不存在的字段
  405. column_data.pop('page_num', None)
  406. column_data.pop('page_size', None)
  407. # 设置更新时间
  408. column_data.setdefault('update_time', datetime.now())
  409. # 使用ORM方式更新数据
  410. stmt = db.update(GenTableColumnPo).where(GenTableColumnPo.column_id == gen_table_column.column_id).values(**column_data)
  411. db.session.execute(stmt)
  412. db.session.commit()
  413. except Exception as e:
  414. db.session.rollback()
  415. print(f"更新代码生成表列出错: {e}")
  416. raise e
  417. def delete_by_table_id(self, table_id: int):
  418. """
  419. 根据表ID删除代码生成表列
  420. Args:
  421. table_id (int): 表ID
  422. """
  423. try:
  424. stmt = db.delete(GenTableColumnPo).where(GenTableColumnPo.table_id == table_id)
  425. db.session.execute(stmt)
  426. db.session.commit()
  427. except Exception as e:
  428. db.session.rollback()
  429. print(f"根据表ID删除代码生成表列出错: {e}")
  430. raise e
  431. def delete_by_id(self, column_id: int):
  432. """
  433. 根据ID删除代码生成表列
  434. Args:
  435. column_id (int): 列ID
  436. """
  437. try:
  438. stmt = db.delete(GenTableColumnPo).where(GenTableColumnPo.column_id == column_id)
  439. db.session.execute(stmt)
  440. db.session.commit()
  441. except Exception as e:
  442. db.session.rollback()
  443. print(f"根据ID删除代码生成表列出错: {e}")
  444. # 实例化Mapper
  445. gen_table_mapper = GenTableMapper()
  446. gen_table_column_mapper = GenTableColumnMapper()