309 lines
14 KiB
Python
309 lines
14 KiB
Python
from pathlib import Path
|
||
|
||
from global_vars import global_vars
|
||
from indent_manager import IndentManager
|
||
from shader_reflection_type import *
|
||
|
||
|
||
class CodeGenerator:
|
||
"""C++绑定代码生成器"""
|
||
|
||
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None:
|
||
"""生成C++绑定函数的入口方法"""
|
||
self._generate_cpp_bindings(binding_infos, output_path)
|
||
|
||
def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None:
|
||
"""生成C++绑定函数"""
|
||
output_file = Path(output_path)
|
||
# 尝试创建输出目录
|
||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as file:
|
||
writer = IndentManager(file)
|
||
self._write_complete_file(writer, binding_infos)
|
||
|
||
def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||
"""写入完整的文件内容"""
|
||
self._write_header(writer)
|
||
|
||
with writer.indent(f'namespace {global_vars.source_file_name}_shaders {{', '}'):
|
||
self._write_shader_bindings_class(writer, binding_infos)
|
||
writer.write()
|
||
|
||
def _write_header(self, writer: IndentManager) -> None:
|
||
"""写入文件头部"""
|
||
headers = [
|
||
'#pragma once',
|
||
'#include <cstdint>',
|
||
'#include <array>',
|
||
]
|
||
|
||
for header in headers:
|
||
writer.write(header)
|
||
writer.write()
|
||
|
||
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||
"""写入ShaderBindings类"""
|
||
self._write_blob(writer, binding_infos)
|
||
self._write_public_methods(writer, binding_infos)
|
||
self._write_get_pipeline_desc_method(writer, binding_infos)
|
||
|
||
def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||
"""写入public部分的结构体定义"""
|
||
|
||
# 写入二进制内容
|
||
for stage, info in binding_infos.stages.items():
|
||
entry_point = info.get_entry_name()
|
||
blob_data = info.blob
|
||
|
||
if blob_data is None:
|
||
writer.write(f'// Warning: No blob data for {entry_point}')
|
||
continue
|
||
|
||
with writer.block(
|
||
f'static constexpr std::array<uint8_t, {len(blob_data)}> {global_vars.source_file_name}_{entry_point}_blob =',
|
||
'};'):
|
||
# 每行16个字节,提高可读性
|
||
for i in range(0, len(blob_data), 16):
|
||
chunk = blob_data[i:i + 16]
|
||
hex_str = ', '.join(f'0x{byte:02x}' for byte in chunk)
|
||
if i + 16 < len(blob_data):
|
||
hex_str += ','
|
||
writer.write(hex_str)
|
||
|
||
writer.write()
|
||
|
||
def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||
"""写入公共方法"""
|
||
self._write_get_shader_info_method(writer, binding_infos)
|
||
|
||
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||
"""写入getShaderInfo方法"""
|
||
with writer.indent(f'static sg_shader_desc get_{global_vars.source_file_name}_shader_desc() {{', '}'):
|
||
writer.write('sg_shader_desc desc = {};')
|
||
writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";')
|
||
writer.write()
|
||
|
||
# 写入顶点属性
|
||
writer.write('// 顶点属性')
|
||
for i, vertex_field in enumerate(binding_infos.vertex_layout):
|
||
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";')
|
||
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};')
|
||
writer.write()
|
||
|
||
# 写入各个阶段的着色器
|
||
for stage, info in binding_infos.stages.items():
|
||
entry_point = info.get_entry_name()
|
||
entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob'
|
||
stage_name = info.get_stage().name.lower()
|
||
|
||
writer.write(f'// {stage_name} shader')
|
||
if info.blob is not None:
|
||
writer.write(f'desc.{stage_name}_func.bytecode = SG_RANGE({entry_point_name});')
|
||
else:
|
||
writer.write(f'// desc.{stage_name}_func.bytecode = SG_RANGE({entry_point_name}); // No blob data')
|
||
writer.write(f'desc.{stage_name}_func.entry = "{entry_point}";')
|
||
writer.write()
|
||
|
||
# 写入该阶段的资源绑定
|
||
self._write_stage_resource_bindings(writer, info, stage_name)
|
||
|
||
writer.write('return desc;')
|
||
writer.write()
|
||
|
||
def _write_stage_resource_bindings(self, writer: IndentManager, stage_info: ShaderStageInfo,
|
||
stage_name: str) -> None:
|
||
"""写入单个阶段的资源绑定"""
|
||
# 统计各类资源
|
||
uniform_blocks = []
|
||
images = []
|
||
samplers = []
|
||
storage_buffers = []
|
||
|
||
for param in stage_info.parameters:
|
||
binding_kind = param.get_binding_kind()
|
||
if binding_kind == BindingKind.CONSTANT_BUFFER:
|
||
uniform_blocks.append(param)
|
||
elif binding_kind == BindingKind.SHADER_RESOURCE:
|
||
images.append(param)
|
||
elif binding_kind == BindingKind.SAMPLER_STATE:
|
||
samplers.append(param)
|
||
elif binding_kind == BindingKind.STORAGE_BUFFER:
|
||
storage_buffers.append(param)
|
||
|
||
# 写入uniform blocks
|
||
if uniform_blocks:
|
||
writer.write(f'// {stage_name} uniform blocks')
|
||
for i, param in enumerate(uniform_blocks):
|
||
self._write_uniform_block(writer, param, i, stage_name)
|
||
writer.write()
|
||
|
||
# 写入images
|
||
if images:
|
||
writer.write(f'// {stage_name} images')
|
||
for i, param in enumerate(images):
|
||
self._write_image(writer, param, i, stage_name)
|
||
writer.write()
|
||
|
||
# 写入samplers
|
||
if samplers:
|
||
writer.write(f'// {stage_name} samplers')
|
||
for i, param in enumerate(samplers):
|
||
self._write_sampler(writer, param, i, stage_name)
|
||
writer.write()
|
||
|
||
# 写入storage buffers
|
||
if storage_buffers:
|
||
writer.write(f'// {stage_name} storage buffers')
|
||
for i, param in enumerate(storage_buffers):
|
||
self._write_storage_buffer(writer, param, i, stage_name)
|
||
writer.write()
|
||
|
||
def _write_uniform_block(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
|
||
"""写入uniform block配置"""
|
||
stage_upper = stage_name.upper()
|
||
writer.write(f'desc.uniform_blocks[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
|
||
writer.write(f'desc.uniform_blocks[{index}].size = {param.get_byte_size()};')
|
||
writer.write(f'desc.uniform_blocks[{index}].layout = SG_UNIFORMLAYOUT_STD140;')
|
||
|
||
target = global_vars.target
|
||
if target == TargetFormat.GLSL:
|
||
# 对于GLSL,需要写入uniform的详细信息
|
||
if isinstance(param.type, (ConstantBufferType, ParameterBlockType)):
|
||
self._write_glsl_uniforms(writer, param, index)
|
||
elif target == TargetFormat.DXBC:
|
||
writer.write(f'desc.uniform_blocks[{index}].hlsl_register_b_n = {param.get_register_index()};')
|
||
|
||
def _write_glsl_uniforms(self, writer: IndentManager, param: Parameter, block_index: int) -> None:
|
||
"""写入GLSL uniform详细信息"""
|
||
if isinstance(param.type, (ConstantBufferType, ParameterBlockType)):
|
||
fields = param.type.elementType.fields
|
||
for i, field in enumerate(fields):
|
||
glsl_type = self._get_glsl_uniform_type(field.type)
|
||
writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].type = {glsl_type};')
|
||
writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].array_count = 1;')
|
||
writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].glsl_name = "{field.name}";')
|
||
|
||
def _get_glsl_uniform_type(self, type_info) -> str:
|
||
"""获取GLSL uniform类型"""
|
||
if isinstance(type_info, ScalarTypeInfo):
|
||
if type_info.scalarType == ScalarType.FLOAT32:
|
||
return "SG_UNIFORMTYPE_FLOAT"
|
||
elif type_info.scalarType in (ScalarType.INT32, ScalarType.UINT32):
|
||
return "SG_UNIFORMTYPE_INT"
|
||
elif isinstance(type_info, VectorType):
|
||
count = type_info.elementCount
|
||
if type_info.elementType.scalarType == ScalarType.FLOAT32:
|
||
return f"SG_UNIFORMTYPE_FLOAT{count}"
|
||
elif type_info.elementType.scalarType in (ScalarType.INT32, ScalarType.UINT32):
|
||
return f"SG_UNIFORMTYPE_INT{count}"
|
||
elif isinstance(type_info, MatrixType):
|
||
if type_info.rowCount == 4 and type_info.columnCount == 4:
|
||
return "SG_UNIFORMTYPE_MAT4"
|
||
# 可以添加其他矩阵类型
|
||
return "SG_UNIFORMTYPE_FLOAT4" # 默认值
|
||
|
||
def _write_image(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
|
||
"""写入image配置"""
|
||
stage_upper = stage_name.upper()
|
||
writer.write(f'desc.images[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
|
||
writer.write(f'desc.images[{index}].image_type = SG_IMAGETYPE_2D;') # 根据实际类型调整
|
||
writer.write(f'desc.images[{index}].sample_type = SG_IMAGESAMPLETYPE_FLOAT;')
|
||
writer.write(f'desc.images[{index}].multisampled = false;')
|
||
|
||
target = global_vars.target
|
||
if target == TargetFormat.GLSL:
|
||
writer.write(f'desc.images[{index}].glsl_name = "{param.name}";')
|
||
elif target == TargetFormat.DXBC:
|
||
writer.write(f'desc.images[{index}].hlsl_register_t_n = {param.get_register_index()};')
|
||
|
||
def _write_sampler(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
|
||
"""写入sampler配置"""
|
||
stage_upper = stage_name.upper()
|
||
writer.write(f'desc.samplers[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
|
||
writer.write(f'desc.samplers[{index}].sampler_type = SG_SAMPLERTYPE_FILTERING;')
|
||
|
||
target = global_vars.target
|
||
if target == TargetFormat.GLSL:
|
||
writer.write(f'desc.samplers[{index}].glsl_name = "{param.name}";')
|
||
elif target == TargetFormat.DXBC:
|
||
writer.write(f'desc.samplers[{index}].hlsl_register_s_n = {param.get_register_index()};')
|
||
|
||
def _write_storage_buffer(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
|
||
"""写入storage buffer配置"""
|
||
stage_upper = stage_name.upper()
|
||
writer.write(f'desc.storage_buffers[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
|
||
writer.write(f'desc.storage_buffers[{index}].readonly = false;') # 根据实际情况调整
|
||
|
||
target = global_vars.target
|
||
if target == TargetFormat.GLSL:
|
||
writer.write(f'desc.storage_buffers[{index}].glsl_binding_n = {param.get_register_index()};')
|
||
elif target == TargetFormat.DXBC:
|
||
writer.write(f'desc.storage_buffers[{index}].hlsl_register_u_n = {param.get_register_index()};')
|
||
|
||
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||
"""写入getPipelineDesc方法"""
|
||
function_params = (
|
||
'sg_shader shader,\n'
|
||
'\t\tsg_pixel_format pixel_format,\n'
|
||
'\t\tint32_t sample_count = 1,\n'
|
||
'\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES,\n'
|
||
'\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE'
|
||
)
|
||
|
||
with writer.indent(
|
||
f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc(\n\t\t{function_params}\n\t) {{',
|
||
'}'):
|
||
writer.write('sg_pipeline_desc desc = {};')
|
||
writer.write()
|
||
|
||
writer.write('desc.shader = shader;')
|
||
writer.write('desc.index_type = SG_INDEXTYPE_UINT32;')
|
||
writer.write()
|
||
|
||
# 处理顶点缓冲区
|
||
if binding_infos.vertex_layout:
|
||
writer.write('// 顶点缓冲区布局')
|
||
writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};')
|
||
writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;')
|
||
writer.write('desc.layout.buffers[0].step_rate = 1;')
|
||
writer.write()
|
||
|
||
# 处理顶点属性
|
||
writer.write('// 顶点属性')
|
||
for i, vertex_field in enumerate(binding_infos.vertex_layout):
|
||
if vertex_field.semanticName:
|
||
writer.write(f'// {vertex_field.semanticName}{vertex_field.semanticIndex or ""}')
|
||
writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;')
|
||
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};')
|
||
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};')
|
||
if i < len(binding_infos.vertex_layout) - 1:
|
||
writer.write()
|
||
|
||
writer.write()
|
||
writer.write('// 渲染状态')
|
||
writer.write('desc.primitive_type = primitive_type;')
|
||
writer.write('desc.cull_mode = cull_mode;')
|
||
writer.write('desc.face_winding = SG_FACEWINDING_CW;')
|
||
writer.write()
|
||
|
||
writer.write('// 深度状态')
|
||
writer.write('desc.depth.write_enabled = false;')
|
||
writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;')
|
||
writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;')
|
||
writer.write()
|
||
|
||
writer.write('// 混合状态')
|
||
writer.write('desc.colors[0].blend.enabled = true;')
|
||
writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;')
|
||
writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;')
|
||
writer.write('desc.colors[0].pixel_format = pixel_format;')
|
||
writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;')
|
||
writer.write()
|
||
|
||
writer.write('desc.sample_count = sample_count;')
|
||
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline_desc";')
|
||
writer.write()
|
||
|
||
writer.write('return desc;')
|