Files
mirage/tools/template_loader.py
daiqingshuang 51ded97e48 feat: Refactor code generation tool to use Jinja2 templates
- Replaced string concatenation with Jinja2 templates for improved maintainability and readability.
- Added new template files for struct definitions, buffer helpers, and buffer manager.
- Introduced a template loader to manage Jinja2 rendering and custom filters.
- Updated the README and documentation to reflect the new template system.
- Added comprehensive tests to ensure output consistency with the previous implementation.
- Maintained full API compatibility with existing functions.
2025-11-22 15:39:08 +08:00

110 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
模板加载和渲染模块
使用Jinja2模板引擎生成C++代码,替代原有的字符串拼接方式。
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict
from jinja2 import Environment, FileSystemLoader, select_autoescape
from .type_mapping import (
calculate_std430_alignment,
calculate_std430_size,
spirv_type_to_cpp,
)
from .types import ArrayTypeInfo, StructTypeInfo, VectorTypeInfo, MatrixTypeInfo
class TemplateRenderer:
"""Jinja2模板渲染器
提供统一的模板加载和渲染接口,支持自定义过滤器和全局函数。
"""
def __init__(self):
"""初始化模板渲染器"""
template_dir = Path(__file__).parent / "templates"
self.env = Environment(
loader=FileSystemLoader(str(template_dir)),
autoescape=select_autoescape([]), # C++代码不需要自动转义
trim_blocks=False, # 保留块后的换行符以确保正确格式
lstrip_blocks=False, # 保留块前的空白以确保正确格式
keep_trailing_newline=True # 保留末尾换行符
)
# 注册自定义过滤器和全局函数
self._register_filters()
self._register_globals()
def _register_filters(self):
"""注册Jinja2自定义过滤器"""
# 类型转换过滤器
self.env.filters['spirv_type_to_cpp'] = spirv_type_to_cpp
self.env.filters['calculate_alignment'] = calculate_std430_alignment
self.env.filters['calculate_size'] = calculate_std430_size
# 工具过滤器
self.env.filters['format_hex'] = lambda x: f"0x{x:08x}"
self.env.filters['join_with_comma'] = lambda items: ", ".join(str(i) for i in items)
def _register_globals(self):
"""注册Jinja2全局函数和变量"""
# Python内置函数
self.env.globals.update({
'isinstance': isinstance,
'len': len,
'enumerate': enumerate,
'range': range,
})
# 类型信息类
self.env.globals.update({
'ArrayTypeInfo': ArrayTypeInfo,
'StructTypeInfo': StructTypeInfo,
'VectorTypeInfo': VectorTypeInfo,
'MatrixTypeInfo': MatrixTypeInfo,
})
def render(self, template_name: str, context: Dict[str, Any]) -> str:
"""渲染模板
Args:
template_name: 模板文件名相对于templates目录
context: 模板上下文变量
Returns:
渲染后的字符串
"""
template = self.env.get_template(template_name)
return template.render(context)
def render_to_lines(self, template_name: str, context: Dict[str, Any]) -> list[str]:
"""渲染模板并返回行列表
Args:
template_name: 模板文件名
context: 模板上下文变量
Returns:
渲染后的行列表
"""
result = self.render(template_name, context)
return result.splitlines()
# 全局渲染器实例(单例)
_renderer = None
def get_renderer() -> TemplateRenderer:
"""获取全局模板渲染器实例"""
global _renderer
if _renderer is None:
_renderer = TemplateRenderer()
return _renderer