- Replaced string concatenation with Jinja2 templates for improved maintainability and readability. - Added new template files for struct definitions, buffer helpers, and buffer manager. - Introduced a template loader to manage Jinja2 rendering and custom filters. - Updated the README and documentation to reflect the new template system. - Added comprehensive tests to ensure output consistency with the previous implementation. - Maintained full API compatibility with existing functions.
110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
模板加载和渲染模块
|
||
|
||
使用Jinja2模板引擎生成C++代码,替代原有的字符串拼接方式。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
from typing import Any, Dict
|
||
|
||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||
|
||
from .type_mapping import (
|
||
calculate_std430_alignment,
|
||
calculate_std430_size,
|
||
spirv_type_to_cpp,
|
||
)
|
||
from .types import ArrayTypeInfo, StructTypeInfo, VectorTypeInfo, MatrixTypeInfo
|
||
|
||
|
||
class TemplateRenderer:
|
||
"""Jinja2模板渲染器
|
||
|
||
提供统一的模板加载和渲染接口,支持自定义过滤器和全局函数。
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化模板渲染器"""
|
||
template_dir = Path(__file__).parent / "templates"
|
||
|
||
self.env = Environment(
|
||
loader=FileSystemLoader(str(template_dir)),
|
||
autoescape=select_autoescape([]), # C++代码不需要自动转义
|
||
trim_blocks=False, # 保留块后的换行符以确保正确格式
|
||
lstrip_blocks=False, # 保留块前的空白以确保正确格式
|
||
keep_trailing_newline=True # 保留末尾换行符
|
||
)
|
||
|
||
# 注册自定义过滤器和全局函数
|
||
self._register_filters()
|
||
self._register_globals()
|
||
|
||
def _register_filters(self):
|
||
"""注册Jinja2自定义过滤器"""
|
||
# 类型转换过滤器
|
||
self.env.filters['spirv_type_to_cpp'] = spirv_type_to_cpp
|
||
self.env.filters['calculate_alignment'] = calculate_std430_alignment
|
||
self.env.filters['calculate_size'] = calculate_std430_size
|
||
|
||
# 工具过滤器
|
||
self.env.filters['format_hex'] = lambda x: f"0x{x:08x}"
|
||
self.env.filters['join_with_comma'] = lambda items: ", ".join(str(i) for i in items)
|
||
|
||
def _register_globals(self):
|
||
"""注册Jinja2全局函数和变量"""
|
||
# Python内置函数
|
||
self.env.globals.update({
|
||
'isinstance': isinstance,
|
||
'len': len,
|
||
'enumerate': enumerate,
|
||
'range': range,
|
||
})
|
||
|
||
# 类型信息类
|
||
self.env.globals.update({
|
||
'ArrayTypeInfo': ArrayTypeInfo,
|
||
'StructTypeInfo': StructTypeInfo,
|
||
'VectorTypeInfo': VectorTypeInfo,
|
||
'MatrixTypeInfo': MatrixTypeInfo,
|
||
})
|
||
|
||
def render(self, template_name: str, context: Dict[str, Any]) -> str:
|
||
"""渲染模板
|
||
|
||
Args:
|
||
template_name: 模板文件名(相对于templates目录)
|
||
context: 模板上下文变量
|
||
|
||
Returns:
|
||
渲染后的字符串
|
||
"""
|
||
template = self.env.get_template(template_name)
|
||
return template.render(context)
|
||
|
||
def render_to_lines(self, template_name: str, context: Dict[str, Any]) -> list[str]:
|
||
"""渲染模板并返回行列表
|
||
|
||
Args:
|
||
template_name: 模板文件名
|
||
context: 模板上下文变量
|
||
|
||
Returns:
|
||
渲染后的行列表
|
||
"""
|
||
result = self.render(template_name, context)
|
||
return result.splitlines()
|
||
|
||
|
||
# 全局渲染器实例(单例)
|
||
_renderer = None
|
||
|
||
|
||
def get_renderer() -> TemplateRenderer:
|
||
"""获取全局模板渲染器实例"""
|
||
global _renderer
|
||
if _renderer is None:
|
||
_renderer = TemplateRenderer()
|
||
return _renderer |