util.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import os
  2. from typing import List
  3. from jinja2 import Template
  4. from ruoyi_common.utils import StringUtil
  5. from ruoyi_generator.domain.entity import GenTable
  6. from ruoyi_generator.config import GeneratorConfig
  7. import zipfile
  8. from io import BytesIO
  9. from datetime import datetime
  10. import re
  11. def to_underscore(name: str) -> str:
  12. """
  13. 将驼峰命名转换为下划线命名
  14. Args:
  15. name (str): 驼峰命名的字符串
  16. Returns:
  17. str: 下划线命名的字符串
  18. """
  19. # 在大写字母前添加下划线,然后转为小写
  20. s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
  21. return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
  22. class GenUtils:
  23. @staticmethod
  24. def get_file_name(template_file: str, table: GenTable) -> str:
  25. """
  26. 根据模板文件名和表信息生成文件名
  27. Args:
  28. template_file (str): 模板文件名
  29. table (GenTable): 表信息
  30. Returns:
  31. str: 生成的文件名
  32. """
  33. # 标准化路径分隔符
  34. template_file = template_file.replace('\\', '/')
  35. # 移除.vm后缀
  36. base_name = template_file[:-3] if template_file.endswith('.vm') else template_file
  37. # 根据模板类型生成文件名和路径
  38. if 'py/entity.py' in template_file:
  39. # 根据包名生成目录结构
  40. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  41. return f"{package_path}/domain/{table.class_name}.py"
  42. elif 'py/controller.py' in template_file:
  43. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  44. # 使用下划线命名法
  45. controller_name = f"{to_underscore(table.class_name)}_controller"
  46. return f"{package_path}/controller/{controller_name}.py"
  47. elif 'py/service.py' in template_file:
  48. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  49. # 使用下划线命名法
  50. service_name = f"{to_underscore(table.class_name)}_service"
  51. return f"{package_path}/service/{service_name}.py"
  52. elif 'py/mapper.py' in template_file:
  53. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  54. # 使用下划线命名法
  55. mapper_name = f"{to_underscore(table.class_name)}_mapper"
  56. return f"{package_path}/mapper/{mapper_name}.py"
  57. elif 'py/po.py' in template_file:
  58. package_path = table.package_name.replace('.', '/') if table.package_name else ''
  59. # PO文件使用表名作为文件名
  60. po_name = f"{table.class_name}PO"
  61. return f"{package_path}/domain/{po_name}.py"
  62. elif 'vue/index.vue' in template_file:
  63. return f"vue/{table.business_name}/index.vue"
  64. elif 'js/api.js' in template_file:
  65. return f"js/api/{table.business_name}.js"
  66. elif 'sql/menu.sql' in template_file:
  67. return f"sql/{table.business_name}_menu.sql"
  68. elif 'README.md' in template_file:
  69. return f"{table.business_name}_README.md"
  70. else:
  71. # 处理其他模板文件,保持原有目录结构
  72. filename = os.path.basename(base_name)
  73. if '.' not in filename:
  74. filename += '.py' # 默认添加.py扩展名
  75. return filename
  76. @staticmethod
  77. def get_table_prefix() -> str:
  78. """
  79. 获取表前缀
  80. Returns:
  81. str: 表前缀
  82. """
  83. return GeneratorConfig.table_prefix or ""
  84. @staticmethod
  85. def remove_table_prefix(table_name: str) -> str:
  86. """
  87. 移除表前缀
  88. Args:
  89. table_name (str): 表名
  90. Returns:
  91. str: 移除前缀后的表名
  92. """
  93. prefix = GenUtils.get_table_prefix()
  94. if prefix and table_name.startswith(prefix):
  95. return table_name[len(prefix):]
  96. return table_name
  97. @staticmethod
  98. def table_to_class_name(table_name: str) -> str:
  99. """
  100. 将表名转换为类名
  101. Args:
  102. table_name (str): 表名
  103. Returns:
  104. str: 类名
  105. """
  106. # 移除表前缀
  107. clean_table_name = GenUtils.remove_table_prefix(table_name)
  108. # 转换为驼峰命名
  109. return GenUtils.to_camel_case(clean_table_name)
  110. @staticmethod
  111. def get_business_name(table_name: str) -> str:
  112. """
  113. 获取业务名
  114. Args:
  115. table_name (str): 表名
  116. Returns:
  117. str: 业务名
  118. """
  119. # 移除表前缀
  120. clean_table_name = GenUtils.remove_table_prefix(table_name)
  121. # 获取下划线分隔的第一部分
  122. return GenUtils.substring_before(clean_table_name, "_") if "_" in clean_table_name else clean_table_name
  123. @staticmethod
  124. def to_camel_case(name: str) -> str:
  125. """
  126. 将下划线命名转换为驼峰命名
  127. Args:
  128. name (str): 下划线命名
  129. Returns:
  130. str: 驼峰命名
  131. """
  132. if hasattr(StringUtil, 'to_camel_case'):
  133. return StringUtil.to_camel_case(name)
  134. # 如果StringUtil没有to_camel_case方法,则手动实现
  135. parts = name.split('_')
  136. if len(parts) == 1:
  137. return parts[0]
  138. return parts[0] + ''.join(word.capitalize() for word in parts[1:])
  139. @staticmethod
  140. def substring_before(string: str, separator: str) -> str:
  141. """
  142. 获取字符串中分隔符之前的部分
  143. Args:
  144. string (str): 输入字符串
  145. separator (str): 分隔符
  146. Returns:
  147. str: 分隔符之前的部分
  148. """
  149. if hasattr(StringUtil, 'substring_before'):
  150. return StringUtil.substring_before(string, separator)
  151. # 如果StringUtil没有substring_before方法,则手动实现
  152. if separator in string:
  153. return string.split(separator, 1)[0]
  154. return string
  155. @staticmethod
  156. def substring_after(string: str, separator: str) -> str:
  157. """
  158. 获取字符串中分隔符之后的部分
  159. Args:
  160. string (str): 输入字符串
  161. separator (str): 分隔符
  162. Returns:
  163. str: 分隔符之后的部分
  164. """
  165. if hasattr(StringUtil, 'substring_after'):
  166. return StringUtil.substring_after(string, separator)
  167. # 如果StringUtil没有substring_after方法,则手动实现
  168. if separator in string:
  169. return string.split(separator, 1)[1]
  170. return ""
  171. @staticmethod
  172. def generator_code(table: GenTable) -> BytesIO:
  173. """
  174. 生成代码
  175. Args:
  176. table (GenTable): 表信息
  177. Returns:
  178. BytesIO: 生成的代码文件
  179. """
  180. # 获取模板目录
  181. template_dir = os.path.join(os.path.dirname(__file__), 'vm')
  182. # 创建内存中的ZIP文件
  183. zip_buffer = BytesIO()
  184. with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
  185. # 定义需要处理的模板文件
  186. template_files = []
  187. for root, dirs, files in os.walk(template_dir):
  188. for file in files:
  189. if file.endswith('.vm'):
  190. full_path = os.path.join(root, file)
  191. relative_path = os.path.relpath(full_path, template_dir)
  192. template_files.append(relative_path)
  193. # 处理每个模板文件
  194. for relative_path in template_files:
  195. template_path = os.path.join(template_dir, relative_path)
  196. if os.path.exists(template_path):
  197. # 读取模板内容
  198. try:
  199. with open(template_path, 'r', encoding='utf-8') as f:
  200. template_content = f.read()
  201. # 准备模板上下文
  202. context = {
  203. 'table': table,
  204. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  205. 'underscore': to_underscore # 添加自定义过滤器
  206. }
  207. # 使用Jinja2渲染模板
  208. template = Template(template_content)
  209. rendered_content = template.render(**context)
  210. # 生成文件名
  211. output_file_name = GenUtils.get_file_name(relative_path, table)
  212. # 检查渲染后的内容是否为空
  213. if rendered_content.strip():
  214. # 将渲染后的内容写入ZIP文件
  215. zip_file.writestr(output_file_name, rendered_content)
  216. else:
  217. print(f"警告: 模板 {relative_path} 渲染后内容为空")
  218. except Exception as e:
  219. print(f"处理模板 {relative_path} 时出错: {e}")
  220. zip_buffer.seek(0)
  221. return zip_buffer
  222. @staticmethod
  223. def batch_generator_code(tables: List[GenTable]) -> BytesIO:
  224. """
  225. 批量生成代码
  226. Args:
  227. tables (List[GenTable]): 表列表
  228. Returns:
  229. BytesIO: 生成的代码文件
  230. """
  231. # 创建内存中的ZIP文件
  232. zip_buffer = BytesIO()
  233. with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
  234. # 跟踪已添加的文件名以避免重复
  235. added_files = set()
  236. # 处理每个表
  237. for table in tables:
  238. # 获取模板目录
  239. template_dir = os.path.join(os.path.dirname(__file__), 'vm')
  240. # 定义需要处理的模板文件
  241. template_files = []
  242. for root, dirs, files in os.walk(template_dir):
  243. for file in files:
  244. if file.endswith('.vm'):
  245. full_path = os.path.join(root, file)
  246. relative_path = os.path.relpath(full_path, template_dir)
  247. template_files.append(relative_path)
  248. # 处理每个模板文件
  249. for relative_path in template_files:
  250. template_path = os.path.join(template_dir, relative_path)
  251. if os.path.exists(template_path):
  252. # 读取模板内容
  253. try:
  254. with open(template_path, 'r', encoding='utf-8') as f:
  255. template_content = f.read()
  256. # 准备模板上下文
  257. context = {
  258. 'table': table,
  259. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  260. 'underscore': to_underscore # 添加自定义过滤器
  261. }
  262. # 使用Jinja2渲染模板
  263. template = Template(template_content)
  264. rendered_content = template.render(**context)
  265. # 生成文件名
  266. output_file_name = GenUtils.get_file_name(relative_path, table)
  267. # 检查是否已添加同名文件
  268. if output_file_name in added_files:
  269. # 为重复文件添加序号
  270. name, ext = os.path.splitext(output_file_name)
  271. counter = 1
  272. new_name = f"{name}_{counter}{ext}"
  273. while new_name in added_files:
  274. counter += 1
  275. new_name = f"{name}_{counter}{ext}"
  276. output_file_name = new_name
  277. # 检查渲染后的内容是否为空
  278. if rendered_content.strip():
  279. # 将渲染后的内容写入ZIP文件
  280. zip_file.writestr(output_file_name, rendered_content)
  281. added_files.add(output_file_name)
  282. else:
  283. print(f"警告: 模板 {relative_path} 渲染后内容为空")
  284. except Exception as e:
  285. print(f"处理表 {table.table_name} 的模板 {relative_path} 时出错: {e}")
  286. zip_buffer.seek(0)
  287. return zip_buffer
  288. @staticmethod
  289. def preview_code(table: GenTable) -> dict:
  290. """
  291. 预览代码
  292. Args:
  293. table (GenTable): 表信息
  294. Returns:
  295. dict: 预览代码
  296. """
  297. # 获取模板目录
  298. template_dir = os.path.join(os.path.dirname(__file__), 'vm')
  299. # 存储预览代码的字典
  300. preview_data = {}
  301. # 定义需要预览的模板文件(严格按照项目要求)
  302. template_files = []
  303. for root, dirs, files in os.walk(template_dir):
  304. for file in files:
  305. if file.endswith('.vm'):
  306. full_path = os.path.join(root, file)
  307. relative_path = os.path.relpath(full_path, template_dir)
  308. template_files.append(relative_path)
  309. # 处理每个模板文件
  310. for relative_path in template_files:
  311. template_path = os.path.join(template_dir, relative_path)
  312. if os.path.exists(template_path):
  313. # 读取模板内容
  314. try:
  315. with open(template_path, 'r', encoding='utf-8') as f:
  316. template_content = f.read()
  317. # 准备模板上下文
  318. context = {
  319. 'table': table,
  320. 'datetime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  321. 'underscore': to_underscore # 添加自定义过滤器
  322. }
  323. # 使用Jinja2渲染模板
  324. template = Template(template_content)
  325. rendered_content = template.render(**context)
  326. # 存储渲染后的内容
  327. preview_data[relative_path] = rendered_content
  328. except Exception as e:
  329. # 如果渲染失败,存储原始模板内容
  330. with open(template_path, 'r', encoding='utf-8') as f:
  331. preview_data[relative_path] = f.read()
  332. return preview_data