from pathlib import Path from global_vars import global_vars from indent_manager import IndentManager from shader_reflection_type import * class CodeGenerator: """C++绑定代码生成器""" def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None: """生成C++绑定函数的入口方法""" self._generate_cpp_bindings(binding_infos, output_path) def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None: """生成C++绑定函数""" output_file = Path(output_path) # 尝试创建输出目录 output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, 'w', encoding='utf-8') as file: writer = IndentManager(file) self._write_complete_file(writer, binding_infos) def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入完整的文件内容""" self._write_header(writer) with writer.indent(f'namespace {global_vars.source_file_name} {{', '}'): self._write_shader_bindings_class(writer, binding_infos) writer.write() def _write_header(self, writer: IndentManager) -> None: """写入文件头部""" headers = [ '#pragma once', '#include ', '#include ', ] for header in headers: writer.write(header) writer.write() def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入ShaderBindings类""" self._write_blob(writer, binding_infos) self._write_public_methods(writer, binding_infos) self._write_get_pipeline_desc_method(writer, binding_infos) def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入public部分的结构体定义""" # 写入二进制内容 for stage, info in binding_infos.stages.items(): entry_point = info.get_entry_name() blob_data = info.blob if blob_data is None: writer.write(f'// Warning: No blob data for {entry_point}') continue with writer.block( f'static constexpr std::array {global_vars.source_file_name}_{entry_point}_blob =', '};'): # 每行16个字节,提高可读性 for i in range(0, len(blob_data), 16): chunk = blob_data[i:i + 16] hex_str = ', '.join(f'0x{byte:02x}' for byte in chunk) if i + 16 < len(blob_data): hex_str += ',' writer.write(hex_str) writer.write() def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入公共方法""" self._write_get_shader_info_method(writer, binding_infos) def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入getShaderInfo方法""" with writer.indent(f'static sg_shader_desc get_shader_desc() {{', '}'): writer.write('sg_shader_desc desc = {};') writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";') writer.write() # 写入顶点属性 writer.write('// 顶点属性') for i, vertex_field in enumerate(binding_infos.vertex_layout): writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";') writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};') writer.write() # 收集所有的image-sampler pairs image_sampler_pairs = [] # 写入各个阶段的着色器 for stage, info in binding_infos.stages.items(): entry_point = info.get_entry_name() entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob' stage_name = info.get_stage().name.lower() writer.write(f'// {stage_name} shader') if info.blob is not None: writer.write(f'desc.{stage_name}_func.bytecode = SG_RANGE({entry_point_name});') else: writer.write(f'// desc.{stage_name}_func.bytecode = SG_RANGE({entry_point_name}); // No blob data') writer.write(f'desc.{stage_name}_func.entry = "{entry_point}";') writer.write(f'desc.{stage_name}_func.d3d11_target = "{CodeGenerator.d3d11_target_string(stage)}";') writer.write() # 写入该阶段的资源绑定并收集image-sampler pairs pairs = self._write_stage_resource_bindings(writer, info, stage_name) image_sampler_pairs.extend(pairs) # 写入image-sampler pairs if image_sampler_pairs: writer.write('// Image-sampler pairs') for i, pair in enumerate(image_sampler_pairs): writer.write(f'desc.image_sampler_pairs[{i}].stage = {pair["stage"]};') writer.write(f'desc.image_sampler_pairs[{i}].image_slot = {pair["image_slot"]};') writer.write(f'desc.image_sampler_pairs[{i}].sampler_slot = {pair["sampler_slot"]};') if global_vars.target == TargetFormat.GLSL: writer.write(f'desc.image_sampler_pairs[{i}].glsl_name = "{pair["glsl_name"]}";') writer.write() writer.write('return desc;') @staticmethod def d3d11_target_string(stage: Stage) -> str: """获取D3D11目标字符串""" if stage == Stage.VERTEX: return 'vs_5_0' elif stage == Stage.FRAGMENT: return 'ps_5_0' elif stage == Stage.COMPUTE: return 'cs_5_0' else: return 'unknown' def _write_stage_resource_bindings(self, writer: IndentManager, stage_info: ShaderStageInfo, stage_name: str) -> list: """写入单个阶段的资源绑定""" # 统计各类资源 uniform_blocks = [] images = [] samplers = [] storage_buffers = [] # 用于存储image-sampler pairs image_sampler_pairs = [] for param in stage_info.parameters: if not param.binding.used: continue binding_kind = param.get_binding_kind() if binding_kind == BindingKind.CONSTANT_BUFFER: uniform_blocks.append(param) elif binding_kind == BindingKind.UNIFORM: uniform_blocks.append(param) elif binding_kind == BindingKind.SUB_ELEMENT_REGISTER_SPACE: uniform_blocks.append(param) elif binding_kind == BindingKind.SHADER_RESOURCE: images.append(param) elif binding_kind == BindingKind.SAMPLER_STATE: samplers.append(param) elif binding_kind == BindingKind.STORAGE_BUFFER: storage_buffers.append(param) # 写入uniform blocks if uniform_blocks: writer.write(f'// {stage_name} uniform blocks') for i, param in enumerate(uniform_blocks): self._write_uniform_block(writer, param, i, stage_name) writer.write() # 写入images if images: writer.write(f'// {stage_name} images') for i, param in enumerate(images): self._write_image(writer, param, i, stage_name) writer.write() # 写入samplers if samplers: writer.write(f'// {stage_name} samplers') for i, param in enumerate(samplers): self._write_sampler(writer, param, i, stage_name) writer.write() # 写入storage buffers if storage_buffers: writer.write(f'// {stage_name} storage buffers') for i, param in enumerate(storage_buffers): self._write_storage_buffer(writer, param, i, stage_name) writer.write() # 尝试配对images和samplers # 这里需要根据实际的着色器逻辑来确定配对关系 # 一个简单的实现是基于名称匹配 image_sampler_pairs = self._match_image_sampler_pairs(images, samplers, stage_name) return image_sampler_pairs def _match_image_sampler_pairs(self, images: List[Parameter], samplers: List[Parameter], stage_name: str) -> list: """匹配image和sampler配对""" pairs = [] stage_upper = stage_name.upper() # 简单的匹配逻辑:基于名称相似度 for i, image in enumerate(images): for j, sampler in enumerate(samplers): # 检查是否有相似的名称(例如:texture 和 textureSampler) if self._is_paired(image.name, sampler.name): pair = { "stage": f"SG_SHADERSTAGE_{stage_upper}", "image_slot": i, "sampler_slot": j, "glsl_name": image.name # 或者使用组合名称 } pairs.append(pair) break # 如果没有明确的配对,可以使用默认配对(假设顺序对应) if not pairs and images and samplers: for i in range(min(len(images), len(samplers))): pair = { "stage": f"SG_SHADERSTAGE_{stage_upper}", "image_slot": i, "sampler_slot": i, "glsl_name": images[i].name } pairs.append(pair) return pairs def _is_paired(self, image_name: str, sampler_name: str) -> bool: """检查image和sampler是否配对""" # 简单的启发式方法 # 例如:texture 和 textureSampler # 或者:diffuseMap 和 diffuseSampler # 移除常见的后缀 image_base = image_name.replace("Texture", "").replace("Map", "").replace("Image", "") sampler_base = sampler_name.replace("Sampler", "").replace("SamplerState", "") # 检查是否有相同的基础名称 if image_base.lower() == sampler_base.lower(): return True # 检查是否一个是另一个的前缀 if image_name.lower() in sampler_name.lower() or sampler_name.lower() in image_name.lower(): return True return False def _get_image_sampler_pairs_from_bindings(self, entry_point: EntryPoint) -> list: """从EntryPoint的bindings中获取image-sampler配对信息""" pairs = [] # 查找所有已使用的image和sampler used_images = [] used_samplers = [] for binding in entry_point.bindings: if binding.binding.used and binding.binding.used > 0: if binding.binding.kind == BindingKind.SHADER_RESOURCE: used_images.append(binding) elif binding.binding.kind == BindingKind.SAMPLER_STATE: used_samplers.append(binding) # 基于某种逻辑进行配对 # 这里需要根据你的着色器实际情况来确定配对逻辑 return pairs def _write_uniform_block(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None: """写入uniform block配置""" stage_upper = stage_name.upper() writer.write(f'desc.uniform_blocks[{index}].stage = SG_SHADERSTAGE_{stage_upper};') writer.write(f'desc.uniform_blocks[{index}].size = {param.get_byte_size()};') writer.write(f'desc.uniform_blocks[{index}].layout = SG_UNIFORMLAYOUT_STD140;') target = global_vars.target if target == TargetFormat.GLSL: # 对于GLSL,需要写入uniform的详细信息 if isinstance(param.type, (ConstantBufferType, ParameterBlockType)): self._write_glsl_uniforms(writer, param, index) elif target == TargetFormat.DXBC: writer.write(f'desc.uniform_blocks[{index}].hlsl_register_b_n = {param.get_register_index()};') def _write_glsl_uniforms(self, writer: IndentManager, param: Parameter, block_index: int) -> None: """写入GLSL uniform详细信息""" if isinstance(param.type, (ConstantBufferType, ParameterBlockType)): fields = param.type.elementType.fields for i, field in enumerate(fields): glsl_type = self._get_glsl_uniform_type(field.type) writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].type = {glsl_type};') writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].array_count = 1;') writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].glsl_name = "{field.name}";') def _get_glsl_uniform_type(self, type_info) -> str: """获取GLSL uniform类型""" if isinstance(type_info, ScalarTypeInfo): if type_info.scalarType == ScalarType.FLOAT32: return "SG_UNIFORMTYPE_FLOAT" elif type_info.scalarType in (ScalarType.INT32, ScalarType.UINT32): return "SG_UNIFORMTYPE_INT" elif isinstance(type_info, VectorType): count = type_info.elementCount if type_info.elementType.scalarType == ScalarType.FLOAT32: return f"SG_UNIFORMTYPE_FLOAT{count}" elif type_info.elementType.scalarType in (ScalarType.INT32, ScalarType.UINT32): return f"SG_UNIFORMTYPE_INT{count}" elif isinstance(type_info, MatrixType): if type_info.rowCount == 4 and type_info.columnCount == 4: return "SG_UNIFORMTYPE_MAT4" # 可以添加其他矩阵类型 return "SG_UNIFORMTYPE_FLOAT4" # 默认值 def _write_image(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None: """写入image配置""" stage_upper = stage_name.upper() writer.write(f'desc.images[{index}].stage = SG_SHADERSTAGE_{stage_upper};') writer.write(f'desc.images[{index}].image_type = SG_IMAGETYPE_2D;') # 根据实际类型调整 writer.write(f'desc.images[{index}].sample_type = SG_IMAGESAMPLETYPE_FLOAT;') writer.write(f'desc.images[{index}].multisampled = false;') target = global_vars.target if target == TargetFormat.GLSL: writer.write(f'desc.images[{index}].glsl_name = "{param.name}";') elif target == TargetFormat.DXBC: writer.write(f'desc.images[{index}].hlsl_register_t_n = {param.get_register_index()};') def _write_sampler(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None: """写入sampler配置""" stage_upper = stage_name.upper() writer.write(f'desc.samplers[{index}].stage = SG_SHADERSTAGE_{stage_upper};') writer.write(f'desc.samplers[{index}].sampler_type = SG_SAMPLERTYPE_FILTERING;') target = global_vars.target if target == TargetFormat.GLSL: writer.write(f'desc.samplers[{index}].glsl_name = "{param.name}";') elif target == TargetFormat.DXBC: writer.write(f'desc.samplers[{index}].hlsl_register_s_n = {param.get_register_index()};') def _write_storage_buffer(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None: """写入storage buffer配置""" stage_upper = stage_name.upper() writer.write(f'desc.storage_buffers[{index}].stage = SG_SHADERSTAGE_{stage_upper};') writer.write(f'desc.storage_buffers[{index}].readonly = false;') # 根据实际情况调整 target = global_vars.target if target == TargetFormat.GLSL: writer.write(f'desc.storage_buffers[{index}].glsl_binding_n = {param.get_register_index()};') elif target == TargetFormat.DXBC: writer.write(f'desc.storage_buffers[{index}].hlsl_register_u_n = {param.get_register_index()};') def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入getPipelineDesc方法""" function_params = ( 'sg_shader shader,\n' '\t\tsg_pixel_format pixel_format,\n' '\t\tint32_t sample_count = 1,\n' '\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES,\n' '\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE' ) with writer.indent( f'static sg_pipeline_desc get_pipeline_desc(\n\t\t{function_params}\n\t) {{', '}'): writer.write('sg_pipeline_desc desc = {};') writer.write() writer.write('desc.shader = shader;') writer.write('desc.index_type = SG_INDEXTYPE_UINT32;') writer.write() # 处理顶点缓冲区 if binding_infos.vertex_layout: writer.write('// 顶点缓冲区布局') writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};') writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;') writer.write('desc.layout.buffers[0].step_rate = 1;') writer.write() # 处理顶点属性 writer.write('// 顶点属性') for i, vertex_field in enumerate(binding_infos.vertex_layout): if vertex_field.semanticName: writer.write(f'// {vertex_field.semanticName}{vertex_field.semanticIndex or ""}') writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;') writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};') writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};') if i < len(binding_infos.vertex_layout) - 1: writer.write() writer.write() writer.write('// 渲染状态') writer.write('desc.primitive_type = primitive_type;') writer.write('desc.cull_mode = cull_mode;') writer.write('desc.face_winding = SG_FACEWINDING_CW;') writer.write() writer.write('// 深度状态') writer.write('desc.depth.write_enabled = false;') writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;') writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;') writer.write() writer.write('// 混合状态') writer.write('desc.colors[0].blend.enabled = true;') writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;') writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;') writer.write('desc.colors[0].pixel_format = pixel_format;') writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;') writer.write('desc.color_count = 1;') writer.write() label = f'{global_vars.source_file_name}_pipeline_desc' # label需要定长, 不足52个字符使用空格填充 label = label.ljust(52)[:52] writer.write('desc.sample_count = sample_count;') writer.write(f'desc.label = "{label}";') writer.write() writer.write('return desc;')