util.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. import os
  2. from typing import List
  3. from jinja2 import Template
  4. from ruoyi_common.utils import StringUtil
  5. from ruoyi_generator.domain.entity import GenTable
  6. from ruoyi_generator.config import GeneratorConfig
  7. import zipfile
  8. from io import BytesIO
  9. from datetime import datetime
  10. import re
  11. def to_underscore(name: str) -> str:
  12. """
  13. 将驼峰命名转换为下划线命名
  14. Args:
  15. name (str): 驼峰命名的字符串
  16. Returns:
  17. str: 下划线命名的字符串
  18. """
  19. # 在大写字母前添加下划线,然后转为小写
  20. s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
  21. return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
  22. class GenUtils:
  23. @staticmethod
  24. def get_file_name(template_file: str, table: GenTable) -> str:
  25. """
  26. 根据模板文件名和表信息生成文件名
  27. Args:
  28. template_file (str): 模板文件名
  29. table (GenTable): 表信息
  30. Returns:
  31. str: 生成的文件名
  32. """
  33. # 标准化路径分隔符
  34. template_file = template_file.replace('\\', '/')
  35. # 移除.vm后缀
  36. base_name = template_file[:-3] if template_file.endswith('.vm') else template_file
  37. # 根据模板类型生成文件名和路径
  38. if 'py/entity.py' in template_file:
  39. # 根据包名生成目录结构
  40. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  41. return f"{package_path}/domain/{table.class_name}.py"
  42. elif 'py/controller.py' in template_file:
  43. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  44. # 使用下划线命名法
  45. controller_name = f"{to_underscore(table.class_name)}_controller"
  46. return f"{package_path}/controller/{controller_name}.py"
  47. elif 'py/service.py' in template_file:
  48. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  49. # 使用下划线命名法
  50. service_name = f"{to_underscore(table.class_name)}_service"
  51. return f"{package_path}/service/{service_name}.py"
  52. elif 'py/mapper.py' in template_file:
  53. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  54. # 使用下划线命名法
  55. mapper_name = f"{to_underscore(table.class_name)}_mapper"
  56. return f"{package_path}/mapper/{mapper_name}.py"
  57. elif 'py/po.py' in template_file:
  58. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  59. # PO文件使用表名作为文件名
  60. po_name = f"{table.class_name}PO"
  61. return f"{package_path}/domain/{po_name}.py"
  62. elif 'vue/index.vue' in template_file:
  63. # 无论是树表还是普通表,Vue文件名都是index.vue
  64. return f"vue/{table.business_name}/index.vue"
  65. elif 'js/api.js' in template_file:
  66. return f"js/api/{table.business_name}.js"
  67. elif 'sql/menu.sql' in template_file:
  68. return f"sql/{table.business_name}_menu.sql"
  69. elif 'README.md' in template_file:
  70. return f"{table.business_name}_README.md"
  71. else:
  72. # 处理其他模板文件,保持原有目录结构
  73. filename = os.path.basename(base_name)
  74. if '.' not in filename:
  75. filename += '.py' # 默认添加.py扩展名
  76. return filename
  77. @staticmethod
  78. def get_table_prefix() -> str:
  79. """
  80. 获取表前缀
  81. Returns:
  82. str: 表前缀
  83. """
  84. return GeneratorConfig.table_prefix or ""
  85. @staticmethod
  86. def remove_table_prefix(table_name: str) -> str:
  87. """
  88. 移除表前缀
  89. Args:
  90. table_name (str): 表名
  91. Returns:
  92. str: 移除前缀后的表名
  93. """
  94. prefix = GenUtils.get_table_prefix()
  95. if prefix and table_name.startswith(prefix):
  96. return table_name[len(prefix):]
  97. return table_name
  98. @staticmethod
  99. def table_to_class_name(table_name: str) -> str:
  100. """
  101. 将表名转换为类名
  102. Args:
  103. table_name (str): 表名
  104. Returns:
  105. str: 类名
  106. """
  107. # 移除表前缀
  108. clean_table_name = GenUtils.remove_table_prefix(table_name)
  109. # 转换为驼峰命名
  110. return GenUtils.to_camel_case(clean_table_name)
  111. @staticmethod
  112. def get_business_name(table_name: str) -> str:
  113. """
  114. 获取业务名
  115. Args:
  116. table_name (str): 表名
  117. Returns:
  118. str: 业务名
  119. """
  120. # 移除表前缀
  121. clean_table_name = GenUtils.remove_table_prefix(table_name)
  122. # 获取下划线分隔的第一部分
  123. return GenUtils.substring_before(clean_table_name, "_") if "_" in clean_table_name else clean_table_name
  124. @staticmethod
  125. def get_import_path(package_name: str, module_type: str, class_name: str = None) -> str:
  126. """
  127. 生成导入路径
  128. Args:
  129. package_name (str): 包名,如 "com.yy.project" 或 "ruoyi_generator"
  130. module_type (str): 模块类型,如 "domain", "service", "mapper", "controller"
  131. class_name (str): 类名(可选,用于PO导入)
  132. Returns:
  133. str: 导入路径,Python包名保持点分隔格式
  134. """
  135. if not package_name:
  136. return f"ruoyi_generator.{module_type}"
  137. # Python导入路径使用点分隔,保持原样
  138. # 例如: "com.yy.project" -> "com.yy.project"
  139. # 例如: "ruoyi_generator" -> "ruoyi_generator" (保持不变)
  140. python_package = package_name
  141. # 生成导入路径
  142. if module_type == "domain" and class_name:
  143. return f"{python_package}.domain.po"
  144. elif module_type == "domain":
  145. return f"{python_package}.domain.entity"
  146. else:
  147. return f"{python_package}.{module_type}"
  148. @staticmethod
  149. def to_camel_case(name: str) -> str:
  150. """
  151. 将下划线命名转换为驼峰命名
  152. Args:
  153. name (str): 下划线命名
  154. Returns:
  155. str: 驼峰命名
  156. """
  157. if hasattr(StringUtil, 'to_camel_case'):
  158. return StringUtil.to_camel_case(name)
  159. # 如果StringUtil没有to_camel_case方法,则手动实现
  160. parts = name.split('_')
  161. if len(parts) == 1:
  162. return parts[0]
  163. return parts[0] + ''.join(word.capitalize() for word in parts[1:])
  164. @staticmethod
  165. def substring_before(string: str, separator: str) -> str:
  166. """
  167. 获取字符串中分隔符之前的部分
  168. Args:
  169. string (str): 输入字符串
  170. separator (str): 分隔符
  171. Returns:
  172. str: 分隔符之前的部分
  173. """
  174. if hasattr(StringUtil, 'substring_before'):
  175. return StringUtil.substring_before(string, separator)
  176. # 如果StringUtil没有substring_before方法,则手动实现
  177. if separator in string:
  178. return string.split(separator, 1)[0]
  179. return string
  180. @staticmethod
  181. def substring_after(string: str, separator: str) -> str:
  182. """
  183. 获取字符串中分隔符之后的部分
  184. Args:
  185. string (str): 输入字符串
  186. separator (str): 分隔符
  187. Returns:
  188. str: 分隔符之后的部分
  189. """
  190. if hasattr(StringUtil, 'substring_after'):
  191. return StringUtil.substring_after(string, separator)
  192. # 如果StringUtil没有substring_after方法,则手动实现
  193. if separator in string:
  194. return string.split(separator, 1)[1]
  195. return ""
  196. @staticmethod
  197. def generator_code(table: GenTable) -> BytesIO:
  198. """
  199. 生成代码
  200. Args:
  201. table (GenTable): 表信息
  202. Returns:
  203. BytesIO: 生成的代码文件
  204. """
  205. # 设置列的 list_index 属性
  206. GenUtils.set_column_list_index(table)
  207. # 设置主键列
  208. pk_columns = [column for column in table.columns if column.is_pk == '1']
  209. if pk_columns:
  210. table.pk_column = pk_columns[0]
  211. else:
  212. table.pk_column = None
  213. # 获取模板目录
  214. template_dir = os.path.join(os.path.dirname(__file__), 'vm')
  215. # 定义核心模板文件
  216. core_templates = [
  217. 'py/entity.py.vm',
  218. 'py/po.py.vm',
  219. 'py/controller.py.vm',
  220. 'py/service.py.vm',
  221. 'py/mapper.py.vm',
  222. 'js/api.js.vm',
  223. 'sql/menu.sql.vm'
  224. ]
  225. # 根据表类型添加相应的Vue模板
  226. if table.tpl_category == 'tree':
  227. core_templates.append('vue/index-tree.vue.vm')
  228. else:
  229. core_templates.append('vue/index.vue.vm')
  230. # 创建内存中的ZIP文件
  231. zip_buffer = BytesIO()
  232. with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
  233. # 处理每个核心模板文件
  234. for relative_path in core_templates:
  235. template_path = os.path.join(template_dir, relative_path)
  236. if os.path.exists(template_path):
  237. # 读取模板内容
  238. try:
  239. with open(template_path, 'r', encoding='utf-8') as f:
  240. template_content = f.read()
  241. # 准备模板上下文
  242. context = {
  243. 'table': table,
  244. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  245. 'underscore': to_underscore, # 添加自定义过滤器
  246. 'get_import_path': GenUtils.get_import_path # 添加导入路径生成函数
  247. }
  248. # 使用Jinja2渲染模板
  249. template = Template(template_content)
  250. rendered_content = template.render(**context)
  251. # 生成文件名
  252. output_file_name = GenUtils.get_file_name(relative_path, table)
  253. # 检查渲染后的内容是否为空
  254. if rendered_content.strip():
  255. # 将渲染后的内容写入ZIP文件
  256. zip_file.writestr(output_file_name, rendered_content)
  257. else:
  258. print(f"警告: 模板 {relative_path} 渲染后内容为空")
  259. except Exception as e:
  260. print(f"处理模板 {relative_path} 时出错: {e}")
  261. zip_buffer.seek(0)
  262. return zip_buffer
  263. @staticmethod
  264. def batch_generator_code(tables: List[GenTable]) -> BytesIO:
  265. """
  266. 批量生成代码
  267. Args:
  268. tables (List[GenTable]): 表列表
  269. Returns:
  270. BytesIO: 生成的代码文件
  271. """
  272. # 为每个表设置列的 list_index 属性和主键列
  273. for table in tables:
  274. GenUtils.set_column_list_index(table)
  275. # 设置主键列
  276. pk_columns = [column for column in table.columns if column.is_pk == '1']
  277. if pk_columns:
  278. table.pk_column = pk_columns[0]
  279. else:
  280. table.pk_column = None
  281. # 定义核心模板文件
  282. core_templates = [
  283. 'py/entity.py.vm',
  284. 'py/po.py.vm',
  285. 'py/controller.py.vm',
  286. 'py/service.py.vm',
  287. 'py/mapper.py.vm',
  288. 'js/api.js.vm',
  289. 'sql/menu.sql.vm'
  290. ]
  291. # 创建内存中的ZIP文件
  292. zip_buffer = BytesIO()
  293. # 获取模板目录
  294. template_dir = os.path.join(os.path.dirname(__file__), 'vm')
  295. with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
  296. # 跟踪已添加的文件名以避免重复
  297. added_files = set()
  298. # 处理每个表
  299. for table in tables:
  300. # 根据表类型添加相应的Vue模板
  301. if table.tpl_category == 'tree':
  302. current_templates = core_templates + ['vue/index-tree.vue.vm']
  303. else:
  304. current_templates = core_templates + ['vue/index.vue.vm']
  305. # 处理每个核心模板文件
  306. for relative_path in current_templates:
  307. template_path = os.path.join(template_dir, relative_path)
  308. if os.path.exists(template_path):
  309. # 读取模板内容
  310. try:
  311. with open(template_path, 'r', encoding='utf-8') as f:
  312. template_content = f.read()
  313. # 准备模板上下文
  314. context = {
  315. 'table': table,
  316. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  317. 'underscore': to_underscore, # 添加自定义过滤器
  318. 'get_import_path': GenUtils.get_import_path # 添加导入路径生成函数
  319. }
  320. # 使用Jinja2渲染模板
  321. template = Template(template_content)
  322. rendered_content = template.render(**context)
  323. # 生成文件名
  324. output_file_name = GenUtils.get_file_name(relative_path, table)
  325. # 检查是否已添加同名文件
  326. if output_file_name in added_files:
  327. # 为重复文件添加序号
  328. name, ext = os.path.splitext(output_file_name)
  329. counter = 1
  330. new_name = f"{name}_{counter}{ext}"
  331. while new_name in added_files:
  332. counter += 1
  333. new_name = f"{name}_{counter}{ext}"
  334. output_file_name = new_name
  335. # 检查渲染后的内容是否为空
  336. if rendered_content.strip():
  337. # 将渲染后的内容写入ZIP文件
  338. zip_file.writestr(output_file_name, rendered_content)
  339. added_files.add(output_file_name)
  340. else:
  341. print(f"警告: 模板 {relative_path} 渲染后内容为空")
  342. except Exception as e:
  343. print(f"处理表 {table.table_name} 的模板 {relative_path} 时出错: {e}")
  344. zip_buffer.seek(0)
  345. return zip_buffer
  346. @staticmethod
  347. def set_column_list_index(table: GenTable):
  348. """
  349. 为表的列设置 list_index 属性,用于 Vue 模板中的 columns 数组索引
  350. Args:
  351. table (GenTable): 表信息
  352. """
  353. if not table.columns:
  354. return
  355. list_index = 0
  356. for column in table.columns:
  357. if column.is_list == '1':
  358. # 使用 setattr 动态添加属性
  359. setattr(column, 'list_index', list_index)
  360. list_index += 1
  361. @staticmethod
  362. def preview_code(table: GenTable) -> dict:
  363. """
  364. 预览代码
  365. Args:
  366. table (GenTable): 表信息
  367. Returns:
  368. dict: 预览代码
  369. """
  370. # 设置列的 list_index 属性
  371. GenUtils.set_column_list_index(table)
  372. # 设置主键列
  373. pk_columns = [column for column in table.columns if column.is_pk == '1']
  374. if pk_columns:
  375. table.pk_column = pk_columns[0]
  376. else:
  377. table.pk_column = None
  378. # 获取模板目录
  379. template_dir = os.path.join(os.path.dirname(__file__), 'vm')
  380. # 存储预览代码的字典
  381. preview_data = {}
  382. # 定义需要预览的核心模板文件
  383. core_templates = [
  384. 'py/entity.py.vm',
  385. 'py/po.py.vm',
  386. 'py/controller.py.vm',
  387. 'py/service.py.vm',
  388. 'py/mapper.py.vm',
  389. 'js/api.js.vm',
  390. 'sql/menu.sql.vm'
  391. ]
  392. # 根据表类型添加相应的Vue模板,但预览时都使用index.vue.vm作为文件名
  393. if table.tpl_category == 'tree':
  394. core_templates.append('vue/index-tree.vue.vm')
  395. else:
  396. core_templates.append('vue/index.vue.vm')
  397. # 处理每个核心模板文件
  398. for relative_path in core_templates:
  399. template_path = os.path.join(template_dir, relative_path)
  400. if os.path.exists(template_path):
  401. # 读取模板内容
  402. try:
  403. with open(template_path, 'r', encoding='utf-8') as f:
  404. template_content = f.read()
  405. # 准备模板上下文
  406. context = {
  407. 'table': table,
  408. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  409. 'underscore': to_underscore, # 添加自定义过滤器
  410. 'get_import_path': GenUtils.get_import_path # 添加导入路径生成函数
  411. }
  412. # 使用Jinja2渲染模板
  413. template = Template(template_content)
  414. rendered_content = template.render(**context)
  415. # 存储渲染后的内容
  416. preview_data[relative_path] = rendered_content
  417. except Exception as e:
  418. # 如果渲染失败,存储错误信息
  419. preview_data[relative_path] = f"模板渲染失败: {str(e)}"
  420. else:
  421. preview_data[relative_path] = "模板文件不存在"
  422. return preview_data