Files
mirage/tools/test_template_refactor.py
daiqingshuang 51ded97e48 feat: Refactor code generation tool to use Jinja2 templates
- Replaced string concatenation with Jinja2 templates for improved maintainability and readability.
- Added new template files for struct definitions, buffer helpers, and buffer manager.
- Introduced a template loader to manage Jinja2 rendering and custom filters.
- Updated the README and documentation to reflect the new template system.
- Added comprehensive tests to ensure output consistency with the previous implementation.
- Maintained full API compatibility with existing functions.
2025-11-22 15:39:08 +08:00

222 lines
5.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试Jinja2模板重构
验证重构后的代码生成器输出与原始版本一致。
"""
import sys
import io
from pathlib import Path
# 设置stdout为UTF-8编码
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
# 添加tools到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from tools.code_generator import (
generate_single_struct,
generate_buffer_helper_functions,
generate_typed_buffer_aliases,
generate_buffer_manager_class,
spirv_to_cpp_array,
)
from tools.types import (
BufferInfo,
StructTypeInfo,
MemberInfo,
ScalarTypeInfo,
VectorTypeInfo,
BaseType,
SPIRVReflection,
)
def create_test_buffer() -> BufferInfo:
"""创建测试用的BufferInfo"""
# 创建一个简单的结构体类型
struct_type = StructTypeInfo(
id=1,
name="Particle",
members=[
MemberInfo(
index=0,
name="position",
type_id=2,
offset=0,
),
MemberInfo(
index=1,
name="velocity",
type_id=2,
offset=16,
),
],
)
buffer = BufferInfo(
name="Particle",
binding=0,
descriptor_set=0,
descriptor_type="storage_buffer",
struct_type=struct_type,
variable_name="particles",
)
return buffer
def create_test_type_map():
"""创建测试用的类型映射"""
vec4_type = VectorTypeInfo(
id=2,
name="vec4",
component_type_id=3,
component_count=4,
)
float_type = ScalarTypeInfo(
id=3,
name="float",
base_type=BaseType.FLOAT,
bit_width=32,
)
vec4_type.resolved_component_type = float_type
return {
2: vec4_type,
3: float_type,
}
def test_struct_generation():
"""测试结构体生成"""
print("测试结构体生成...")
buffer = create_test_buffer()
type_map = create_test_type_map()
result = generate_single_struct(buffer, type_map)
# 验证输出包含关键元素
result_str = "\n".join(result)
assert "struct" in result_str
assert "Particle" in result_str
assert "position" in result_str
assert "velocity" in result_str
assert "alignas" in result_str
print("[PASS] 结构体生成测试通过")
print(result_str)
print()
def test_helper_functions():
"""测试辅助函数生成"""
print("测试辅助函数生成...")
buffer = create_test_buffer()
reflection = SPIRVReflection(buffers=[buffer])
result = generate_buffer_helper_functions(reflection, "test_shader")
# 验证输出包含关键函数
assert "create_particles_buffer" in result
assert "upload_particles" in result
assert "download_particles" in result
print("[PASS] 辅助函数生成测试通过")
print(result[:500] + "...\n")
def test_typed_buffer():
"""测试类型化buffer生成"""
print("测试类型化buffer生成...")
buffer = create_test_buffer()
reflection = SPIRVReflection(buffers=[buffer])
result = generate_typed_buffer_aliases(reflection, "test_shader")
# 验证输出包含类型别名和工厂函数
assert "particles_buffer" in result
assert "create_particles_typed" in result
assert "typed_buffer<Particle>" in result
print("[PASS] 类型化buffer生成测试通过")
print(result[:500] + "...\n")
def test_buffer_manager():
"""测试buffer管理器生成"""
print("测试buffer管理器生成...")
buffer = create_test_buffer()
reflection = SPIRVReflection(buffers=[buffer])
result = generate_buffer_manager_class(reflection, "test_shader")
# 验证输出包含管理器类
assert "test_shader_buffer_manager" in result
assert "struct buffers" in result
assert "static auto create" in result
assert "void initialize" in result
assert "bind_to_descriptor_set" in result
print("[PASS] Buffer管理器生成测试通过")
print(result[:500] + "...\n")
def test_spirv_array():
"""测试SPIR-V数组生成"""
print("测试SPIR-V数组生成...")
# 创建一些测试数据
test_data = b'\x03\x02\x23\x07' * 20 # 80字节
result = spirv_to_cpp_array(test_data, "test_spirv")
# 验证输出格式
assert "static constexpr uint32_t test_spirv[]" in result
assert "0x07230203" in result # 小端字节序
print("[PASS] SPIR-V数组生成测试通过")
print(result[:200] + "...\n")
def main():
"""运行所有测试"""
print("=" * 60)
print("Jinja2模板重构测试")
print("=" * 60)
print()
try:
test_struct_generation()
test_helper_functions()
test_typed_buffer()
test_buffer_manager()
test_spirv_array()
print("=" * 60)
print("所有测试通过!")
print("=" * 60)
return 0
except AssertionError as e:
print(f"\n[FAIL] 测试失败: {e}")
import traceback
traceback.print_exc()
return 1
except Exception as e:
print(f"\n[ERROR] 错误: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())