#!/usr/bin/env python3 """ 代码生成模块 负责生成C++代码,包括结构体定义、SPIR-V数组、头文件和绑定信息。 使用Jinja2模板引擎代替字符串拼接。 """ from __future__ import annotations import struct from typing import Dict, List from .template_loader import get_renderer from .type_mapping import ( calculate_std430_alignment, calculate_std430_size, spirv_type_to_cpp, ) from .types import ( ArrayTypeInfo, BufferInfo, CompilationResult, MemberInfo, ShaderMetadata, SPIRVReflection, ToolError, TypeInfo, ) # ============ Structure Generation ============ def generate_buffer_structures(reflection: SPIRVReflection) -> str: """生成所有buffer/uniform结构体定义""" if not reflection.buffers: return "" lines = [] lines.append("// ============ Buffer Structures ============") lines.append("// Auto-generated from SPIR-V reflection") lines.append("") for buffer in reflection.buffers: struct_lines = generate_single_struct(buffer, reflection.types) lines.extend(struct_lines) lines.append("") return "\n".join(lines) def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]: """生成单个结构体定义""" renderer = get_renderer() context = { 'buffer': buffer, 'type_map': type_map, 'alignment': calculate_std430_alignment(buffer.struct_type, type_map), } result = renderer.render('base/struct.jinja2', context) return result.splitlines() def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]: """生成结构体成员定义""" renderer = get_renderer() context = { 'member': member, 'type_map': type_map, } result = renderer.render('base/struct_member.jinja2', context) return result.splitlines() # ============ SPIR-V Array Generation ============ def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str: """将SPIR-V二进制转换为C++ uint32_t数组""" if len(spirv_data) % 4 != 0: raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4") uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data) # 将数组分成每行12个元素的块 chunks = [] for i in range(0, len(uint32_array), 12): chunk = uint32_array[i : i + 12] hex_values = [f"0x{val:08x}" for val in chunk] chunks.append(hex_values) renderer = get_renderer() context = { 'symbol_name': symbol_name, 'chunks': chunks, } return renderer.render('base/spirv_array.jinja2', context) # ============ 缓冲区辅助函数生成 ============ def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str: """生成 buffer 数据操作的辅助函数""" if not reflection.buffers: return "" lines = [] lines.append("// ============ 缓冲区辅助函数 ============") lines.append("// 自动生成缓冲区数据管理的便利函数") lines.append("") for buffer in reflection.buffers: # 生成创建函数 lines.extend(_generate_create_buffer_function(buffer, shader_name)) lines.append("") # 生成上传函数 lines.extend(_generate_upload_function(buffer)) lines.append("") # 生成下载函数 lines.extend(_generate_download_function(buffer)) lines.append("") return "\n".join(lines) def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]: """生成创建 buffer 的便捷函数""" renderer = get_renderer() context = {'buffer': buffer, 'shader_name': shader_name} result = renderer.render('buffer_helpers/create_buffer.jinja2', context) return result.splitlines() def _generate_upload_function(buffer: BufferInfo) -> List[str]: """生成上传数据到 buffer 的函数""" renderer = get_renderer() context = {'buffer': buffer} result = renderer.render('buffer_helpers/upload.jinja2', context) return result.splitlines() def _generate_download_function(buffer: BufferInfo) -> List[str]: """生成从 buffer 下载数据的函数""" renderer = get_renderer() context = {'buffer': buffer} result = renderer.render('buffer_helpers/download.jinja2', context) return result.splitlines() # ============ Typed Buffer Aliases Generation ============ def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str: """生成类型安全的 buffer 类型别名和工厂函数""" if not reflection.buffers: return "" lines = [] lines.append("// ============ Typed Buffer Aliases ============") lines.append("// Type aliases and factory functions for type-safe buffer wrappers") lines.append("") for buffer in reflection.buffers: lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name)) lines.append("") return "\n".join(lines) def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]: """为特定 buffer 生成类型别名和工厂函数""" renderer = get_renderer() context = {'buffer': buffer, 'shader_name': shader_name} result = renderer.render('typed_buffer/alias_and_factory.jinja2', context) return result.splitlines() # ============ Buffer Manager Class Generation ============ def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str: """生成 buffer 管理器类""" if not reflection.buffers: return "" lines = [] lines.append("// ============ Buffer Manager Class ============") lines.append("// Automatic lifecycle management for all shader buffers") lines.append("") class_name = f"{shader_name}_buffer_manager" # 判断是否是计算着色器 is_compute = reflection.shader_stage == "compute" renderer = get_renderer() context = { 'class_name': class_name, 'buffers': reflection.buffers, 'is_compute': is_compute, } result = renderer.render('buffer_manager/manager_class.jinja2', context) lines.append(result) return "\n".join(lines) def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]: """生成 buffers 结构体(已被模板替代,保留用于兼容性)""" lines = [] lines.append(" // Container for all buffers") lines.append(" struct buffers {") for buffer in buffers: var_name = buffer.variable_name alias_name = f"{var_name}_buffer" lines.append(f" {alias_name} {var_name};") lines.append(" };") return lines def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]: """生成管理器的工厂方法(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]: """生成初始化方法(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]: """生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]: """生成管理器的私有构造函数(已被模板替代,保留用于兼容性)""" # 此函数已被模板替代,保留签名以防外部调用 return [] # ============ Vulkan Type Conversion ============ def vk_descriptor_type_to_cpp(descriptor_type: str) -> str: """将描述符类型转换为Vulkan C++ enum""" type_map = { "storage_buffer": "vk::DescriptorType::eStorageBuffer", "uniform_buffer": "vk::DescriptorType::eUniformBuffer", "sampled_image": "vk::DescriptorType::eSampledImage", "storage_image": "vk::DescriptorType::eStorageImage", "sampler": "vk::DescriptorType::eSampler", } return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer") def vk_shader_stage_to_cpp(stages: List[str]) -> str: """将着色器阶段列表转换为Vulkan C++ flags""" stage_map = { "vertex": "vk::ShaderStageFlagBits::eVertex", "fragment": "vk::ShaderStageFlagBits::eFragment", "compute": "vk::ShaderStageFlagBits::eCompute", "geometry": "vk::ShaderStageFlagBits::eGeometry", "tess_control": "vk::ShaderStageFlagBits::eTessellationControl", "tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation", } if not stages: stages = ["compute"] flags = [stage_map[stage] for stage in stages if stage in stage_map] if not flags: flags = ["vk::ShaderStageFlagBits::eCompute"] return " | ".join(flags) # ============ Header File Generation ============ def generate_header( metadata: ShaderMetadata, compilation_results: List[CompilationResult], ) -> str: """生成完整的C++头文件""" # 准备buffer结构体代码 buffer_structures = "" if metadata.reflection and metadata.reflection.buffers: buffer_structures = generate_buffer_structures(metadata.reflection) # 准备buffer辅助函数代码 buffer_helpers = "" if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers: buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name) # 准备typed buffer代码 typed_buffers = "" if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers: typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name) # 准备buffer管理器代码 buffer_manager = "" if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers: buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name) # 准备SPIR-V数组 compilation_results_with_arrays = [] for result in compilation_results: symbol_name = f"{metadata.name}_{result.shader_type}_spirv" spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name) compilation_results_with_arrays.append({ 'shader_type': result.shader_type, 'source_file': result.source_file, 'spirv_array': spirv_array, }) # 准备绑定信息 bindings_data = [] if metadata.bindings: for binding in metadata.bindings: bindings_data.append({ 'binding': binding.binding, 'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type), 'count': binding.count, 'stage_flags': vk_shader_stage_to_cpp(binding.stages), }) # 使用模板渲染 renderer = get_renderer() context = { 'namespace': metadata.namespace, 'shader_name': metadata.name, 'metadata_struct_name': f"{metadata.name}_metadata", 'bindings_name': f"{metadata.name}_bindings", 'entry_point': compilation_results[0].entry_point if compilation_results else "main", 'is_compute': metadata.shader_type == "compute", 'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "", 'has_buffers': metadata.reflection and metadata.reflection.buffers, 'generate_typed_buffers': metadata.generate_typed_buffers, 'buffer_structures': buffer_structures, 'buffer_helpers': buffer_helpers, 'typed_buffers': typed_buffers, 'buffer_manager': buffer_manager, 'compilation_results': compilation_results_with_arrays, 'bindings': bindings_data, } return renderer.render('base/header.jinja2', context)