#!/usr/bin/env python3 """ 代码生成模块 负责生成C++代码,包括结构体定义、SPIR-V数组、头文件和绑定信息。 使用Jinja2模板引擎代替字符串拼接。 """ from __future__ import annotations import struct from typing import Dict, List from .template_loader import get_renderer from .type_mapping import ( calculate_std430_alignment, calculate_std430_size, spirv_type_to_cpp, ) from .types import ( ArrayTypeInfo, BufferInfo, CombinedPushConstantInfo, CompilationResult, InstanceBufferInfo, MemberInfo, PushConstantInfo, ShaderMetadata, SPIRVReflection, StagePushConstantInfo, ToolError, TypeInfo, VertexAttribute, VertexLayout, ) from .spirv_parser import analyze_dual_stage_push_constants from .constants import ( PUSH_CONSTANT_HEADER_SIZE, PUSH_CONSTANT_CUSTOM_OFFSET, PUSH_CONSTANT_CUSTOM_MAX_SIZE, ) def check_supports_procedural_vertex_shader(reflection: SPIRVReflection) -> bool: """检测片元着色器是否支持 custom_shader_quad_procedural 顶点着色器 检测条件:片元着色器的 push_constant 中第一个用户自定义成员的偏移正好是 16 字节。 这意味着片元着色器的 push_constant 布局: - Header [0-15]: scale + translate (由顶点着色器定义并填充) - Custom [16-127]: 用户自定义数据 (第一个成员从 offset 16 开始) Args: reflection: SPIR-V 反射信息(通常来自片元着色器) Returns: True 如果支持 procedural 顶点着色器,否则 False """ push_constant = reflection.push_constant if not push_constant: return False # 检查原始结构体成员,找到 offset >= 16 的第一个成员 struct_type = push_constant.struct_type if not struct_type or not struct_type.members: return False # 找到所有 offset >= PUSH_CONSTANT_CUSTOM_OFFSET 的成员 custom_members = [m for m in struct_type.members if m.offset >= PUSH_CONSTANT_CUSTOM_OFFSET] if not custom_members: return False # 检查第一个 custom 成员的偏移是否正好是 16 first_custom_offset = min(m.offset for m in custom_members) return first_custom_offset == PUSH_CONSTANT_CUSTOM_OFFSET # ============ Structure Generation ============ def generate_buffer_structures(reflection: SPIRVReflection) -> str: """生成所有buffer/uniform结构体定义""" if not reflection.buffers: return "" lines = [] for buffer in reflection.buffers: struct_lines = generate_single_struct(buffer, reflection.types) lines.extend(struct_lines) lines.append("") return "\n".join(lines) def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]: """生成单个结构体定义""" renderer = get_renderer() context = { 'buffer': buffer, 'type_map': type_map, 'alignment': calculate_std430_alignment(buffer.struct_type, type_map), } result = renderer.render('base/struct.jinja2', context) return result.splitlines() def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]: """生成结构体成员定义""" renderer = get_renderer() context = { 'member': member, 'type_map': type_map, } result = renderer.render('base/struct_member.jinja2', context) return result.splitlines() # ============ Push Constants Custom Structure Generation ============ def generate_push_constant_custom_struct(reflection: SPIRVReflection, shader_name: str) -> str: """生成 Push Constants 用户自定义参数结构体定义 (Custom 部分) 布局规范: ┌─────────────────────────────────────────────────────────────────────────────┐ │ Push Constants 内存布局 (128 bytes) │ ├─────────────────────────────────────────────────────────────────────────────┤ │ Header [0-15]: scale + translate - 16 bytes │ ← 系统自动填充 │ Custom [16-127]: 用户自定义数据 - 112 bytes │ ← 本函数生成 └─────────────────────────────────────────────────────────────────────────────┘ 识别规则: - 仅收集 offset >= 16 的成员(跳过 scale 和 translate) - 计算相对偏移量(绝对偏移 - 16) - 生成的 C++ 结构体只包含 Custom 区域的变量 Args: reflection: SPIR-V反射信息 shader_name: 着色器名称(用于生成结构体名称) Returns: 生成的C++结构体代码,如果没有用户自定义部分则返回空字符串 Note: 与 push_constant_traits.h 中的 PUSH_CONSTANT_CUSTOM_OFFSET (16) 和 PUSH_CONSTANT_CUSTOM_MAX_SIZE (112) 常量保持一致。 """ push_constant = reflection.push_constant if not push_constant or not push_constant.effect_members: return "" # 生成结构体名称(保持兼容性,内部使用 custom_members) struct_name = f"{_to_pascal_case(shader_name)}PushConstantCustom" # 准备成员信息 members = [] type_map = reflection.types for member in push_constant.effect_members: # effect_members 现在存储的是 custom_members # 计算对齐和大小 alignment = calculate_std430_alignment(member.resolved_type, type_map) size = calculate_std430_size(member.resolved_type, type_map) cpp_type = spirv_type_to_cpp(member.resolved_type, type_map) members.append({ 'name': member.name, 'cpp_type': cpp_type, 'offset': member.offset, 'size': size, 'alignment': alignment, }) renderer = get_renderer() context = { 'struct_name': struct_name, 'base_offset': PUSH_CONSTANT_CUSTOM_OFFSET, # 使用 Custom 部分偏移量 'members': members, } return renderer.render('push_constant/custom_struct.jinja2', context) # 兼容性别名 def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str: """已弃用:请使用 generate_push_constant_custom_struct 保留此函数以保持向后兼容性。 """ return generate_push_constant_custom_struct(reflection, shader_name) def generate_push_constant_header_layout(reflection: SPIRVReflection, shader_name: str) -> str: """生成 Push Constants Header 部分布局信息结构体 Header 部分包含 scale 和 translate 变量(偏移 0-15),用于 NDC 变换: ┌─────────────────────────────────────────────────────────────────────────────┐ │ Header [0-15]: scale (vec2) + translate (vec2) - 16 bytes │ │ scale: 偏移 0, 大小 8 │ │ translate: 偏移 8, 大小 8 │ └─────────────────────────────────────────────────────────────────────────────┘ Args: reflection: SPIR-V反射信息 shader_name: 着色器名称(用于生成结构体名称) Returns: 生成的C++代码,如果没有 Header 部分则返回空字符串 """ push_constant = reflection.push_constant if not push_constant or not push_constant.base_info: return "" base_info = push_constant.base_info # 准备模板数据 header_members = [ { 'name': m.name, 'offset': m.offset, 'size': m.size, 'expected_offset': m.expected_offset, 'expected_size': m.expected_size, 'is_valid': m.is_valid, } for m in base_info.members ] expected_names = ['scale', 'translate'] present_names = [m.name for m in base_info.members] renderer = get_renderer() # 渲染布局结构体 layout_context = { 'struct_name': shader_name, 'header_members': header_members, 'expected_names': expected_names, 'present_names': present_names, 'header_size': base_info.total_size, 'is_standard': base_info.is_standard_layout, } layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context) return layout_code def generate_push_constant_static_check(reflection: SPIRVReflection, shader_name: str) -> str: """生成 Push Constants 静态检查代码(不含 custom 结构体检查) 注意:custom 结构体的 sizeof 检查需要在结构体定义之后进行, 因此此函数只生成 Header 部分的静态检查。 Args: reflection: SPIR-V反射信息 shader_name: 着色器名称(用于生成结构体名称) Returns: 生成的静态检查C++代码 """ push_constant = reflection.push_constant if not push_constant or not push_constant.base_info: return "" base_info = push_constant.base_info # 准备模板数据 header_members = [ { 'name': m.name, 'offset': m.offset, 'size': m.size, 'expected_offset': m.expected_offset, 'expected_size': m.expected_size, 'is_valid': m.is_valid, } for m in base_info.members ] expected_names = ['scale', 'translate'] present_names = [m.name for m in base_info.members] renderer = get_renderer() # 渲染静态检查代码(不含 custom 部分) check_context = { 'struct_name': shader_name, 'header_members': header_members, 'is_standard': base_info.is_standard_layout, 'has_custom_members': False, # 不在这里检查 custom } return renderer.render('push_constant/static_check.jinja2', check_context) # 兼容性别名 def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str: """已弃用:请使用 generate_push_constant_header_layout 保留此函数以保持向后兼容性。 """ return generate_push_constant_header_layout(reflection, shader_name) # ============ Dual Stage Push Constants Structure Generation ============ def _generate_stage_push_constant_struct( stage_info: StagePushConstantInfo, shader_name: str, type_map: Dict[int, TypeInfo] ) -> str: """生成单个阶段的 Push Constant 结构体 Args: stage_info: 阶段 Push Constant 信息 shader_name: 着色器名称 type_map: 类型映射 Returns: 生成的 C++ 结构体代码 """ if not stage_info or not stage_info.has_members: return "" # 生成结构体名称 stage_suffix = "Vertex" if stage_info.stage == "vertex" else "Fragment" struct_name = f"{_to_pascal_case(shader_name)}{stage_suffix}PushConstant" # 准备成员信息 members = [] for member in stage_info.members: alignment = calculate_std430_alignment(member.resolved_type, type_map) size = calculate_std430_size(member.resolved_type, type_map) cpp_type = spirv_type_to_cpp(member.resolved_type, type_map) members.append({ 'name': member.name, 'cpp_type': cpp_type, 'offset': member.offset, 'size': size, 'alignment': alignment, }) renderer = get_renderer() context = { 'struct_name': struct_name, 'base_offset': stage_info.relative_offset, 'members': members, } return renderer.render('push_constant/custom_struct.jinja2', context) def _generate_dual_stage_fill_function( combined_info: CombinedPushConstantInfo, shader_name: str, type_map: Dict[int, TypeInfo] ) -> str: """生成双阶段 Push Constant 填充函数 根据布局模式生成不同的填充函数: - 共享模式:单一填充函数,同时填充两个阶段 - 分离模式:分别为每个阶段生成填充函数 Args: combined_info: 合并后的双阶段信息 shader_name: 着色器名称 type_map: 类型映射 Returns: 生成的 C++ 填充函数代码 """ lines = [] pascal_name = _to_pascal_case(shader_name) if combined_info.is_shared_layout: # 共享模式:生成单一填充函数 struct_name = f"{pascal_name}PushConstantCustom" lines.append(f"/**") lines.append(f" * @brief 填充 Push Constants Custom 区域(共享布局)") lines.append(f" * ") lines.append(f" * 顶点和片元着色器共享相同的 Custom 区域布局。") lines.append(f" * ") lines.append(f" * @param cmd 命令缓冲") lines.append(f" * @param layout 管线布局") lines.append(f" * @param data Custom 区域数据") lines.append(f" */") lines.append(f"inline void fill_push_constants(") lines.append(f" vk::CommandBuffer cmd,") lines.append(f" vk::PipelineLayout layout,") lines.append(f" const {struct_name}& data") lines.append(f") {{") lines.append(f" cmd.pushConstants(") lines.append(f" layout,") lines.append(f" vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment,") lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,") lines.append(f" sizeof({struct_name}),") lines.append(f" &data") lines.append(f" );") lines.append(f"}}") else: # 分离模式:为每个阶段生成填充函数 if combined_info.has_vertex: vert_struct = f"{pascal_name}VertexPushConstant" lines.append(f"/**") lines.append(f" * @brief 填充顶点着色器 Push Constants Custom 区域") lines.append(f" * ") lines.append(f" * @param cmd 命令缓冲") lines.append(f" * @param layout 管线布局") lines.append(f" * @param data 顶点着色器 Custom 区域数据") lines.append(f" */") lines.append(f"inline void fill_vertex_push_constants(") lines.append(f" vk::CommandBuffer cmd,") lines.append(f" vk::PipelineLayout layout,") lines.append(f" const {vert_struct}& data") lines.append(f") {{") lines.append(f" cmd.pushConstants(") lines.append(f" layout,") lines.append(f" vk::ShaderStageFlagBits::eVertex,") lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,") lines.append(f" sizeof({vert_struct}),") lines.append(f" &data") lines.append(f" );") lines.append(f"}}") lines.append(f"") if combined_info.has_fragment: frag_struct = f"{pascal_name}FragmentPushConstant" lines.append(f"/**") lines.append(f" * @brief 填充片元着色器 Push Constants Custom 区域") lines.append(f" * ") lines.append(f" * @param cmd 命令缓冲") lines.append(f" * @param layout 管线布局") lines.append(f" * @param data 片元着色器 Custom 区域数据") lines.append(f" */") lines.append(f"inline void fill_fragment_push_constants(") lines.append(f" vk::CommandBuffer cmd,") lines.append(f" vk::PipelineLayout layout,") lines.append(f" const {frag_struct}& data") lines.append(f") {{") lines.append(f" cmd.pushConstants(") lines.append(f" layout,") lines.append(f" vk::ShaderStageFlagBits::eFragment,") lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,") lines.append(f" sizeof({frag_struct}),") lines.append(f" &data") lines.append(f" );") lines.append(f"}}") return "\n".join(lines) def _generate_dual_stage_layout_constants( combined_info: CombinedPushConstantInfo, shader_name: str ) -> str: """生成双阶段 Push Constant 布局常量 Args: combined_info: 合并后的双阶段信息 shader_name: 着色器名称 Returns: 生成的 C++ 布局常量代码 """ lines = [] pascal_name = _to_pascal_case(shader_name) lines.append(f"/**") lines.append(f" * @brief {pascal_name} Push Constant 布局信息") lines.append(f" */") lines.append(f"struct {pascal_name}PushConstantLayout {{") lines.append(f" /// 是否使用共享布局模式") lines.append(f" static constexpr bool is_shared_layout = {'true' if combined_info.is_shared_layout else 'false'};") lines.append(f" ") if combined_info.has_vertex: lines.append(f" /// 顶点着色器 Custom 区域大小") lines.append(f" static constexpr std::uint32_t vertex_custom_size = {combined_info.vertex_info.total_size};") if combined_info.has_fragment: lines.append(f" /// 片元着色器 Custom 区域大小") lines.append(f" static constexpr std::uint32_t fragment_custom_size = {combined_info.fragment_info.total_size};") if combined_info.shared_members: lines.append(f" ") lines.append(f" /// 共享的成员数量") lines.append(f" static constexpr std::size_t shared_member_count = {len(combined_info.shared_members)};") lines.append(f"}};") return "\n".join(lines) def generate_dual_stage_push_constant_structs( vert_reflection: SPIRVReflection, frag_reflection: SPIRVReflection, shader_name: str ) -> str: """生成双阶段 Push Constant 结构体和填充函数 分析顶点和片元着色器的 Push Constant 布局,生成: - {ShaderName}VertexPushConstant 结构体(如果有顶点专用数据) - {ShaderName}FragmentPushConstant 结构体(如果有片元专用数据) - {ShaderName}PushConstantCustom 结构体(共享模式时) - fill_push_constants() 或 fill_vertex/fragment_push_constants() 填充函数 - {ShaderName}PushConstantLayout 布局信息结构体 Args: vert_reflection: 顶点着色器的 SPIR-V 反射信息 frag_reflection: 片元着色器的 SPIR-V 反射信息 shader_name: 着色器名称 Returns: 生成的 C++ 代码,如果没有 Custom 数据则返回空字符串 Example: >>> vert_ref = parse_spirv_type_system(vert_spirv, "vertex") >>> frag_ref = parse_spirv_type_system(frag_spirv, "fragment") >>> code = generate_dual_stage_push_constant_structs(vert_ref, frag_ref, "my_shader") >>> print(code) // 生成的结构体定义和填充函数... """ # 分析双阶段布局 combined_info = analyze_dual_stage_push_constants(vert_reflection, frag_reflection) # 如果两个阶段都没有 Custom 成员,返回空字符串 if not combined_info.has_vertex and not combined_info.has_fragment: return "" sections = [] pascal_name = _to_pascal_case(shader_name) # 添加注释头 sections.append(f"// ============ {pascal_name} Push Constants (Dual Stage) ============") sections.append("") if combined_info.is_shared_layout: # 共享模式:使用现有的 generate_push_constant_custom_struct # 优先使用片元着色器的反射信息(通常更完整) ref = frag_reflection if frag_reflection.push_constant else vert_reflection custom_struct = generate_push_constant_custom_struct(ref, shader_name) if custom_struct: sections.append(custom_struct) sections.append("") else: # 分离模式:为每个阶段生成独立结构体 type_map = vert_reflection.types if vert_reflection else frag_reflection.types if combined_info.has_vertex: vert_struct = _generate_stage_push_constant_struct( combined_info.vertex_info, shader_name, type_map ) if vert_struct: sections.append(vert_struct) sections.append("") if combined_info.has_fragment: frag_struct = _generate_stage_push_constant_struct( combined_info.fragment_info, shader_name, frag_reflection.types if frag_reflection else type_map ) if frag_struct: sections.append(frag_struct) sections.append("") # 生成填充函数 type_map = vert_reflection.types if vert_reflection else frag_reflection.types fill_function = _generate_dual_stage_fill_function(combined_info, shader_name, type_map) if fill_function: sections.append(fill_function) sections.append("") # 生成布局常量 layout_constants = _generate_dual_stage_layout_constants(combined_info, shader_name) if layout_constants: sections.append(layout_constants) sections.append("") return "\n".join(sections) # ============ Instance Data Structure Generation ============ def generate_instance_data_struct(instance_buffer: InstanceBufferInfo, type_map: Dict[int, TypeInfo]) -> str: """生成实例数据结构体定义 从 @instance_buffer 标记的 SSBO 中的结构体类型生成 C++ 结构体, 包含编译期静态检查以验证布局正确性。 Args: instance_buffer: 实例缓冲信息 type_map: 类型映射 Returns: 生成的 C++ 结构体代码 Example output: struct alignas(16) gradient_instance_data { Eigen::Vector4f rect; // 必需:x, y, width, height Eigen::Vector4f start_color; // 起始颜色 Eigen::Vector4f transform; // rotation, anchor_x, anchor_y, scale Eigen::Vector4f custom_params; }; static_assert(sizeof(gradient_instance_data) == 64); static_assert(offsetof(gradient_instance_data, rect) == 0); """ # 准备成员信息 members = [] for member in instance_buffer.members: alignment = calculate_std430_alignment(member.resolved_type, type_map) size = calculate_std430_size(member.resolved_type, type_map) cpp_type = spirv_type_to_cpp(member.resolved_type, type_map) members.append({ 'name': member.name, 'cpp_type': cpp_type, 'offset': member.offset, 'size': size, 'alignment': alignment, }) renderer = get_renderer() context = { 'struct_name': instance_buffer.struct_type_name, 'members': members, 'total_size': instance_buffer.total_size, 'alignment': 16, # 实例数据结构必须 16 字节对齐 } return renderer.render('instance/instance_data_struct.jinja2', context) def generate_instance_data_info(reflection: SPIRVReflection, shader_name: str) -> dict: """生成实例数据相关的模板上下文信息 Args: reflection: SPIR-V 反射信息 shader_name: 着色器名称 Returns: 包含实例化相关模板变量的字典 """ if not reflection or not reflection.instance_buffer: return { 'has_instance_data': False, 'instance_data_struct': '', 'instance_type_name': 'void', 'instance_buffer_binding': 0, 'instance_buffer_set': 0, 'instance_data_size': 0, } instance_buffer = reflection.instance_buffer # 生成实例数据结构体代码 instance_data_struct = generate_instance_data_struct(instance_buffer, reflection.types) return { 'has_instance_data': True, 'instance_data_struct': instance_data_struct, 'instance_type_name': instance_buffer.struct_type_name, 'instance_buffer_binding': instance_buffer.binding, 'instance_buffer_set': instance_buffer.set_number, 'instance_data_size': instance_buffer.total_size, } # ============ Vertex Structure Generation ============ def generate_vertex_struct(vertex_layout: VertexLayout, shader_name: str) -> str: """生成顶点输入结构体定义 从 SPIR-V 反射中提取的顶点输入布局生成 C++ 结构体, 包含静态方法用于获取 Vulkan 顶点输入描述。 Args: vertex_layout: 顶点布局信息 shader_name: 着色器名称(用于生成结构体名称) Returns: 生成的 C++ 结构体代码 Example output: struct ImageShaderVertex { Eigen::Vector2f position; // location = 0 Eigen::Vector4f color; // location = 1 Eigen::Vector2f uv; // location = 2 static VkVertexInputBindingDescription get_binding_description() { return { .binding = 0, .stride = sizeof(ImageShaderVertex), .inputRate = VK_VERTEX_INPUT_RATE_VERTEX }; } static std::array get_attribute_descriptions() { return {{ {0, 0, VK_FORMAT_R32G32_SFLOAT, offsetof(ImageShaderVertex, position)}, {1, 0, VK_FORMAT_R32G32B32A32_SFLOAT, offsetof(ImageShaderVertex, color)}, {2, 0, VK_FORMAT_R32G32_SFLOAT, offsetof(ImageShaderVertex, uv)} }}; } }; """ if not vertex_layout or not vertex_layout.attributes: return "" # 生成结构体名称 (shader_name -> ShaderNameVertex) struct_name = f"{_to_pascal_case(shader_name)}Vertex" vertex_layout.struct_name = struct_name # 准备属性信息 attributes = [] for attr in vertex_layout.attributes: attributes.append({ 'name': attr.name, 'cpp_type': attr.cpp_type, 'location': attr.location, 'vk_format': attr.vk_format, 'offset': attr.offset, 'size': attr.size, }) renderer = get_renderer() context = { 'struct_name': struct_name, 'attributes': attributes, 'stride': vertex_layout.stride, 'attribute_count': len(attributes), } return renderer.render('vertex/vertex_struct.jinja2', context) def generate_vertex_layout_info(reflection: SPIRVReflection, shader_name: str) -> dict: """生成顶点布局相关的模板上下文信息 Args: reflection: SPIR-V 反射信息 shader_name: 着色器名称 Returns: 包含顶点布局相关模板变量的字典 """ if not reflection or not reflection.vertex_layout: return { 'has_vertex_layout': False, 'vertex_struct': '', 'vertex_type_name': 'void', 'vertex_stride': 0, 'vertex_attribute_count': 0, } vertex_layout = reflection.vertex_layout # 生成顶点结构体代码 vertex_struct = generate_vertex_struct(vertex_layout, shader_name) return { 'has_vertex_layout': True, 'vertex_struct': vertex_struct, 'vertex_type_name': vertex_layout.struct_name, 'vertex_stride': vertex_layout.stride, 'vertex_attribute_count': len(vertex_layout.attributes), } # ============ SPIR-V Array Generation ============ def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str: """将SPIR-V二进制转换为C++ uint32_t数组""" if len(spirv_data) % 4 != 0: raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4") uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data) # 将数组分成每行12个元素的块 chunks = [] for i in range(0, len(uint32_array), 12): chunk = uint32_array[i : i + 12] hex_values = [f"0x{val:08x}" for val in chunk] chunks.append(hex_values) renderer = get_renderer() context = { 'symbol_name': symbol_name, 'chunks': chunks, } return renderer.render('base/spirv_array.jinja2', context) # ============ 缓冲区辅助函数生成 ============ def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str: """生成 buffer 数据操作的辅助函数""" if not reflection.buffers: return "" lines = [] for buffer in reflection.buffers: # 生成创建函数 lines.extend(_generate_create_buffer_function(buffer, shader_name)) lines.append("") # 生成上传函数 lines.extend(_generate_upload_function(buffer)) lines.append("") # 生成下载函数 lines.extend(_generate_download_function(buffer)) lines.append("") return "\n".join(lines) def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]: """生成创建 buffer 的便捷函数""" renderer = get_renderer() context = {'buffer': buffer, 'shader_name': shader_name} result = renderer.render('buffer_helpers/create_buffer.jinja2', context) return result.splitlines() def _generate_upload_function(buffer: BufferInfo) -> List[str]: """生成上传数据到 buffer 的函数""" renderer = get_renderer() context = {'buffer': buffer} result = renderer.render('buffer_helpers/upload.jinja2', context) return result.splitlines() def _generate_download_function(buffer: BufferInfo) -> List[str]: """生成从 buffer 下载数据的函数""" renderer = get_renderer() context = {'buffer': buffer} result = renderer.render('buffer_helpers/download.jinja2', context) return result.splitlines() # ============ Typed Buffer Aliases Generation ============ def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str: """生成类型安全的 buffer 类型别名和工厂函数""" if not reflection.buffers: return "" lines = [] for buffer in reflection.buffers: lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name)) lines.append("") return "\n".join(lines) def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]: """为特定 buffer 生成类型别名和工厂函数""" renderer = get_renderer() context = {'buffer': buffer, 'shader_name': shader_name} result = renderer.render('typed_buffer/alias_and_factory.jinja2', context) return result.splitlines() # ============ Buffer Manager Class Generation ============ def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str: """生成 buffer 管理器类""" if not reflection.buffers: return "" lines = [] class_name = f"{shader_name}_buffer_manager" # 判断是否是计算着色器 is_compute = reflection.shader_stage == "compute" renderer = get_renderer() context = { 'class_name': class_name, 'buffers': reflection.buffers, 'is_compute': is_compute, } result = renderer.render('buffer_manager/manager_class.jinja2', context) lines.append(result) return "\n".join(lines) def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]: """生成 buffers 结构体(已被模板替代,保留用于兼容性)""" lines = [] lines.append(" // Container for all buffers") lines.append(" struct buffers {") for buffer in buffers: var_name = buffer.variable_name alias_name = f"{var_name}_buffer" lines.append(f" {alias_name} {var_name};") lines.append(" };") return lines def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]: """生成管理器的工厂方法(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]: """生成初始化方法(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]: """生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]: """生成管理器的私有构造函数(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] # ============ Vulkan Type Conversion ============ def vk_descriptor_type_to_cpp(descriptor_type: str) -> str: """将描述符类型转换为Vulkan C++ enum""" type_map = { "storage_buffer": "vk::DescriptorType::eStorageBuffer", "uniform_buffer": "vk::DescriptorType::eUniformBuffer", "sampled_image": "vk::DescriptorType::eSampledImage", "combined_image_sampler": "vk::DescriptorType::eCombinedImageSampler", "storage_image": "vk::DescriptorType::eStorageImage", "sampler": "vk::DescriptorType::eSampler", } return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer") def vk_shader_stage_to_cpp(stages: List[str]) -> str: """将着色器阶段列表转换为Vulkan C++ flags""" stage_map = { "vertex": "vk::ShaderStageFlagBits::eVertex", "fragment": "vk::ShaderStageFlagBits::eFragment", "compute": "vk::ShaderStageFlagBits::eCompute", "geometry": "vk::ShaderStageFlagBits::eGeometry", "tess_control": "vk::ShaderStageFlagBits::eTessellationControl", "tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation", } if not stages: stages = ["compute"] flags = [stage_map[stage] for stage in stages if stage in stage_map] if not flags: flags = ["vk::ShaderStageFlagBits::eCompute"] return " | ".join(flags) # ============ Header File Generation ============ def _to_pascal_case(name: str) -> str: """将 snake_case 转换为 PascalCase""" return ''.join(word.capitalize() for word in name.split('_')) def generate_header( metadata: ShaderMetadata, compilation_results: List[CompilationResult], ) -> str: """生成完整的C++头文件""" # 准备buffer结构体代码 buffer_structures = "" if metadata.reflection and metadata.reflection.buffers: buffer_structures = generate_buffer_structures(metadata.reflection) # 准备buffer辅助函数代码 buffer_helpers = "" if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers: buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name) # 准备typed buffer代码 typed_buffers = "" if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers: typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name) # 准备buffer管理器代码 buffer_manager = "" if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers: buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name) # 准备Push Constants用户自定义部分结构体代码 push_constant_custom_struct = "" push_constant_custom_type_name = "void" if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members: push_constant_custom_struct = generate_push_constant_custom_struct(metadata.reflection, metadata.name) push_constant_custom_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom" # 兼容性:保持旧变量名 push_constant_effect_struct = push_constant_custom_struct push_constant_effect_type_name = push_constant_custom_type_name # 准备Push Constants Header部分布局代码 push_constant_header_layout = "" has_push_constant_header_layout = False push_constant_header_member_names = [] push_constant_header_member_offsets = [] if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info: push_constant_header_layout = generate_push_constant_header_layout(metadata.reflection, metadata.name) has_push_constant_header_layout = bool(push_constant_header_layout) # 提取 Header 成员名称和偏移量供 shader_spec 使用 base_info = metadata.reflection.push_constant.base_info for member in base_info.members: push_constant_header_member_names.append(member.name) push_constant_header_member_offsets.append(member.offset) # 兼容性:保持旧变量名 push_constant_base_layout = push_constant_header_layout has_push_constant_base_layout = has_push_constant_header_layout push_constant_base_member_names = push_constant_header_member_names push_constant_base_member_offsets = push_constant_header_member_offsets # 生成 Header 静态检查代码 push_constant_static_check = "" if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info: push_constant_static_check = generate_push_constant_static_check(metadata.reflection, metadata.name) # 准备SPIR-V数组和着色器类型检测 compilation_results_with_arrays = [] has_vert_shader = False has_frag_shader = False vert_spirv_symbol = "" frag_spirv_symbol = "" for result in compilation_results: symbol_name = f"{metadata.name}_{result.shader_type}_spirv" spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name) compilation_results_with_arrays.append({ 'shader_type': result.shader_type, 'source_file': result.source_file, 'spirv_array': spirv_array, }) # 检测着色器类型 if result.shader_type == "vert": has_vert_shader = True vert_spirv_symbol = symbol_name elif result.shader_type == "frag": has_frag_shader = True frag_spirv_symbol = symbol_name # 准备绑定信息 # 使用 set * 1000 + binding 作为最终的 binding 值 bindings_data = [] if metadata.bindings: for binding in metadata.bindings: # 计算包含set偏移的binding值 final_binding = binding.descriptor_set * 1000 + binding.binding bindings_data.append({ 'binding': final_binding, 'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type), 'count': binding.count, 'stage_flags': vk_shader_stage_to_cpp(binding.stages), }) # 确定参数类型名称(从 buffer 反射信息中获取) # 优先级:buffer > push_constant效果部分 > void params_type_name = "void" if metadata.reflection and metadata.reflection.buffers: # 使用第一个 buffer 的结构体类型名称 params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members: # 没有 buffer,但有 push_constant 效果部分,使用效果结构体名称 params_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom" # 准备实例数据相关变量 instance_info = generate_instance_data_info(metadata.reflection, metadata.name) # 准备顶点布局相关变量 vertex_info = generate_vertex_layout_info(metadata.reflection, metadata.name) # 检测是否支持 procedural 顶点着色器 # 条件:片元着色器的 push_constant 第一个用户自定义成员偏移为 16 字节 supports_procedural_vertex_shader = False if metadata.reflection: supports_procedural_vertex_shader = check_supports_procedural_vertex_shader(metadata.reflection) # 使用模板渲染 renderer = get_renderer() context = { 'namespace': metadata.namespace, 'shader_name': metadata.name, 'metadata_struct_name': f"{metadata.name}_metadata", 'bindings_name': f"{metadata.name}_bindings", 'shader_spec_name': f"{metadata.name}_shader_spec", 'params_type_name': params_type_name, 'entry_point': compilation_results[0].entry_point if compilation_results else "main", 'is_compute': metadata.shader_type == "compute", 'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "", 'has_buffers': metadata.reflection and metadata.reflection.buffers, 'generate_typed_buffers': metadata.generate_typed_buffers, 'buffer_structures': buffer_structures, 'buffer_helpers': buffer_helpers, 'typed_buffers': typed_buffers, 'buffer_manager': buffer_manager, 'compilation_results': compilation_results_with_arrays, 'bindings': bindings_data, # shader_spec 相关变量 'has_vert_shader': has_vert_shader, 'has_frag_shader': has_frag_shader, 'vert_spirv_symbol': vert_spirv_symbol, 'frag_spirv_symbol': frag_spirv_symbol, # Push Constants 效果结构体相关变量 'push_constant_effect_struct': push_constant_effect_struct, 'push_constant_effect_type_name': push_constant_effect_type_name, 'has_push_constant_effect': bool(push_constant_effect_struct), # Push Constants 基础部分布局相关变量 'push_constant_base_layout': push_constant_base_layout, 'has_push_constant_base_layout': has_push_constant_base_layout, 'push_constant_base_member_names': push_constant_base_member_names, 'push_constant_base_member_offsets': push_constant_base_member_offsets, # Push Constants 静态检查代码 'push_constant_static_check': push_constant_static_check, # 是否支持 procedural 顶点着色器 'supports_procedural_vertex_shader': supports_procedural_vertex_shader, # 实例化渲染相关变量 **instance_info, # 顶点布局相关变量 **vertex_info, } return renderer.render('base/header.jinja2', context)