Files
mirage/tools/code_generator.py

428 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
from global_vars import global_vars
from indent_manager import IndentManager
from shader_reflection_type import *
class CodeGenerator:
"""C++绑定代码生成器"""
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None:
"""生成C++绑定函数的入口方法"""
self._generate_cpp_bindings(binding_infos, output_path)
def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None:
"""生成C++绑定函数"""
output_file = Path(output_path)
# 尝试创建输出目录
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as file:
writer = IndentManager(file)
self._write_complete_file(writer, binding_infos)
def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入完整的文件内容"""
self._write_header(writer)
with writer.indent(f'namespace {global_vars.source_file_name} {{', '}'):
self._write_shader_bindings_class(writer, binding_infos)
writer.write()
def _write_header(self, writer: IndentManager) -> None:
"""写入文件头部"""
headers = [
'#pragma once',
'#include <cstdint>',
'#include <array>',
]
for header in headers:
writer.write(header)
writer.write()
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入ShaderBindings类"""
self._write_blob(writer, binding_infos)
self._write_public_methods(writer, binding_infos)
self._write_get_pipeline_desc_method(writer, binding_infos)
def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入public部分的结构体定义"""
# 写入二进制内容
for stage, info in binding_infos.stages.items():
entry_point = info.get_entry_name()
blob_data = info.blob
if blob_data is None:
writer.write(f'// Warning: No blob data for {entry_point}')
continue
with writer.block(
f'static constexpr std::array<uint8_t, {len(blob_data)}> {global_vars.source_file_name}_{entry_point}_blob =',
'};'):
# 每行16个字节提高可读性
for i in range(0, len(blob_data), 16):
chunk = blob_data[i:i + 16]
hex_str = ', '.join(f'0x{byte:02x}' for byte in chunk)
if i + 16 < len(blob_data):
hex_str += ','
writer.write(hex_str)
writer.write()
def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入公共方法"""
self._write_get_shader_info_method(writer, binding_infos)
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入getShaderInfo方法"""
with writer.indent(f'static sg_shader_desc get_shader_desc() {{', '}'):
writer.write('sg_shader_desc desc = {};')
writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";')
writer.write()
# 写入顶点属性
writer.write('// 顶点属性')
for i, vertex_field in enumerate(binding_infos.vertex_layout):
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";')
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};')
writer.write()
# 收集所有的image-sampler pairs
image_sampler_pairs = []
# 写入各个阶段的着色器
for stage, info in binding_infos.stages.items():
entry_point = info.get_entry_name()
entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob'
stage_name = info.get_stage().name.lower()
writer.write(f'// {stage_name} shader')
if info.blob is not None:
writer.write(f'desc.{stage_name}_func.bytecode = SG_RANGE({entry_point_name});')
else:
writer.write(f'// desc.{stage_name}_func.bytecode = SG_RANGE({entry_point_name}); // No blob data')
writer.write(f'desc.{stage_name}_func.entry = "{entry_point}";')
writer.write(f'desc.{stage_name}_func.d3d11_target = "{CodeGenerator.d3d11_target_string(stage)}";')
writer.write()
# 写入该阶段的资源绑定并收集image-sampler pairs
pairs = self._write_stage_resource_bindings(writer, info, stage_name)
image_sampler_pairs.extend(pairs)
# 写入image-sampler pairs
if image_sampler_pairs:
writer.write('// Image-sampler pairs')
for i, pair in enumerate(image_sampler_pairs):
writer.write(f'desc.image_sampler_pairs[{i}].stage = {pair["stage"]};')
writer.write(f'desc.image_sampler_pairs[{i}].image_slot = {pair["image_slot"]};')
writer.write(f'desc.image_sampler_pairs[{i}].sampler_slot = {pair["sampler_slot"]};')
if global_vars.target == TargetFormat.GLSL:
writer.write(f'desc.image_sampler_pairs[{i}].glsl_name = "{pair["glsl_name"]}";')
writer.write()
writer.write('return desc;')
@staticmethod
def d3d11_target_string(stage: Stage) -> str:
"""获取D3D11目标字符串"""
if stage == Stage.VERTEX:
return 'vs_5_0'
elif stage == Stage.FRAGMENT:
return 'ps_5_0'
elif stage == Stage.COMPUTE:
return 'cs_5_0'
else:
return 'unknown'
def _write_stage_resource_bindings(self, writer: IndentManager, stage_info: ShaderStageInfo,
stage_name: str) -> list:
"""写入单个阶段的资源绑定"""
# 统计各类资源
uniform_blocks = []
images = []
samplers = []
storage_buffers = []
# 用于存储image-sampler pairs
image_sampler_pairs = []
for param in stage_info.parameters:
if not param.binding.used:
continue
binding_kind = param.get_binding_kind()
if binding_kind == BindingKind.CONSTANT_BUFFER:
uniform_blocks.append(param)
elif binding_kind == BindingKind.UNIFORM:
uniform_blocks.append(param)
elif binding_kind == BindingKind.SUB_ELEMENT_REGISTER_SPACE:
uniform_blocks.append(param)
elif binding_kind == BindingKind.SHADER_RESOURCE:
images.append(param)
elif binding_kind == BindingKind.SAMPLER_STATE:
samplers.append(param)
elif binding_kind == BindingKind.STORAGE_BUFFER:
storage_buffers.append(param)
# 写入uniform blocks
if uniform_blocks:
writer.write(f'// {stage_name} uniform blocks')
for i, param in enumerate(uniform_blocks):
self._write_uniform_block(writer, param, i, stage_name)
writer.write()
# 写入images
if images:
writer.write(f'// {stage_name} images')
for i, param in enumerate(images):
self._write_image(writer, param, i, stage_name)
writer.write()
# 写入samplers
if samplers:
writer.write(f'// {stage_name} samplers')
for i, param in enumerate(samplers):
self._write_sampler(writer, param, i, stage_name)
writer.write()
# 写入storage buffers
if storage_buffers:
writer.write(f'// {stage_name} storage buffers')
for i, param in enumerate(storage_buffers):
self._write_storage_buffer(writer, param, i, stage_name)
writer.write()
# 尝试配对images和samplers
# 这里需要根据实际的着色器逻辑来确定配对关系
# 一个简单的实现是基于名称匹配
image_sampler_pairs = self._match_image_sampler_pairs(images, samplers, stage_name)
return image_sampler_pairs
def _match_image_sampler_pairs(self, images: List[Parameter], samplers: List[Parameter], stage_name: str) -> list:
"""匹配image和sampler配对"""
pairs = []
stage_upper = stage_name.upper()
# 简单的匹配逻辑:基于名称相似度
for i, image in enumerate(images):
for j, sampler in enumerate(samplers):
# 检查是否有相似的名称例如texture 和 textureSampler
if self._is_paired(image.name, sampler.name):
pair = {
"stage": f"SG_SHADERSTAGE_{stage_upper}",
"image_slot": i,
"sampler_slot": j,
"glsl_name": image.name # 或者使用组合名称
}
pairs.append(pair)
break
# 如果没有明确的配对,可以使用默认配对(假设顺序对应)
if not pairs and images and samplers:
for i in range(min(len(images), len(samplers))):
pair = {
"stage": f"SG_SHADERSTAGE_{stage_upper}",
"image_slot": i,
"sampler_slot": i,
"glsl_name": images[i].name
}
pairs.append(pair)
return pairs
def _is_paired(self, image_name: str, sampler_name: str) -> bool:
"""检查image和sampler是否配对"""
# 简单的启发式方法
# 例如texture 和 textureSampler
# 或者diffuseMap 和 diffuseSampler
# 移除常见的后缀
image_base = image_name.replace("Texture", "").replace("Map", "").replace("Image", "")
sampler_base = sampler_name.replace("Sampler", "").replace("SamplerState", "")
# 检查是否有相同的基础名称
if image_base.lower() == sampler_base.lower():
return True
# 检查是否一个是另一个的前缀
if image_name.lower() in sampler_name.lower() or sampler_name.lower() in image_name.lower():
return True
return False
def _get_image_sampler_pairs_from_bindings(self, entry_point: EntryPoint) -> list:
"""从EntryPoint的bindings中获取image-sampler配对信息"""
pairs = []
# 查找所有已使用的image和sampler
used_images = []
used_samplers = []
for binding in entry_point.bindings:
if binding.binding.used and binding.binding.used > 0:
if binding.binding.kind == BindingKind.SHADER_RESOURCE:
used_images.append(binding)
elif binding.binding.kind == BindingKind.SAMPLER_STATE:
used_samplers.append(binding)
# 基于某种逻辑进行配对
# 这里需要根据你的着色器实际情况来确定配对逻辑
return pairs
def _write_uniform_block(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
"""写入uniform block配置"""
stage_upper = stage_name.upper()
writer.write(f'desc.uniform_blocks[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
writer.write(f'desc.uniform_blocks[{index}].size = {param.get_byte_size()};')
writer.write(f'desc.uniform_blocks[{index}].layout = SG_UNIFORMLAYOUT_STD140;')
target = global_vars.target
if target == TargetFormat.GLSL:
# 对于GLSL需要写入uniform的详细信息
if isinstance(param.type, (ConstantBufferType, ParameterBlockType)):
self._write_glsl_uniforms(writer, param, index)
elif target == TargetFormat.DXBC:
writer.write(f'desc.uniform_blocks[{index}].hlsl_register_b_n = {param.get_register_index()};')
def _write_glsl_uniforms(self, writer: IndentManager, param: Parameter, block_index: int) -> None:
"""写入GLSL uniform详细信息"""
if isinstance(param.type, (ConstantBufferType, ParameterBlockType)):
fields = param.type.elementType.fields
for i, field in enumerate(fields):
glsl_type = self._get_glsl_uniform_type(field.type)
writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].type = {glsl_type};')
writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].array_count = 1;')
writer.write(f'desc.uniform_blocks[{block_index}].glsl_uniforms[{i}].glsl_name = "{field.name}";')
def _get_glsl_uniform_type(self, type_info) -> str:
"""获取GLSL uniform类型"""
if isinstance(type_info, ScalarTypeInfo):
if type_info.scalarType == ScalarType.FLOAT32:
return "SG_UNIFORMTYPE_FLOAT"
elif type_info.scalarType in (ScalarType.INT32, ScalarType.UINT32):
return "SG_UNIFORMTYPE_INT"
elif isinstance(type_info, VectorType):
count = type_info.elementCount
if type_info.elementType.scalarType == ScalarType.FLOAT32:
return f"SG_UNIFORMTYPE_FLOAT{count}"
elif type_info.elementType.scalarType in (ScalarType.INT32, ScalarType.UINT32):
return f"SG_UNIFORMTYPE_INT{count}"
elif isinstance(type_info, MatrixType):
if type_info.rowCount == 4 and type_info.columnCount == 4:
return "SG_UNIFORMTYPE_MAT4"
# 可以添加其他矩阵类型
return "SG_UNIFORMTYPE_FLOAT4" # 默认值
def _write_image(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
"""写入image配置"""
stage_upper = stage_name.upper()
writer.write(f'desc.images[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
writer.write(f'desc.images[{index}].image_type = SG_IMAGETYPE_2D;') # 根据实际类型调整
writer.write(f'desc.images[{index}].sample_type = SG_IMAGESAMPLETYPE_FLOAT;')
writer.write(f'desc.images[{index}].multisampled = false;')
target = global_vars.target
if target == TargetFormat.GLSL:
writer.write(f'desc.images[{index}].glsl_name = "{param.name}";')
elif target == TargetFormat.DXBC:
writer.write(f'desc.images[{index}].hlsl_register_t_n = {param.get_register_index()};')
def _write_sampler(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
"""写入sampler配置"""
stage_upper = stage_name.upper()
writer.write(f'desc.samplers[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
writer.write(f'desc.samplers[{index}].sampler_type = SG_SAMPLERTYPE_FILTERING;')
target = global_vars.target
if target == TargetFormat.GLSL:
writer.write(f'desc.samplers[{index}].glsl_name = "{param.name}";')
elif target == TargetFormat.DXBC:
writer.write(f'desc.samplers[{index}].hlsl_register_s_n = {param.get_register_index()};')
def _write_storage_buffer(self, writer: IndentManager, param: Parameter, index: int, stage_name: str) -> None:
"""写入storage buffer配置"""
stage_upper = stage_name.upper()
writer.write(f'desc.storage_buffers[{index}].stage = SG_SHADERSTAGE_{stage_upper};')
writer.write(f'desc.storage_buffers[{index}].readonly = false;') # 根据实际情况调整
target = global_vars.target
if target == TargetFormat.GLSL:
writer.write(f'desc.storage_buffers[{index}].glsl_binding_n = {param.get_register_index()};')
elif target == TargetFormat.DXBC:
writer.write(f'desc.storage_buffers[{index}].hlsl_register_u_n = {param.get_register_index()};')
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入getPipelineDesc方法"""
function_params = (
'sg_shader shader,\n'
'\t\tsg_pixel_format pixel_format,\n'
'\t\tint32_t sample_count = 1,\n'
'\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES,\n'
'\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE'
)
with writer.indent(
f'static sg_pipeline_desc get_pipeline_desc(\n\t\t{function_params}\n\t) {{',
'}'):
writer.write('sg_pipeline_desc desc = {};')
writer.write()
writer.write('desc.shader = shader;')
writer.write('desc.index_type = SG_INDEXTYPE_UINT32;')
writer.write()
# 处理顶点缓冲区
if binding_infos.vertex_layout:
writer.write('// 顶点缓冲区布局')
writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};')
writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;')
writer.write('desc.layout.buffers[0].step_rate = 1;')
writer.write()
# 处理顶点属性
writer.write('// 顶点属性')
for i, vertex_field in enumerate(binding_infos.vertex_layout):
if vertex_field.semanticName:
writer.write(f'// {vertex_field.semanticName}{vertex_field.semanticIndex or ""}')
writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;')
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};')
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};')
if i < len(binding_infos.vertex_layout) - 1:
writer.write()
writer.write()
writer.write('// 渲染状态')
writer.write('desc.primitive_type = primitive_type;')
writer.write('desc.cull_mode = cull_mode;')
writer.write('desc.face_winding = SG_FACEWINDING_CW;')
writer.write()
writer.write('// 深度状态')
writer.write('desc.depth.write_enabled = false;')
writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;')
writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;')
writer.write()
writer.write('// 混合状态')
writer.write('desc.colors[0].blend.enabled = true;')
writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;')
writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;')
writer.write('desc.colors[0].pixel_format = pixel_format;')
writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;')
writer.write('desc.color_count = 1;')
writer.write()
label = f'{global_vars.source_file_name}_pipeline_desc'
# label需要定长, 不足52个字符使用空格填充
label = label.ljust(52)[:52]
writer.write('desc.sample_count = sample_count;')
writer.write(f'desc.label = "{label}";')
writer.write()
writer.write('return desc;')