714 lines
27 KiB
Python
714 lines
27 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
代码生成模块
|
||
|
||
负责生成C++代码,包括结构体定义、SPIR-V数组、头文件和绑定信息。
|
||
使用Jinja2模板引擎代替字符串拼接。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
from typing import Dict, List
|
||
|
||
from .template_loader import get_renderer
|
||
from .type_mapping import (
|
||
calculate_std430_alignment,
|
||
calculate_std430_size,
|
||
spirv_type_to_cpp,
|
||
)
|
||
from .types import (
|
||
ArrayTypeInfo,
|
||
BufferInfo,
|
||
CompilationResult,
|
||
InstanceBufferInfo,
|
||
MemberInfo,
|
||
PushConstantInfo,
|
||
ShaderMetadata,
|
||
SPIRVReflection,
|
||
ToolError,
|
||
TypeInfo,
|
||
)
|
||
from .constants import (
|
||
PUSH_CONSTANT_HEADER_SIZE,
|
||
PUSH_CONSTANT_CUSTOM_OFFSET,
|
||
PUSH_CONSTANT_CUSTOM_MAX_SIZE,
|
||
)
|
||
|
||
|
||
# ============ Structure Generation ============
|
||
|
||
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
|
||
"""生成所有buffer/uniform结构体定义"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
for buffer in reflection.buffers:
|
||
struct_lines = generate_single_struct(buffer, reflection.types)
|
||
lines.extend(struct_lines)
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成单个结构体定义"""
|
||
renderer = get_renderer()
|
||
|
||
context = {
|
||
'buffer': buffer,
|
||
'type_map': type_map,
|
||
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
|
||
}
|
||
|
||
result = renderer.render('base/struct.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成结构体成员定义"""
|
||
renderer = get_renderer()
|
||
|
||
context = {
|
||
'member': member,
|
||
'type_map': type_map,
|
||
}
|
||
|
||
result = renderer.render('base/struct_member.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Push Constants Custom Structure Generation ============
|
||
|
||
|
||
def generate_push_constant_custom_struct(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants 用户自定义参数结构体定义 (Custom 部分)
|
||
|
||
布局规范:
|
||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||
│ Push Constants 内存布局 (128 bytes) │
|
||
├─────────────────────────────────────────────────────────────────────────────┤
|
||
│ Header [0-15]: scale + translate - 16 bytes │ ← 系统自动填充
|
||
│ Custom [16-127]: 用户自定义数据 - 112 bytes │ ← 本函数生成
|
||
└─────────────────────────────────────────────────────────────────────────────┘
|
||
|
||
识别规则:
|
||
- 仅收集 offset >= 16 的成员(跳过 scale 和 translate)
|
||
- 计算相对偏移量(绝对偏移 - 16)
|
||
- 生成的 C++ 结构体只包含 Custom 区域的变量
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的C++结构体代码,如果没有用户自定义部分则返回空字符串
|
||
|
||
Note:
|
||
与 push_constant_traits.h 中的 PUSH_CONSTANT_CUSTOM_OFFSET (16)
|
||
和 PUSH_CONSTANT_CUSTOM_MAX_SIZE (112) 常量保持一致。
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.effect_members:
|
||
return ""
|
||
|
||
# 生成结构体名称(保持兼容性,内部使用 custom_members)
|
||
struct_name = f"{_to_pascal_case(shader_name)}PushConstantCustom"
|
||
|
||
# 准备成员信息
|
||
members = []
|
||
type_map = reflection.types
|
||
|
||
for member in push_constant.effect_members: # effect_members 现在存储的是 custom_members
|
||
# 计算对齐和大小
|
||
alignment = calculate_std430_alignment(member.resolved_type, type_map)
|
||
size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
members.append({
|
||
'name': member.name,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': size,
|
||
'alignment': alignment,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': struct_name,
|
||
'base_offset': PUSH_CONSTANT_CUSTOM_OFFSET, # 使用 Custom 部分偏移量
|
||
'members': members,
|
||
}
|
||
|
||
return renderer.render('push_constant/custom_struct.jinja2', context)
|
||
|
||
|
||
# 兼容性别名
|
||
def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""已弃用:请使用 generate_push_constant_custom_struct
|
||
|
||
保留此函数以保持向后兼容性。
|
||
"""
|
||
return generate_push_constant_custom_struct(reflection, shader_name)
|
||
|
||
|
||
def generate_push_constant_header_layout(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants Header 部分布局信息结构体
|
||
|
||
Header 部分包含 scale 和 translate 变量(偏移 0-15),用于 NDC 变换:
|
||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||
│ Header [0-15]: scale (vec2) + translate (vec2) - 16 bytes │
|
||
│ scale: 偏移 0, 大小 8 │
|
||
│ translate: 偏移 8, 大小 8 │
|
||
└─────────────────────────────────────────────────────────────────────────────┘
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的C++代码,如果没有 Header 部分则返回空字符串
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.base_info:
|
||
return ""
|
||
|
||
base_info = push_constant.base_info
|
||
|
||
# 准备模板数据
|
||
header_members = [
|
||
{
|
||
'name': m.name,
|
||
'offset': m.offset,
|
||
'size': m.size,
|
||
'expected_offset': m.expected_offset,
|
||
'expected_size': m.expected_size,
|
||
'is_valid': m.is_valid,
|
||
}
|
||
for m in base_info.members
|
||
]
|
||
|
||
expected_names = ['scale', 'translate']
|
||
present_names = [m.name for m in base_info.members]
|
||
|
||
renderer = get_renderer()
|
||
|
||
# 渲染布局结构体
|
||
layout_context = {
|
||
'struct_name': shader_name,
|
||
'header_members': header_members,
|
||
'expected_names': expected_names,
|
||
'present_names': present_names,
|
||
'header_size': base_info.total_size,
|
||
'is_standard': base_info.is_standard_layout,
|
||
}
|
||
|
||
layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context)
|
||
|
||
return layout_code
|
||
|
||
|
||
def generate_push_constant_static_check(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants 静态检查代码(不含 custom 结构体检查)
|
||
|
||
注意:custom 结构体的 sizeof 检查需要在结构体定义之后进行,
|
||
因此此函数只生成 Header 部分的静态检查。
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的静态检查C++代码
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.base_info:
|
||
return ""
|
||
|
||
base_info = push_constant.base_info
|
||
|
||
# 准备模板数据
|
||
header_members = [
|
||
{
|
||
'name': m.name,
|
||
'offset': m.offset,
|
||
'size': m.size,
|
||
'expected_offset': m.expected_offset,
|
||
'expected_size': m.expected_size,
|
||
'is_valid': m.is_valid,
|
||
}
|
||
for m in base_info.members
|
||
]
|
||
|
||
expected_names = ['scale', 'translate']
|
||
present_names = [m.name for m in base_info.members]
|
||
|
||
renderer = get_renderer()
|
||
|
||
# 渲染静态检查代码(不含 custom 部分)
|
||
check_context = {
|
||
'struct_name': shader_name,
|
||
'header_members': header_members,
|
||
'is_standard': base_info.is_standard_layout,
|
||
'has_custom_members': False, # 不在这里检查 custom
|
||
}
|
||
|
||
return renderer.render('push_constant/static_check.jinja2', check_context)
|
||
|
||
|
||
# 兼容性别名
|
||
def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""已弃用:请使用 generate_push_constant_header_layout
|
||
|
||
保留此函数以保持向后兼容性。
|
||
"""
|
||
return generate_push_constant_header_layout(reflection, shader_name)
|
||
|
||
|
||
# ============ Instance Data Structure Generation ============
|
||
|
||
def generate_instance_data_struct(instance_buffer: InstanceBufferInfo, type_map: Dict[int, TypeInfo]) -> str:
|
||
"""生成实例数据结构体定义
|
||
|
||
从 @instance_buffer 标记的 SSBO 中的结构体类型生成 C++ 结构体,
|
||
包含编译期静态检查以验证布局正确性。
|
||
|
||
Args:
|
||
instance_buffer: 实例缓冲信息
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
生成的 C++ 结构体代码
|
||
|
||
Example output:
|
||
struct alignas(16) gradient_instance_data {
|
||
Eigen::Vector4f rect; // 必需:x, y, width, height
|
||
Eigen::Vector4f start_color; // 起始颜色
|
||
Eigen::Vector4f transform; // rotation, anchor_x, anchor_y, scale
|
||
Eigen::Vector4f custom_params;
|
||
};
|
||
|
||
static_assert(sizeof(gradient_instance_data) == 64);
|
||
static_assert(offsetof(gradient_instance_data, rect) == 0);
|
||
"""
|
||
# 准备成员信息
|
||
members = []
|
||
for member in instance_buffer.members:
|
||
alignment = calculate_std430_alignment(member.resolved_type, type_map)
|
||
size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
members.append({
|
||
'name': member.name,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': size,
|
||
'alignment': alignment,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': instance_buffer.struct_type_name,
|
||
'members': members,
|
||
'total_size': instance_buffer.total_size,
|
||
'alignment': 16, # 实例数据结构必须 16 字节对齐
|
||
}
|
||
|
||
return renderer.render('instance/instance_data_struct.jinja2', context)
|
||
|
||
|
||
def generate_instance_data_info(reflection: SPIRVReflection, shader_name: str) -> dict:
|
||
"""生成实例数据相关的模板上下文信息
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息
|
||
shader_name: 着色器名称
|
||
|
||
Returns:
|
||
包含实例化相关模板变量的字典
|
||
"""
|
||
if not reflection or not reflection.instance_buffer:
|
||
return {
|
||
'has_instance_data': False,
|
||
'instance_data_struct': '',
|
||
'instance_type_name': 'void',
|
||
'instance_buffer_binding': 0,
|
||
'instance_buffer_set': 0,
|
||
'instance_data_size': 0,
|
||
}
|
||
|
||
instance_buffer = reflection.instance_buffer
|
||
|
||
# 生成实例数据结构体代码
|
||
instance_data_struct = generate_instance_data_struct(instance_buffer, reflection.types)
|
||
|
||
return {
|
||
'has_instance_data': True,
|
||
'instance_data_struct': instance_data_struct,
|
||
'instance_type_name': instance_buffer.struct_type_name,
|
||
'instance_buffer_binding': instance_buffer.binding,
|
||
'instance_buffer_set': instance_buffer.set_number,
|
||
'instance_data_size': instance_buffer.total_size,
|
||
}
|
||
|
||
|
||
# ============ SPIR-V Array Generation ============
|
||
|
||
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
|
||
"""将SPIR-V二进制转换为C++ uint32_t数组"""
|
||
if len(spirv_data) % 4 != 0:
|
||
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
|
||
|
||
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
|
||
|
||
# 将数组分成每行12个元素的块
|
||
chunks = []
|
||
for i in range(0, len(uint32_array), 12):
|
||
chunk = uint32_array[i : i + 12]
|
||
hex_values = [f"0x{val:08x}" for val in chunk]
|
||
chunks.append(hex_values)
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'symbol_name': symbol_name,
|
||
'chunks': chunks,
|
||
}
|
||
|
||
return renderer.render('base/spirv_array.jinja2', context)
|
||
|
||
|
||
# ============ 缓冲区辅助函数生成 ============
|
||
|
||
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 数据操作的辅助函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
for buffer in reflection.buffers:
|
||
# 生成创建函数
|
||
lines.extend(_generate_create_buffer_function(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
# 生成上传函数
|
||
lines.extend(_generate_upload_function(buffer))
|
||
lines.append("")
|
||
|
||
# 生成下载函数
|
||
lines.extend(_generate_download_function(buffer))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""生成创建 buffer 的便捷函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer, 'shader_name': shader_name}
|
||
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成上传数据到 buffer 的函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer}
|
||
result = renderer.render('buffer_helpers/upload.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def _generate_download_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成从 buffer 下载数据的函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer}
|
||
result = renderer.render('buffer_helpers/download.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Typed Buffer Aliases Generation ============
|
||
|
||
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成类型安全的 buffer 类型别名和工厂函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
for buffer in reflection.buffers:
|
||
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""为特定 buffer 生成类型别名和工厂函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer, 'shader_name': shader_name}
|
||
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Buffer Manager Class Generation ============
|
||
|
||
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 管理器类"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
class_name = f"{shader_name}_buffer_manager"
|
||
|
||
# 判断是否是计算着色器
|
||
is_compute = reflection.shader_stage == "compute"
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'class_name': class_name,
|
||
'buffers': reflection.buffers,
|
||
'is_compute': is_compute,
|
||
}
|
||
|
||
result = renderer.render('buffer_manager/manager_class.jinja2', context)
|
||
lines.append(result)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
|
||
lines = []
|
||
lines.append(" // Container for all buffers")
|
||
lines.append(" struct buffers {")
|
||
for buffer in buffers:
|
||
var_name = buffer.variable_name
|
||
alias_name = f"{var_name}_buffer"
|
||
lines.append(f" {alias_name} {var_name};")
|
||
lines.append(" };")
|
||
return lines
|
||
|
||
|
||
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
# ============ Vulkan Type Conversion ============
|
||
|
||
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
|
||
"""将描述符类型转换为Vulkan C++ enum"""
|
||
type_map = {
|
||
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
|
||
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
|
||
"sampled_image": "vk::DescriptorType::eSampledImage",
|
||
"combined_image_sampler": "vk::DescriptorType::eCombinedImageSampler",
|
||
"storage_image": "vk::DescriptorType::eStorageImage",
|
||
"sampler": "vk::DescriptorType::eSampler",
|
||
}
|
||
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
|
||
|
||
|
||
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
|
||
"""将着色器阶段列表转换为Vulkan C++ flags"""
|
||
stage_map = {
|
||
"vertex": "vk::ShaderStageFlagBits::eVertex",
|
||
"fragment": "vk::ShaderStageFlagBits::eFragment",
|
||
"compute": "vk::ShaderStageFlagBits::eCompute",
|
||
"geometry": "vk::ShaderStageFlagBits::eGeometry",
|
||
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
|
||
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
|
||
}
|
||
|
||
if not stages:
|
||
stages = ["compute"]
|
||
|
||
flags = [stage_map[stage] for stage in stages if stage in stage_map]
|
||
if not flags:
|
||
flags = ["vk::ShaderStageFlagBits::eCompute"]
|
||
|
||
return " | ".join(flags)
|
||
|
||
|
||
# ============ Header File Generation ============
|
||
|
||
def _to_pascal_case(name: str) -> str:
|
||
"""将 snake_case 转换为 PascalCase"""
|
||
return ''.join(word.capitalize() for word in name.split('_'))
|
||
|
||
|
||
def generate_header(
|
||
metadata: ShaderMetadata,
|
||
compilation_results: List[CompilationResult],
|
||
) -> str:
|
||
"""生成完整的C++头文件"""
|
||
# 准备buffer结构体代码
|
||
buffer_structures = ""
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
buffer_structures = generate_buffer_structures(metadata.reflection)
|
||
|
||
# 准备buffer辅助函数代码
|
||
buffer_helpers = ""
|
||
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
|
||
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
|
||
|
||
# 准备typed buffer代码
|
||
typed_buffers = ""
|
||
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
|
||
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
|
||
|
||
# 准备buffer管理器代码
|
||
buffer_manager = ""
|
||
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
|
||
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
|
||
|
||
# 准备Push Constants用户自定义部分结构体代码
|
||
push_constant_custom_struct = ""
|
||
push_constant_custom_type_name = "void"
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
|
||
push_constant_custom_struct = generate_push_constant_custom_struct(metadata.reflection, metadata.name)
|
||
push_constant_custom_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
|
||
|
||
# 兼容性:保持旧变量名
|
||
push_constant_effect_struct = push_constant_custom_struct
|
||
push_constant_effect_type_name = push_constant_custom_type_name
|
||
|
||
# 准备Push Constants Header部分布局代码
|
||
push_constant_header_layout = ""
|
||
has_push_constant_header_layout = False
|
||
push_constant_header_member_names = []
|
||
push_constant_header_member_offsets = []
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
|
||
push_constant_header_layout = generate_push_constant_header_layout(metadata.reflection, metadata.name)
|
||
has_push_constant_header_layout = bool(push_constant_header_layout)
|
||
# 提取 Header 成员名称和偏移量供 shader_spec 使用
|
||
base_info = metadata.reflection.push_constant.base_info
|
||
for member in base_info.members:
|
||
push_constant_header_member_names.append(member.name)
|
||
push_constant_header_member_offsets.append(member.offset)
|
||
|
||
# 兼容性:保持旧变量名
|
||
push_constant_base_layout = push_constant_header_layout
|
||
has_push_constant_base_layout = has_push_constant_header_layout
|
||
push_constant_base_member_names = push_constant_header_member_names
|
||
push_constant_base_member_offsets = push_constant_header_member_offsets
|
||
|
||
# 生成 Header 静态检查代码
|
||
push_constant_static_check = ""
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
|
||
push_constant_static_check = generate_push_constant_static_check(metadata.reflection, metadata.name)
|
||
|
||
# 准备SPIR-V数组和着色器类型检测
|
||
compilation_results_with_arrays = []
|
||
has_vert_shader = False
|
||
has_frag_shader = False
|
||
vert_spirv_symbol = ""
|
||
frag_spirv_symbol = ""
|
||
|
||
for result in compilation_results:
|
||
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
|
||
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
|
||
compilation_results_with_arrays.append({
|
||
'shader_type': result.shader_type,
|
||
'source_file': result.source_file,
|
||
'spirv_array': spirv_array,
|
||
})
|
||
|
||
# 检测着色器类型
|
||
if result.shader_type == "vert":
|
||
has_vert_shader = True
|
||
vert_spirv_symbol = symbol_name
|
||
elif result.shader_type == "frag":
|
||
has_frag_shader = True
|
||
frag_spirv_symbol = symbol_name
|
||
|
||
# 准备绑定信息
|
||
# 使用 set * 1000 + binding 作为最终的 binding 值
|
||
bindings_data = []
|
||
if metadata.bindings:
|
||
for binding in metadata.bindings:
|
||
# 计算包含set偏移的binding值
|
||
final_binding = binding.descriptor_set * 1000 + binding.binding
|
||
bindings_data.append({
|
||
'binding': final_binding,
|
||
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
|
||
'count': binding.count,
|
||
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
|
||
})
|
||
|
||
# 确定参数类型名称(从 buffer 反射信息中获取)
|
||
# 优先级:buffer > push_constant效果部分 > void
|
||
params_type_name = "void"
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
# 使用第一个 buffer 的结构体类型名称
|
||
params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name
|
||
elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
|
||
# 没有 buffer,但有 push_constant 效果部分,使用效果结构体名称
|
||
params_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
|
||
|
||
# 准备实例数据相关变量
|
||
instance_info = generate_instance_data_info(metadata.reflection, metadata.name)
|
||
|
||
# 使用模板渲染
|
||
renderer = get_renderer()
|
||
context = {
|
||
'namespace': metadata.namespace,
|
||
'shader_name': metadata.name,
|
||
'metadata_struct_name': f"{metadata.name}_metadata",
|
||
'bindings_name': f"{metadata.name}_bindings",
|
||
'shader_spec_name': f"{metadata.name}_shader_spec",
|
||
'params_type_name': params_type_name,
|
||
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
|
||
'is_compute': metadata.shader_type == "compute",
|
||
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
|
||
'has_buffers': metadata.reflection and metadata.reflection.buffers,
|
||
'generate_typed_buffers': metadata.generate_typed_buffers,
|
||
'buffer_structures': buffer_structures,
|
||
'buffer_helpers': buffer_helpers,
|
||
'typed_buffers': typed_buffers,
|
||
'buffer_manager': buffer_manager,
|
||
'compilation_results': compilation_results_with_arrays,
|
||
'bindings': bindings_data,
|
||
# shader_spec 相关变量
|
||
'has_vert_shader': has_vert_shader,
|
||
'has_frag_shader': has_frag_shader,
|
||
'vert_spirv_symbol': vert_spirv_symbol,
|
||
'frag_spirv_symbol': frag_spirv_symbol,
|
||
# Push Constants 效果结构体相关变量
|
||
'push_constant_effect_struct': push_constant_effect_struct,
|
||
'push_constant_effect_type_name': push_constant_effect_type_name,
|
||
'has_push_constant_effect': bool(push_constant_effect_struct),
|
||
# Push Constants 基础部分布局相关变量
|
||
'push_constant_base_layout': push_constant_base_layout,
|
||
'has_push_constant_base_layout': has_push_constant_base_layout,
|
||
'push_constant_base_member_names': push_constant_base_member_names,
|
||
'push_constant_base_member_offsets': push_constant_base_member_offsets,
|
||
# Push Constants 静态检查代码
|
||
'push_constant_static_check': push_constant_static_check,
|
||
# 实例化渲染相关变量
|
||
**instance_info,
|
||
}
|
||
|
||
return renderer.render('base/header.jinja2', context) |