Files
mirage/tools/code_generator.py
2025-06-18 16:33:11 +08:00

174 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Slang Compiler - Code Generator
生成C/C++绑定代码
"""
from pathlib import Path
from global_vars import global_vars
from indent_manager import IndentManager
from shader_reflection_type import *
class CodeGenerator:
"""C++绑定代码生成器"""
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None:
"""生成C++绑定函数的入口方法"""
self._generate_cpp_bindings(binding_infos, output_path)
def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None:
"""生成C++绑定函数"""
output_file = Path(output_path)
# 尝试创建输出目录
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as file:
writer = IndentManager(file)
self._write_complete_file(writer, binding_infos)
def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入完整的文件内容"""
self._write_header(writer)
with writer.indent(f'namespace {global_vars.source_file_name}_shaders ''{', '}'):
self._write_shader_bindings_class(writer, binding_infos)
writer.write()
def _write_header(self, writer: IndentManager) -> None:
"""写入文件头部"""
headers = [
'#pragma once',
'#include <cstdint>',
'#include <array>',
]
for header in headers:
writer.write(header)
writer.write()
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入ShaderBindings类"""
# self._generate_uniform_structs(writer, global_vars.vertex_layout.uniform_buffers)
self._write_blob(writer, binding_infos)
self._write_public_methods(writer, binding_infos)
self._write_get_pipeline_desc_method(writer, binding_infos)
def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入public部分的结构体定义"""
# 写入二进制内容
for stage, info in binding_infos.stages.items():
entry_point = info.get_entry_name()
blob_data = info.blob
with writer.block(f'static constexpr std::array<uint8_t, {len(blob_data)}> {global_vars.source_file_name}_{entry_point}_blob =', '};'):
# 每行16个字节提高可读性
for i in range(0, len(blob_data), 16):
chunk = blob_data[i:i + 16]
hex_str = ', '.join(f'0x{byte:02x}' for byte in chunk)
if i + 16 < len(blob_data):
hex_str += ','
writer.write(hex_str)
writer.write()
writer.write()
def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入公共方法"""
self._write_get_shader_info_method(writer, binding_infos)
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入getShaderInfo方法"""
with writer.indent(f'static sg_shader_desc get_{global_vars.source_file_name}_shader_desc()'' {', '}'):
writer.write('sg_shader_desc desc = {};')
writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";')
writer.write('// 顶点布局')
with writer.indent('{', '}'):
for i, vertex_field in enumerate(binding_infos.vertex_layout):
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";')
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};')
for stage, info in binding_infos.stages.items():
entry_point = info.get_entry_name()
entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob'
stage_name = info.get_stage().name
writer.write(f'// {stage_name}')
with writer.indent('{','}'):
writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name});')
writer.write(f'desc.{stage_name.lower()}_func.entry = "{entry_point}";')
writer.write('// 资源绑定')
self._write_resource_binding(writer, info.parameters)
def _write_resource_binding(self, writer: IndentManager, parameters: List[Parameter]) -> None:
resource_index = {}
target = global_vars.target
glsl_block_index = {}
for p in parameters:
binding_kind = p.get_binding_kind()
index = resource_index.get(binding_kind, 0)
resource_index[binding_kind] = index + 1
stage_name = p.stage.name.upper()
if binding_kind == BindingKind.UNIFORM:
t = f'desc.uniform_blocks[{index}]'
writer.write(f'{t}.stage = SG_SHADERSTAGE_{stage_name};')
writer.write(f'{t}.size = {p.get_byte_size()};')
if target == TargetFormat.GLSL:
glsl_index = glsl_block_index.get(binding_kind, 0)
glsl_block_index[binding_kind] = glsl_index + 1
writer.write(f'{t}.glsl_uniforms[{glsl_index}].type = ;')
writer.write(f'{t}.glsl_uniforms[{glsl_index}].array_count = 1;')
writer.write(f'{t}.glsl_uniforms[{glsl_index}].glsl_name = "{p.name}";')
elif target == TargetFormat.DXBC:
writer.write(f'{t}.hlsl_register_b_n = {p.get_register_index()};')
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入getPipelineDesc方法"""
function_params = ('\n\t\t\tsg_shader shader, \n'
'\t\t\tsg_pixel_format pixel_format, \n'
'\t\t\tint32_t sample_count = 1, \n'
'\t\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES, \n'
'\t\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE\n')
with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc({function_params}\t)'' {', '}'):
writer.write('sg_pipeline_desc desc = {};')
writer.write('desc.shader = shader;')
writer.write('desc.index_type = SG_INDEXTYPE_UINT32;')
writer.write()
# 处理顶点缓冲区大小
writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};')
writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;')
writer.write('desc.layout.buffers[0].step_rate = 1;')
writer.write()
# 处理顶点输入布局
for i, vertex_field in enumerate(binding_infos.vertex_layout):
writer.write(f'// {vertex_field.semanticName} {vertex_field.semanticIndex or ''}')
writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;')
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};')
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};')
writer.write()
writer.write('desc.primitive_type = primitive_type;')
writer.write('desc.cull_mode = cull_mode;')
writer.write('desc.face_winding = SG_FACEWINDING_CW;')
writer.write('desc.depth.write_enabled = false;')
writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;')
writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;')
writer.write('desc.colors[0].blend.enabled = true;')
writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;')
writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;')
writer.write('desc.colors[0].pixel_format = pixel_format;')
writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;')
writer.write('desc.sample_count = sample_count;')
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline_desc";')
writer.write('return desc;')