__init__.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. # @FileName: __init__.py
  4. from typing import List, Optional
  5. import json
  6. from datetime import datetime
  7. from dateutil import parser
  8. from ruoyi_common.utils import StringUtil
  9. from ruoyi_generator.domain.entity import GenTable, GenTableColumn
  10. from ruoyi_admin.ext import db
  11. from sqlalchemy import text
  12. from ruoyi_generator.util import GenUtils, to_underscore
  13. from ruoyi_generator.config import GeneratorConfig
  14. from ruoyi_common.sqlalchemy.model import ColumnEntityList
  15. from ruoyi_generator.domain.po import GenTablePo, GenTableColumnPo
  16. class GenTableMapper:
  17. default_fields = {
  18. "table_id", "table_name", "table_comment", "sub_table_name", "sub_table_fk_name",
  19. "class_name", "tpl_category", "package_name", "module_name", "business_name",
  20. "function_name", "function_author", "gen_type", "gen_path", "options",
  21. "create_by", "create_time", "update_by", "update_time", "remark"
  22. }
  23. default_columns = ColumnEntityList(GenTablePo, default_fields)
  24. def select_list(self, gen_table: GenTable) -> List[GenTable]:
  25. """
  26. 查询代码生成表列表
  27. Args:
  28. gen_table (GenTable): 代码生成表对象
  29. Returns:
  30. List[GenTable]: 代码生成表列表
  31. """
  32. try:
  33. criterions = []
  34. if gen_table.table_name:
  35. criterions.append(GenTablePo.table_name.like(f"%{gen_table.table_name}%"))
  36. if gen_table.table_comment:
  37. criterions.append(GenTablePo.table_comment.like(f"%{gen_table.table_comment}%"))
  38. stmt = db.select(*self.default_columns).where(*criterions)
  39. # 分页查询
  40. if hasattr(gen_table, 'page_num') and hasattr(gen_table, 'page_size') and gen_table.page_num and gen_table.page_size:
  41. offset = (gen_table.page_num - 1) * gen_table.page_size
  42. stmt = stmt.limit(gen_table.page_size).offset(offset)
  43. stmt = stmt.order_by(GenTablePo.table_id.desc())
  44. result = db.session.execute(stmt).all()
  45. tables = []
  46. for row in result:
  47. table = self.default_columns.cast(row, GenTable)
  48. # 解析options字段以设置tree相关属性
  49. if table.options:
  50. try:
  51. options_dict = json.loads(table.options)
  52. table.tree_name = options_dict.get('treeName')
  53. table.tree_code = options_dict.get('treeCode')
  54. table.tree_parent_code = options_dict.get('treeParentCode')
  55. except Exception:
  56. pass
  57. tables.append(table)
  58. return tables
  59. except Exception as e:
  60. print(f"查询代码生成表列表出错: {e}")
  61. # 返回空列表而不是模拟数据
  62. return []
  63. def select_db_list(self, gen_table: GenTable) -> List[GenTable]:
  64. """
  65. 查询数据库表列表
  66. Args:
  67. gen_table (GenTable): 代码生成表对象
  68. Returns:
  69. List[GenTable]: 数据库表列表
  70. """
  71. # 查询真实的数据库表信息
  72. try:
  73. # 查询所有表名、表注释、创建时间和更新时间,按创建时间倒序排列
  74. result = db.session.execute(text("""
  75. SELECT table_name, table_comment, create_time, update_time
  76. FROM information_schema.tables
  77. WHERE table_schema = DATABASE()
  78. ORDER BY create_time DESC
  79. """)).fetchall()
  80. tables = []
  81. for row in result:
  82. table_name = row[0]
  83. table_comment = row[1] if row[1] else table_name
  84. create_time = row[2] if len(row) > 2 else None
  85. update_time = row[3] if len(row) > 3 else None
  86. # 检查是否已导入
  87. exists_result = db.session.execute(
  88. text("SELECT COUNT(1) FROM gen_table WHERE table_name = :table_name"),
  89. {"table_name": table_name}
  90. ).fetchone()
  91. exists = exists_result[0] > 0 if exists_result else False
  92. if not exists:
  93. table = GenTable()
  94. table.table_name = table_name
  95. table.table_comment = table_comment
  96. # 设置创建时间和更新时间
  97. if create_time:
  98. table.create_time = create_time
  99. if update_time:
  100. table.update_time = update_time
  101. # 设置默认值,以便前端显示
  102. clean_table_name = GenUtils.remove_table_prefix(table_name) if GeneratorConfig.auto_remove_pre else table_name
  103. # 使用下划线命名法而不是驼峰命名法
  104. table.class_name = to_underscore(clean_table_name)
  105. table.package_name = GeneratorConfig.package_name
  106. # 使用配置中的 modelName 作为模块名
  107. table.module_name = GeneratorConfig.model_name
  108. table.business_name = StringUtil.substring_after(clean_table_name, "_") if hasattr(StringUtil, 'substring_after') and "_" in clean_table_name else clean_table_name
  109. table.function_name = table.business_name
  110. table.function_author = GeneratorConfig.author
  111. table.create_by = "admin"
  112. tables.append(table)
  113. # 应用过滤条件
  114. filtered_tables = []
  115. for table in tables:
  116. # 表名过滤
  117. if gen_table.table_name and table.table_name.find(gen_table.table_name) == -1:
  118. continue
  119. # 表注释过滤
  120. if gen_table.table_comment and table.table_comment.find(gen_table.table_comment) == -1:
  121. continue
  122. filtered_tables.append(table)
  123. return filtered_tables
  124. except Exception as e:
  125. # 出现异常时返回空列表
  126. print(f"查询数据库表出错: {e}")
  127. # 返回空列表而不是模拟数据
  128. return []
  129. def select_by_id(self, table_id: int) -> Optional[GenTable]:
  130. """
  131. 根据ID查询代码生成表
  132. Args:
  133. table_id (int): 表ID
  134. Returns:
  135. Optional[GenTable]: 代码生成表对象
  136. """
  137. try:
  138. stmt = db.select(*self.default_columns).where(GenTablePo.table_id == table_id)
  139. row = db.session.execute(stmt).first()
  140. if row:
  141. table = self.default_columns.cast(row, GenTable)
  142. # 解析options字段以设置tree相关属性
  143. if table.options:
  144. try:
  145. options_dict = json.loads(table.options)
  146. table.tree_name = options_dict.get('treeName')
  147. table.tree_code = options_dict.get('treeCode')
  148. table.tree_parent_code = options_dict.get('treeParentCode')
  149. except Exception:
  150. pass
  151. return table
  152. return None
  153. except Exception as e:
  154. print(f"根据ID查询代码生成表出错: {e}")
  155. return None
  156. def select_by_table_name(self, table_name: str) -> Optional[GenTable]:
  157. """
  158. 根据表名查询代码生成表
  159. Args:
  160. table_name (str): 表名
  161. Returns:
  162. Optional[GenTable]: 代码生成表对象
  163. """
  164. try:
  165. stmt = db.select(*self.default_columns).where(GenTablePo.table_name == table_name)
  166. row = db.session.execute(stmt).first()
  167. if row:
  168. table = self.default_columns.cast(row, GenTable)
  169. return table
  170. return None
  171. except Exception as e:
  172. print(f"根据表名查询代码生成表出错: {e}")
  173. return None
  174. def select_db_table_comment_by_name(self, table_name: str) -> Optional[str]:
  175. """
  176. 根据表名查询数据库表注释
  177. Args:
  178. table_name (str): 表名
  179. Returns:
  180. Optional[str]: 表注释
  181. """
  182. try:
  183. result = db.session.execute(
  184. text("SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  185. {"table_name": table_name}
  186. ).fetchone()
  187. return result[0] if result else None
  188. except Exception as e:
  189. print(f"查询表注释出错: {e}")
  190. return None
  191. def exists_table(self, table_name: str) -> bool:
  192. """
  193. 检查表是否存在
  194. Args:
  195. table_name (str): 表名
  196. Returns:
  197. bool: 是否存在
  198. """
  199. try:
  200. result = db.session.execute(
  201. text("SELECT COUNT(1) FROM gen_table WHERE table_name = :table_name"),
  202. {"table_name": table_name}
  203. ).fetchone()
  204. return result[0] > 0 if result else False
  205. except Exception as e:
  206. print(f"检查表是否存在出错: {e}")
  207. return False
  208. def insert(self, gen_table: GenTable) -> int:
  209. """
  210. 插入代码生成表
  211. Args:
  212. gen_table (GenTable): 代码生成表对象
  213. Returns:
  214. int: 插入的表ID
  215. """
  216. try:
  217. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  218. table_data = gen_table.model_dump(by_alias=False, exclude_none=True)
  219. # 移除不需要插入的字段
  220. exclude_fields = {'table_id', 'page_size', 'page_num', 'columns', 'pk_column', 'tree_name', 'tree_code', 'tree_parent_code'}
  221. for field in exclude_fields:
  222. table_data.pop(field, None)
  223. # 移除不需要插入的字段
  224. table_data.pop('update_time', None)
  225. # 确保必要的字段有默认值
  226. table_data.setdefault('create_by', 'admin')
  227. table_data.setdefault('update_by', 'admin')
  228. # 设置创建时间
  229. if 'create_time' not in table_data:
  230. table_data['create_time'] = datetime.now()
  231. # 使用SQLAlchemy ORM方式插入数据
  232. gen_table_po = GenTablePo(**table_data)
  233. db.session.add(gen_table_po)
  234. db.session.flush()
  235. table_id = gen_table_po.table_id
  236. db.session.commit()
  237. return table_id
  238. except Exception as e:
  239. db.session.rollback()
  240. print(f"插入代码生成表出错: {e}")
  241. return 0
  242. def update(self, gen_table: GenTable):
  243. """
  244. 更新代码生成表
  245. Args:
  246. gen_table (GenTable): 代码生成表对象
  247. """
  248. try:
  249. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  250. table_data = gen_table.model_dump(by_alias=False, exclude_none=True)
  251. # 移除不需要更新的字段
  252. exclude_fields = {'table_id', 'page_size', 'page_num', 'columns', 'pk_column', 'tree_name', 'tree_code', 'tree_parent_code', 'parent_menu_id'}
  253. for field in exclude_fields:
  254. table_data.pop(field, None)
  255. table_data.pop('create_time', None)
  256. table_data.pop('create_by', None)
  257. # 确保 options 字段被正确更新(即使为空也要更新)
  258. if 'options' in table_data:
  259. # options 字段需要保留,即使可能是空字符串
  260. pass
  261. # 确保必要的字段有默认值
  262. table_data.setdefault('update_by', 'admin')
  263. # 使用ORM方式更新数据
  264. stmt = db.update(GenTablePo).where(GenTablePo.table_id == gen_table.table_id).values(**table_data)
  265. db.session.execute(stmt)
  266. db.session.commit()
  267. except Exception as e:
  268. db.session.rollback()
  269. print(f"更新代码生成表出错: {e}")
  270. def delete_by_id(self, table_id: int):
  271. """
  272. 根据ID删除代码生成表
  273. Args:
  274. table_id (int): 表ID
  275. """
  276. try:
  277. stmt = db.delete(GenTablePo).where(GenTablePo.table_id == table_id)
  278. db.session.execute(stmt)
  279. db.session.commit()
  280. except Exception as e:
  281. db.session.rollback()
  282. print(f"根据ID删除代码生成表出错: {e}")
  283. raise e
  284. def select_db_table_columns_by_name(self, table_name: str) -> List[GenTableColumn]:
  285. """
  286. 根据表名查询数据库表列信息
  287. Args:
  288. table_name (str): 表名
  289. Returns:
  290. List[GenTableColumn]: 表列信息列表
  291. """
  292. try:
  293. # 查询表的列信息
  294. result = db.session.execute(text("""
  295. SELECT
  296. column_name,
  297. column_comment,
  298. data_type,
  299. is_nullable,
  300. column_default,
  301. column_key,
  302. extra
  303. FROM information_schema.columns
  304. WHERE table_schema = DATABASE() AND table_name = :table_name
  305. ORDER BY ordinal_position
  306. """), {"table_name": table_name}).fetchall()
  307. columns = []
  308. for i, row in enumerate(result):
  309. column = GenTableColumn()
  310. column.column_name = row[0]
  311. column.column_comment = row[1] if row[1] else row[0]
  312. column.column_type = row[2]
  313. # 设置Java类型
  314. if row[2] in ['int', 'integer', 'tinyint', 'smallint', 'mediumint']:
  315. column.java_type = 'Integer'
  316. elif row[2] in ['bigint']:
  317. column.java_type = 'Long'
  318. elif row[2] in ['float', 'double', 'decimal', 'numeric']:
  319. column.java_type = 'BigDecimal'
  320. elif row[2] in ['date', 'datetime', 'timestamp']:
  321. column.java_type = 'Date'
  322. else:
  323. column.java_type = 'String'
  324. # 使用 GenUtils.to_camel_case 方法转换字段名
  325. column.java_field = GenUtils.to_camel_case(row[0])
  326. column.is_pk = '1' if row[5] == 'PRI' else '0'
  327. column.is_increment = '1' if row[6] == 'auto_increment' else '0'
  328. column.is_required = '0' if row[3] == 'YES' else '1'
  329. column.is_insert = '1'
  330. column.is_edit = '1'
  331. column.is_list = '1'
  332. column.is_query = '1'
  333. column.query_type = 'EQ'
  334. column.html_type = 'input'
  335. column.sort = i + 1
  336. columns.append(column)
  337. return columns
  338. except Exception as e:
  339. print(f"查询表列信息出错: {e}")
  340. # 返回空列表而不是模拟数据
  341. return []
  342. class GenTableColumnMapper:
  343. default_fields = {
  344. "column_id", "table_id", "column_name", "column_comment", "column_type",
  345. "java_type", "java_field", "is_pk", "is_increment", "is_required",
  346. "is_insert", "is_edit", "is_list", "is_query", "query_type",
  347. "html_type", "dict_type", "sort", "create_by", "create_time",
  348. "update_by", "update_time", "remark"
  349. }
  350. default_columns = ColumnEntityList(GenTableColumnPo, default_fields)
  351. def select_list_by_table_id(self, table_id: int) -> List[GenTableColumn]:
  352. """
  353. 根据表ID查询代码生成表列列表
  354. Args:
  355. table_id (int): 表ID
  356. Returns:
  357. List[GenTableColumn]: 代码生成表列列表
  358. """
  359. try:
  360. stmt = db.select(*self.default_columns).where(GenTableColumnPo.table_id == table_id).order_by(GenTableColumnPo.sort)
  361. result = db.session.execute(stmt).all()
  362. columns = []
  363. for row in result:
  364. column = self.default_columns.cast(row, GenTableColumn)
  365. columns.append(column)
  366. return columns
  367. except Exception as e:
  368. print(f"根据表ID查询代码生成表列列表出错: {e}")
  369. return []
  370. def insert(self, gen_table_column: GenTableColumn) -> int:
  371. """
  372. 插入代码生成表列
  373. Args:
  374. gen_table_column (GenTableColumn): 代码生成表列对象
  375. Returns:
  376. int: 插入的列ID
  377. """
  378. try:
  379. # 使用ORM方式插入数据,使用下划线命名
  380. column_data = gen_table_column.model_dump(by_alias=False, exclude_none=False)
  381. # 移除数据库表中不存在的字段
  382. column_data.pop('page_num', None)
  383. column_data.pop('page_size', None)
  384. # 确保布尔字段有默认值
  385. bool_fields = ['is_pk', 'is_increment', 'is_required', 'is_insert',
  386. 'is_edit', 'is_list', 'is_query']
  387. for field in bool_fields:
  388. if field in column_data and column_data[field] is None:
  389. column_data[field] = "1"
  390. # 移除不需要插入的字段
  391. column_data.pop('update_time', None)
  392. # 设置创建时间
  393. if 'create_time' not in column_data:
  394. column_data['create_time'] = datetime.now()
  395. gen_table_column_po = GenTableColumnPo(**column_data)
  396. db.session.add(gen_table_column_po)
  397. db.session.flush()
  398. column_id = gen_table_column_po.column_id
  399. db.session.commit()
  400. return column_id
  401. except Exception as e:
  402. db.session.rollback()
  403. print(f"插入代码生成表列出错: {e}")
  404. return 0
  405. def update(self, gen_table_column: GenTableColumn):
  406. """
  407. 更新代码生成表列
  408. Args:
  409. gen_table_column (GenTableColumn): 代码生成表列对象
  410. """
  411. try:
  412. # 使用model_dump方法直接获取所有字段的值,使用下划线命名
  413. column_data = gen_table_column.model_dump(by_alias=False, exclude_none=False)
  414. # 移除不需要更新的字段
  415. column_data.pop('create_time', None)
  416. column_data.pop('create_by', None)
  417. # 移除数据库表中不存在的字段
  418. column_data.pop('page_num', None)
  419. column_data.pop('page_size', None)
  420. # 设置更新时间
  421. column_data.setdefault('update_time', datetime.now())
  422. # 使用ORM方式更新数据
  423. stmt = db.update(GenTableColumnPo).where(GenTableColumnPo.column_id == gen_table_column.column_id).values(**column_data)
  424. db.session.execute(stmt)
  425. db.session.commit()
  426. except Exception as e:
  427. db.session.rollback()
  428. print(f"更新代码生成表列出错: {e}")
  429. raise e
  430. def delete_by_table_id(self, table_id: int):
  431. """
  432. 根据表ID删除代码生成表列
  433. Args:
  434. table_id (int): 表ID
  435. """
  436. try:
  437. stmt = db.delete(GenTableColumnPo).where(GenTableColumnPo.table_id == table_id)
  438. db.session.execute(stmt)
  439. db.session.commit()
  440. except Exception as e:
  441. db.session.rollback()
  442. print(f"根据表ID删除代码生成表列出错: {e}")
  443. raise e
  444. def delete_by_id(self, column_id: int):
  445. """
  446. 根据ID删除代码生成表列
  447. Args:
  448. column_id (int): 列ID
  449. """
  450. try:
  451. stmt = db.delete(GenTableColumnPo).where(GenTableColumnPo.column_id == column_id)
  452. db.session.execute(stmt)
  453. db.session.commit()
  454. except Exception as e:
  455. db.session.rollback()
  456. print(f"根据ID删除代码生成表列出错: {e}")
  457. # 实例化Mapper
  458. gen_table_mapper = GenTableMapper()
  459. gen_table_column_mapper = GenTableColumnMapper()