- Replaced string concatenation with Jinja2 templates for improved maintainability and readability. - Added new template files for struct definitions, buffer helpers, and buffer manager. - Introduced a template loader to manage Jinja2 rendering and custom filters. - Updated the README and documentation to reflect the new template system. - Added comprehensive tests to ensure output consistency with the previous implementation. - Maintained full API compatibility with existing functions.
222 lines
5.4 KiB
Python
222 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
测试Jinja2模板重构
|
|
|
|
验证重构后的代码生成器输出与原始版本一致。
|
|
"""
|
|
|
|
import sys
|
|
import io
|
|
from pathlib import Path
|
|
|
|
# 设置stdout为UTF-8编码
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
|
|
|
# 添加tools到路径
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from tools.code_generator import (
|
|
generate_single_struct,
|
|
generate_buffer_helper_functions,
|
|
generate_typed_buffer_aliases,
|
|
generate_buffer_manager_class,
|
|
spirv_to_cpp_array,
|
|
)
|
|
from tools.types import (
|
|
BufferInfo,
|
|
StructTypeInfo,
|
|
MemberInfo,
|
|
ScalarTypeInfo,
|
|
VectorTypeInfo,
|
|
BaseType,
|
|
SPIRVReflection,
|
|
)
|
|
|
|
|
|
def create_test_buffer() -> BufferInfo:
|
|
"""创建测试用的BufferInfo"""
|
|
# 创建一个简单的结构体类型
|
|
struct_type = StructTypeInfo(
|
|
id=1,
|
|
name="Particle",
|
|
members=[
|
|
MemberInfo(
|
|
index=0,
|
|
name="position",
|
|
type_id=2,
|
|
offset=0,
|
|
),
|
|
MemberInfo(
|
|
index=1,
|
|
name="velocity",
|
|
type_id=2,
|
|
offset=16,
|
|
),
|
|
],
|
|
)
|
|
|
|
buffer = BufferInfo(
|
|
name="Particle",
|
|
binding=0,
|
|
descriptor_set=0,
|
|
descriptor_type="storage_buffer",
|
|
struct_type=struct_type,
|
|
variable_name="particles",
|
|
)
|
|
|
|
return buffer
|
|
|
|
|
|
def create_test_type_map():
|
|
"""创建测试用的类型映射"""
|
|
vec4_type = VectorTypeInfo(
|
|
id=2,
|
|
name="vec4",
|
|
component_type_id=3,
|
|
component_count=4,
|
|
)
|
|
|
|
float_type = ScalarTypeInfo(
|
|
id=3,
|
|
name="float",
|
|
base_type=BaseType.FLOAT,
|
|
bit_width=32,
|
|
)
|
|
|
|
vec4_type.resolved_component_type = float_type
|
|
|
|
return {
|
|
2: vec4_type,
|
|
3: float_type,
|
|
}
|
|
|
|
|
|
def test_struct_generation():
|
|
"""测试结构体生成"""
|
|
print("测试结构体生成...")
|
|
|
|
buffer = create_test_buffer()
|
|
type_map = create_test_type_map()
|
|
|
|
result = generate_single_struct(buffer, type_map)
|
|
|
|
# 验证输出包含关键元素
|
|
result_str = "\n".join(result)
|
|
assert "struct" in result_str
|
|
assert "Particle" in result_str
|
|
assert "position" in result_str
|
|
assert "velocity" in result_str
|
|
assert "alignas" in result_str
|
|
|
|
print("[PASS] 结构体生成测试通过")
|
|
print(result_str)
|
|
print()
|
|
|
|
|
|
def test_helper_functions():
|
|
"""测试辅助函数生成"""
|
|
print("测试辅助函数生成...")
|
|
|
|
buffer = create_test_buffer()
|
|
reflection = SPIRVReflection(buffers=[buffer])
|
|
|
|
result = generate_buffer_helper_functions(reflection, "test_shader")
|
|
|
|
# 验证输出包含关键函数
|
|
assert "create_particles_buffer" in result
|
|
assert "upload_particles" in result
|
|
assert "download_particles" in result
|
|
|
|
print("[PASS] 辅助函数生成测试通过")
|
|
print(result[:500] + "...\n")
|
|
|
|
|
|
def test_typed_buffer():
|
|
"""测试类型化buffer生成"""
|
|
print("测试类型化buffer生成...")
|
|
|
|
buffer = create_test_buffer()
|
|
reflection = SPIRVReflection(buffers=[buffer])
|
|
|
|
result = generate_typed_buffer_aliases(reflection, "test_shader")
|
|
|
|
# 验证输出包含类型别名和工厂函数
|
|
assert "particles_buffer" in result
|
|
assert "create_particles_typed" in result
|
|
assert "typed_buffer<Particle>" in result
|
|
|
|
print("[PASS] 类型化buffer生成测试通过")
|
|
print(result[:500] + "...\n")
|
|
|
|
|
|
def test_buffer_manager():
|
|
"""测试buffer管理器生成"""
|
|
print("测试buffer管理器生成...")
|
|
|
|
buffer = create_test_buffer()
|
|
reflection = SPIRVReflection(buffers=[buffer])
|
|
|
|
result = generate_buffer_manager_class(reflection, "test_shader")
|
|
|
|
# 验证输出包含管理器类
|
|
assert "test_shader_buffer_manager" in result
|
|
assert "struct buffers" in result
|
|
assert "static auto create" in result
|
|
assert "void initialize" in result
|
|
assert "bind_to_descriptor_set" in result
|
|
|
|
print("[PASS] Buffer管理器生成测试通过")
|
|
print(result[:500] + "...\n")
|
|
|
|
|
|
def test_spirv_array():
|
|
"""测试SPIR-V数组生成"""
|
|
print("测试SPIR-V数组生成...")
|
|
|
|
# 创建一些测试数据
|
|
test_data = b'\x03\x02\x23\x07' * 20 # 80字节
|
|
|
|
result = spirv_to_cpp_array(test_data, "test_spirv")
|
|
|
|
# 验证输出格式
|
|
assert "static constexpr uint32_t test_spirv[]" in result
|
|
assert "0x07230203" in result # 小端字节序
|
|
|
|
print("[PASS] SPIR-V数组生成测试通过")
|
|
print(result[:200] + "...\n")
|
|
|
|
|
|
def main():
|
|
"""运行所有测试"""
|
|
print("=" * 60)
|
|
print("Jinja2模板重构测试")
|
|
print("=" * 60)
|
|
print()
|
|
|
|
try:
|
|
test_struct_generation()
|
|
test_helper_functions()
|
|
test_typed_buffer()
|
|
test_buffer_manager()
|
|
test_spirv_array()
|
|
|
|
print("=" * 60)
|
|
print("所有测试通过!")
|
|
print("=" * 60)
|
|
return 0
|
|
|
|
except AssertionError as e:
|
|
print(f"\n[FAIL] 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
except Exception as e:
|
|
print(f"\n[ERROR] 错误: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |