Files
mirage/tools/code_generator.py
2025-12-25 13:20:16 +08:00

714 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
代码生成模块
负责生成C++代码包括结构体定义、SPIR-V数组、头文件和绑定信息。
使用Jinja2模板引擎代替字符串拼接。
"""
from __future__ import annotations
import struct
from typing import Dict, List
from .template_loader import get_renderer
from .type_mapping import (
calculate_std430_alignment,
calculate_std430_size,
spirv_type_to_cpp,
)
from .types import (
ArrayTypeInfo,
BufferInfo,
CompilationResult,
InstanceBufferInfo,
MemberInfo,
PushConstantInfo,
ShaderMetadata,
SPIRVReflection,
ToolError,
TypeInfo,
)
from .constants import (
PUSH_CONSTANT_HEADER_SIZE,
PUSH_CONSTANT_CUSTOM_OFFSET,
PUSH_CONSTANT_CUSTOM_MAX_SIZE,
)
# ============ Structure Generation ============
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
"""生成所有buffer/uniform结构体定义"""
if not reflection.buffers:
return ""
lines = []
for buffer in reflection.buffers:
struct_lines = generate_single_struct(buffer, reflection.types)
lines.extend(struct_lines)
lines.append("")
return "\n".join(lines)
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成单个结构体定义"""
renderer = get_renderer()
context = {
'buffer': buffer,
'type_map': type_map,
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
}
result = renderer.render('base/struct.jinja2', context)
return result.splitlines()
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成结构体成员定义"""
renderer = get_renderer()
context = {
'member': member,
'type_map': type_map,
}
result = renderer.render('base/struct_member.jinja2', context)
return result.splitlines()
# ============ Push Constants Custom Structure Generation ============
def generate_push_constant_custom_struct(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants 用户自定义参数结构体定义 (Custom 部分)
布局规范:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Push Constants 内存布局 (128 bytes) │
├─────────────────────────────────────────────────────────────────────────────┤
│ Header [0-15]: scale + translate - 16 bytes │ ← 系统自动填充
│ Custom [16-127]: 用户自定义数据 - 112 bytes │ ← 本函数生成
└─────────────────────────────────────────────────────────────────────────────┘
识别规则:
- 仅收集 offset >= 16 的成员(跳过 scale 和 translate
- 计算相对偏移量(绝对偏移 - 16
- 生成的 C++ 结构体只包含 Custom 区域的变量
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的C++结构体代码,如果没有用户自定义部分则返回空字符串
Note:
与 push_constant_traits.h 中的 PUSH_CONSTANT_CUSTOM_OFFSET (16)
和 PUSH_CONSTANT_CUSTOM_MAX_SIZE (112) 常量保持一致。
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.effect_members:
return ""
# 生成结构体名称(保持兼容性,内部使用 custom_members
struct_name = f"{_to_pascal_case(shader_name)}PushConstantCustom"
# 准备成员信息
members = []
type_map = reflection.types
for member in push_constant.effect_members: # effect_members 现在存储的是 custom_members
# 计算对齐和大小
alignment = calculate_std430_alignment(member.resolved_type, type_map)
size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
members.append({
'name': member.name,
'cpp_type': cpp_type,
'offset': member.offset,
'size': size,
'alignment': alignment,
})
renderer = get_renderer()
context = {
'struct_name': struct_name,
'base_offset': PUSH_CONSTANT_CUSTOM_OFFSET, # 使用 Custom 部分偏移量
'members': members,
}
return renderer.render('push_constant/custom_struct.jinja2', context)
# 兼容性别名
def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str:
"""已弃用:请使用 generate_push_constant_custom_struct
保留此函数以保持向后兼容性。
"""
return generate_push_constant_custom_struct(reflection, shader_name)
def generate_push_constant_header_layout(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants Header 部分布局信息结构体
Header 部分包含 scale 和 translate 变量(偏移 0-15用于 NDC 变换:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Header [0-15]: scale (vec2) + translate (vec2) - 16 bytes │
│ scale: 偏移 0, 大小 8 │
│ translate: 偏移 8, 大小 8 │
└─────────────────────────────────────────────────────────────────────────────┘
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的C++代码,如果没有 Header 部分则返回空字符串
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.base_info:
return ""
base_info = push_constant.base_info
# 准备模板数据
header_members = [
{
'name': m.name,
'offset': m.offset,
'size': m.size,
'expected_offset': m.expected_offset,
'expected_size': m.expected_size,
'is_valid': m.is_valid,
}
for m in base_info.members
]
expected_names = ['scale', 'translate']
present_names = [m.name for m in base_info.members]
renderer = get_renderer()
# 渲染布局结构体
layout_context = {
'struct_name': shader_name,
'header_members': header_members,
'expected_names': expected_names,
'present_names': present_names,
'header_size': base_info.total_size,
'is_standard': base_info.is_standard_layout,
}
layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context)
return layout_code
def generate_push_constant_static_check(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants 静态检查代码(不含 custom 结构体检查)
注意custom 结构体的 sizeof 检查需要在结构体定义之后进行,
因此此函数只生成 Header 部分的静态检查。
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的静态检查C++代码
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.base_info:
return ""
base_info = push_constant.base_info
# 准备模板数据
header_members = [
{
'name': m.name,
'offset': m.offset,
'size': m.size,
'expected_offset': m.expected_offset,
'expected_size': m.expected_size,
'is_valid': m.is_valid,
}
for m in base_info.members
]
expected_names = ['scale', 'translate']
present_names = [m.name for m in base_info.members]
renderer = get_renderer()
# 渲染静态检查代码(不含 custom 部分)
check_context = {
'struct_name': shader_name,
'header_members': header_members,
'is_standard': base_info.is_standard_layout,
'has_custom_members': False, # 不在这里检查 custom
}
return renderer.render('push_constant/static_check.jinja2', check_context)
# 兼容性别名
def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str:
"""已弃用:请使用 generate_push_constant_header_layout
保留此函数以保持向后兼容性。
"""
return generate_push_constant_header_layout(reflection, shader_name)
# ============ Instance Data Structure Generation ============
def generate_instance_data_struct(instance_buffer: InstanceBufferInfo, type_map: Dict[int, TypeInfo]) -> str:
"""生成实例数据结构体定义
从 @instance_buffer 标记的 SSBO 中的结构体类型生成 C++ 结构体,
包含编译期静态检查以验证布局正确性。
Args:
instance_buffer: 实例缓冲信息
type_map: 类型映射
Returns:
生成的 C++ 结构体代码
Example output:
struct alignas(16) gradient_instance_data {
Eigen::Vector4f rect; // 必需x, y, width, height
Eigen::Vector4f start_color; // 起始颜色
Eigen::Vector4f transform; // rotation, anchor_x, anchor_y, scale
Eigen::Vector4f custom_params;
};
static_assert(sizeof(gradient_instance_data) == 64);
static_assert(offsetof(gradient_instance_data, rect) == 0);
"""
# 准备成员信息
members = []
for member in instance_buffer.members:
alignment = calculate_std430_alignment(member.resolved_type, type_map)
size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
members.append({
'name': member.name,
'cpp_type': cpp_type,
'offset': member.offset,
'size': size,
'alignment': alignment,
})
renderer = get_renderer()
context = {
'struct_name': instance_buffer.struct_type_name,
'members': members,
'total_size': instance_buffer.total_size,
'alignment': 16, # 实例数据结构必须 16 字节对齐
}
return renderer.render('instance/instance_data_struct.jinja2', context)
def generate_instance_data_info(reflection: SPIRVReflection, shader_name: str) -> dict:
"""生成实例数据相关的模板上下文信息
Args:
reflection: SPIR-V 反射信息
shader_name: 着色器名称
Returns:
包含实例化相关模板变量的字典
"""
if not reflection or not reflection.instance_buffer:
return {
'has_instance_data': False,
'instance_data_struct': '',
'instance_type_name': 'void',
'instance_buffer_binding': 0,
'instance_buffer_set': 0,
'instance_data_size': 0,
}
instance_buffer = reflection.instance_buffer
# 生成实例数据结构体代码
instance_data_struct = generate_instance_data_struct(instance_buffer, reflection.types)
return {
'has_instance_data': True,
'instance_data_struct': instance_data_struct,
'instance_type_name': instance_buffer.struct_type_name,
'instance_buffer_binding': instance_buffer.binding,
'instance_buffer_set': instance_buffer.set_number,
'instance_data_size': instance_buffer.total_size,
}
# ============ SPIR-V Array Generation ============
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
"""将SPIR-V二进制转换为C++ uint32_t数组"""
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
# 将数组分成每行12个元素的块
chunks = []
for i in range(0, len(uint32_array), 12):
chunk = uint32_array[i : i + 12]
hex_values = [f"0x{val:08x}" for val in chunk]
chunks.append(hex_values)
renderer = get_renderer()
context = {
'symbol_name': symbol_name,
'chunks': chunks,
}
return renderer.render('base/spirv_array.jinja2', context)
# ============ 缓冲区辅助函数生成 ============
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 数据操作的辅助函数"""
if not reflection.buffers:
return ""
lines = []
for buffer in reflection.buffers:
# 生成创建函数
lines.extend(_generate_create_buffer_function(buffer, shader_name))
lines.append("")
# 生成上传函数
lines.extend(_generate_upload_function(buffer))
lines.append("")
# 生成下载函数
lines.extend(_generate_download_function(buffer))
lines.append("")
return "\n".join(lines)
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
"""生成创建 buffer 的便捷函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
return result.splitlines()
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
"""生成上传数据到 buffer 的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/upload.jinja2', context)
return result.splitlines()
def _generate_download_function(buffer: BufferInfo) -> List[str]:
"""生成从 buffer 下载数据的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/download.jinja2', context)
return result.splitlines()
# ============ Typed Buffer Aliases Generation ============
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成类型安全的 buffer 类型别名和工厂函数"""
if not reflection.buffers:
return ""
lines = []
for buffer in reflection.buffers:
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
lines.append("")
return "\n".join(lines)
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
"""为特定 buffer 生成类型别名和工厂函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
return result.splitlines()
# ============ Buffer Manager Class Generation ============
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 管理器类"""
if not reflection.buffers:
return ""
lines = []
class_name = f"{shader_name}_buffer_manager"
# 判断是否是计算着色器
is_compute = reflection.shader_stage == "compute"
renderer = get_renderer()
context = {
'class_name': class_name,
'buffers': reflection.buffers,
'is_compute': is_compute,
}
result = renderer.render('buffer_manager/manager_class.jinja2', context)
lines.append(result)
return "\n".join(lines)
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
lines = []
lines.append(" // Container for all buffers")
lines.append(" struct buffers {")
for buffer in buffers:
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
lines.append(f" {alias_name} {var_name};")
lines.append(" };")
return lines
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
# ============ Vulkan Type Conversion ============
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
"""将描述符类型转换为Vulkan C++ enum"""
type_map = {
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
"sampled_image": "vk::DescriptorType::eSampledImage",
"combined_image_sampler": "vk::DescriptorType::eCombinedImageSampler",
"storage_image": "vk::DescriptorType::eStorageImage",
"sampler": "vk::DescriptorType::eSampler",
}
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
"""将着色器阶段列表转换为Vulkan C++ flags"""
stage_map = {
"vertex": "vk::ShaderStageFlagBits::eVertex",
"fragment": "vk::ShaderStageFlagBits::eFragment",
"compute": "vk::ShaderStageFlagBits::eCompute",
"geometry": "vk::ShaderStageFlagBits::eGeometry",
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
}
if not stages:
stages = ["compute"]
flags = [stage_map[stage] for stage in stages if stage in stage_map]
if not flags:
flags = ["vk::ShaderStageFlagBits::eCompute"]
return " | ".join(flags)
# ============ Header File Generation ============
def _to_pascal_case(name: str) -> str:
"""将 snake_case 转换为 PascalCase"""
return ''.join(word.capitalize() for word in name.split('_'))
def generate_header(
metadata: ShaderMetadata,
compilation_results: List[CompilationResult],
) -> str:
"""生成完整的C++头文件"""
# 准备buffer结构体代码
buffer_structures = ""
if metadata.reflection and metadata.reflection.buffers:
buffer_structures = generate_buffer_structures(metadata.reflection)
# 准备buffer辅助函数代码
buffer_helpers = ""
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
# 准备typed buffer代码
typed_buffers = ""
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
# 准备buffer管理器代码
buffer_manager = ""
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
# 准备Push Constants用户自定义部分结构体代码
push_constant_custom_struct = ""
push_constant_custom_type_name = "void"
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
push_constant_custom_struct = generate_push_constant_custom_struct(metadata.reflection, metadata.name)
push_constant_custom_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
# 兼容性:保持旧变量名
push_constant_effect_struct = push_constant_custom_struct
push_constant_effect_type_name = push_constant_custom_type_name
# 准备Push Constants Header部分布局代码
push_constant_header_layout = ""
has_push_constant_header_layout = False
push_constant_header_member_names = []
push_constant_header_member_offsets = []
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
push_constant_header_layout = generate_push_constant_header_layout(metadata.reflection, metadata.name)
has_push_constant_header_layout = bool(push_constant_header_layout)
# 提取 Header 成员名称和偏移量供 shader_spec 使用
base_info = metadata.reflection.push_constant.base_info
for member in base_info.members:
push_constant_header_member_names.append(member.name)
push_constant_header_member_offsets.append(member.offset)
# 兼容性:保持旧变量名
push_constant_base_layout = push_constant_header_layout
has_push_constant_base_layout = has_push_constant_header_layout
push_constant_base_member_names = push_constant_header_member_names
push_constant_base_member_offsets = push_constant_header_member_offsets
# 生成 Header 静态检查代码
push_constant_static_check = ""
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
push_constant_static_check = generate_push_constant_static_check(metadata.reflection, metadata.name)
# 准备SPIR-V数组和着色器类型检测
compilation_results_with_arrays = []
has_vert_shader = False
has_frag_shader = False
vert_spirv_symbol = ""
frag_spirv_symbol = ""
for result in compilation_results:
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
compilation_results_with_arrays.append({
'shader_type': result.shader_type,
'source_file': result.source_file,
'spirv_array': spirv_array,
})
# 检测着色器类型
if result.shader_type == "vert":
has_vert_shader = True
vert_spirv_symbol = symbol_name
elif result.shader_type == "frag":
has_frag_shader = True
frag_spirv_symbol = symbol_name
# 准备绑定信息
# 使用 set * 1000 + binding 作为最终的 binding 值
bindings_data = []
if metadata.bindings:
for binding in metadata.bindings:
# 计算包含set偏移的binding值
final_binding = binding.descriptor_set * 1000 + binding.binding
bindings_data.append({
'binding': final_binding,
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
'count': binding.count,
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
})
# 确定参数类型名称(从 buffer 反射信息中获取)
# 优先级buffer > push_constant效果部分 > void
params_type_name = "void"
if metadata.reflection and metadata.reflection.buffers:
# 使用第一个 buffer 的结构体类型名称
params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name
elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
# 没有 buffer但有 push_constant 效果部分,使用效果结构体名称
params_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
# 准备实例数据相关变量
instance_info = generate_instance_data_info(metadata.reflection, metadata.name)
# 使用模板渲染
renderer = get_renderer()
context = {
'namespace': metadata.namespace,
'shader_name': metadata.name,
'metadata_struct_name': f"{metadata.name}_metadata",
'bindings_name': f"{metadata.name}_bindings",
'shader_spec_name': f"{metadata.name}_shader_spec",
'params_type_name': params_type_name,
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
'is_compute': metadata.shader_type == "compute",
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
'has_buffers': metadata.reflection and metadata.reflection.buffers,
'generate_typed_buffers': metadata.generate_typed_buffers,
'buffer_structures': buffer_structures,
'buffer_helpers': buffer_helpers,
'typed_buffers': typed_buffers,
'buffer_manager': buffer_manager,
'compilation_results': compilation_results_with_arrays,
'bindings': bindings_data,
# shader_spec 相关变量
'has_vert_shader': has_vert_shader,
'has_frag_shader': has_frag_shader,
'vert_spirv_symbol': vert_spirv_symbol,
'frag_spirv_symbol': frag_spirv_symbol,
# Push Constants 效果结构体相关变量
'push_constant_effect_struct': push_constant_effect_struct,
'push_constant_effect_type_name': push_constant_effect_type_name,
'has_push_constant_effect': bool(push_constant_effect_struct),
# Push Constants 基础部分布局相关变量
'push_constant_base_layout': push_constant_base_layout,
'has_push_constant_base_layout': has_push_constant_base_layout,
'push_constant_base_member_names': push_constant_base_member_names,
'push_constant_base_member_offsets': push_constant_base_member_offsets,
# Push Constants 静态检查代码
'push_constant_static_check': push_constant_static_check,
# 实例化渲染相关变量
**instance_info,
}
return renderer.render('base/header.jinja2', context)