ruoyi_generator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # -*- coding: utf-8 -*-
  2. # @Author : YY
  3. # @FileName: ruoyi_generator.py
  4. import json
  5. import os
  6. from typing import List
  7. from zipfile import ZipFile
  8. from io import BytesIO
  9. from jinja2 import Environment, FileSystemLoader
  10. from ruoyi_common.utils import StringUtil
  11. from ruoyi_common.constant import Constants
  12. from ruoyi_generator.domain.entity import GenTable, GenTableColumn
  13. from ruoyi_generator.mapper import gen_table_mapper, gen_table_column_mapper
  14. from ruoyi_generator.util import GenUtils
  15. from ruoyi_generator.config import GeneratorConfig
  16. from datetime import datetime
  17. from ruoyi_admin.ext import db
  18. from sqlalchemy import text
  19. from flask import Flask
  20. from ruoyi_admin import create_app
  21. from ruoyi_generator.mapper import gen_table_mapper
  22. from ruoyi_generator.domain.entity import GenTable
  23. from ruoyi_generator.config import GeneratorConfig
  24. from ruoyi_common.utils import StringUtil
  25. from ruoyi_generator.util import to_underscore
  26. class RuoYiGenerator:
  27. def __init__(self):
  28. # 初始化模板引擎
  29. self.template_env = Environment(
  30. loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), 'vm')),
  31. autoescape=False # 关闭自动转义,避免HTML转义字符
  32. )
  33. def get_template_data(self, table_id: int) -> dict:
  34. """
  35. 获取模板数据
  36. Args:
  37. table_id (int): 表ID
  38. Returns:
  39. dict: 模板数据
  40. """
  41. # 查询表信息
  42. table = gen_table_mapper.select_by_id(table_id)
  43. if not table:
  44. raise Exception(f"表ID {table_id} 不存在")
  45. # 查询列信息
  46. columns = gen_table_column_mapper.select_list_by_table_id(table_id)
  47. table.columns = columns
  48. # 设置列的 list_index 属性
  49. from ruoyi_generator.util import GenUtils
  50. GenUtils.set_column_list_index(table)
  51. # 设置主键列
  52. pk_columns = [column for column in columns if column.is_pk == '1']
  53. if pk_columns:
  54. table.pk_column = pk_columns[0]
  55. # 设置其他属性
  56. if table.options:
  57. try:
  58. table.options = json.loads(table.options) if isinstance(table.options, str) else table.options
  59. table.tree_name = table.options.get('treeName')
  60. table.tree_code = table.options.get('treeCode')
  61. table.tree_parent_code = table.options.get('treeParentCode')
  62. # 从 options 中提取 parentMenuId
  63. if 'parentMenuId' in table.options:
  64. table.parent_menu_id = table.options.get('parentMenuId')
  65. except Exception:
  66. pass
  67. # 强制使用前端模块名(modelName),而不是 Python 模块名
  68. # module_name 必须使用 modelName(test),不能使用 pythonModelName(ruoyi_test)
  69. # 如果 module_name 是空的、等于 python_model_name 或包含 python_model_name,强制替换为 model_name
  70. original_module_name = table.module_name
  71. if not table.module_name or table.module_name == GeneratorConfig.python_model_name or (table.module_name and GeneratorConfig.python_model_name in table.module_name):
  72. table.module_name = GeneratorConfig.model_name
  73. if original_module_name != table.module_name:
  74. print(f"警告:table.module_name 从 '{original_module_name}' 强制替换为 '{table.module_name}'(前端模块名)")
  75. # 设置模板数据
  76. data = {
  77. 'table': table,
  78. 'constants': Constants,
  79. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  80. }
  81. return data
  82. def generate_files(self, table_id: int) -> dict:
  83. """
  84. 生成文件内容
  85. Args:
  86. table_id (int): 表ID
  87. Returns:
  88. dict: 生成的文件内容,key为文件路径,value为文件内容
  89. """
  90. # 获取模板数据
  91. data = self.get_template_data(table_id)
  92. table = data['table']
  93. # 获取所有模板文件
  94. template_files = self.get_template_files()
  95. # 生成文件
  96. generated_files = {}
  97. for template_file in template_files:
  98. try:
  99. # 渲染模板
  100. template = self.template_env.get_template(template_file)
  101. content = template.render(**data)
  102. # 生成文件名
  103. file_name = GenUtils.get_file_name(template_file, table)
  104. generated_files[file_name] = content
  105. except Exception as e:
  106. print(f"渲染模板 {template_file} 失败: {e}")
  107. continue
  108. return generated_files
  109. def get_template_files(self) -> List[str]:
  110. """
  111. 获取模板文件列表
  112. Returns:
  113. List[str]: 模板文件列表
  114. """
  115. template_files = []
  116. vm_dir = os.path.join(os.path.dirname(__file__), 'vm')
  117. # 递归遍历vm目录下的所有.vm文件
  118. for root, dirs, files in os.walk(vm_dir):
  119. for file in files:
  120. if file.endswith('.vm'):
  121. # 获取相对于vm目录的路径
  122. rel_path = os.path.relpath(os.path.join(root, file), vm_dir)
  123. template_files.append(rel_path.replace('\\', '/'))
  124. return template_files
  125. def preview_code(self, table_id: int) -> dict:
  126. """
  127. 预览代码
  128. Args:
  129. table_id (int): 表ID
  130. Returns:
  131. dict: 生成的代码
  132. """
  133. try:
  134. # 生成文件内容
  135. generated_files = self.generate_files(table_id)
  136. # 返回生成的代码
  137. return generated_files
  138. except Exception as e:
  139. print(f"预览代码失败: {e}")
  140. return {}
  141. def download_code(self, table_id: int) -> bytes:
  142. """
  143. 下载代码
  144. Args:
  145. table_id (int): 表ID
  146. Returns:
  147. bytes: 生成的代码压缩包
  148. """
  149. # 生成文件内容
  150. generated_files = self.generate_files(table_id)
  151. # 创建ZIP文件
  152. zip_buffer = BytesIO()
  153. with ZipFile(zip_buffer, 'w') as zip_file:
  154. for file_path, content in generated_files.items():
  155. zip_file.writestr(file_path, content)
  156. zip_buffer.seek(0)
  157. return zip_buffer.getvalue()
  158. def import_table(self, table_name: str) -> bool:
  159. """导入表"""
  160. try:
  161. # 检查表是否已存在
  162. if gen_table_mapper.exists_table(table_name):
  163. # 如果表已存在,直接同步字段信息
  164. from ruoyi_generator.service import GenTableService
  165. service = GenTableService()
  166. return service.synch_db(table_name)
  167. # 创建表信息
  168. table = GenTable()
  169. table.table_name = table_name
  170. # 获取表注释
  171. try:
  172. table_comment_result = db.session.execute(
  173. text("SELECT table_comment FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = :table_name"),
  174. {"table_name": table_name}
  175. ).fetchone()
  176. table.table_comment = table_comment_result[0] if table_comment_result and table_comment_result[0] else table_name
  177. except Exception as e:
  178. print(f"获取表注释失败: {e}")
  179. table.table_comment = table_name
  180. # 处理表名前缀
  181. clean_table_name = GenUtils.remove_table_prefix(table_name) if GeneratorConfig.auto_remove_pre else table_name
  182. # 使用下划线命名法而不是驼峰命名法
  183. table.class_name = to_underscore(clean_table_name)
  184. table.package_name = GeneratorConfig.package_name
  185. # 使用配置中的 modelName 作为模块名
  186. table.module_name = GeneratorConfig.model_name
  187. table.business_name = StringUtil.substring_after(clean_table_name, "_") if hasattr(StringUtil, 'substring_after') and "_" in clean_table_name else clean_table_name
  188. table.function_name = table.business_name
  189. table.function_author = GeneratorConfig.author
  190. table.create_by = "admin"
  191. # 插入表信息到数据库
  192. table_id = gen_table_mapper.insert(table)
  193. # 获取表列信息
  194. columns = gen_table_mapper.select_db_table_columns_by_name(table_name)
  195. # 即使没有列信息也继续处理
  196. if not columns:
  197. print(f"警告:未能获取到表 {table_name} 的列信息")
  198. for i, column in enumerate(columns or []):
  199. column.table_id = table_id
  200. column.sort = i + 1
  201. column.create_by = "admin"
  202. # 设置默认的字段属性
  203. if not column.java_type:
  204. if column.column_type in ['int', 'integer', 'tinyint', 'smallint', 'mediumint']:
  205. column.java_type = 'Integer'
  206. elif column.column_type in ['bigint']:
  207. column.java_type = 'Long'
  208. elif column.column_type in ['float', 'double', 'decimal', 'numeric']:
  209. column.java_type = 'BigDecimal'
  210. elif column.column_type in ['date', 'datetime', 'timestamp']:
  211. column.java_type = 'Date'
  212. else:
  213. column.java_type = 'String'
  214. if not column.java_field:
  215. column.java_field = GenUtils.to_camel_case(column.column_name)
  216. if not column.html_type:
  217. if column.column_type in ['date', 'datetime', 'timestamp']:
  218. column.html_type = 'datetime'
  219. elif column.column_type in ['text', 'longtext', 'mediumtext']:
  220. column.html_type = 'textarea'
  221. elif column.column_type in ['tinyint'] and column.column_name in ['status', 'is_delete', 'is_enabled']:
  222. column.html_type = 'radio'
  223. else:
  224. column.html_type = 'input'
  225. if not column.query_type:
  226. if column.column_type in ['varchar', 'char', 'text']:
  227. column.query_type = 'LIKE'
  228. else:
  229. column.query_type = 'EQ'
  230. gen_table_column_mapper.insert(column)
  231. return True
  232. except Exception as e:
  233. print(f"导入表失败: {e}")
  234. return False