Files
mirage/tools/code_generator.py

177 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Slang Compiler - Code Generator
生成C/C++绑定代码
"""
from pathlib import Path
from code_generator_helper import *
from global_vars import global_vars
from indent_manager import IndentManager
from shader_types import *
class CodeGenerator:
"""C++绑定代码生成器"""
def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str) -> None:
"""生成C++绑定函数的入口方法"""
self._generate_cpp_bindings(binding_infos, output_path)
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
"""生成Uniform缓冲区结构体"""
writer.write("// Uniform buffer structures")
for buffer in uniform_buffers:
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
# 计算总大小和对齐要求
total_size = 0
max_alignment = 16 # GPU通常要求16字节对齐
with writer.block(f'struct {struct_name}', '};'):
for i, field in enumerate(buffer.fields):
# 检查是否需要填充
if field.offset > total_size:
padding_size = field.offset - total_size
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
total_size = field.offset
# 生成字段声明
declaration = get_c_type_declaration(field.type, field.name)
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
total_size = field.offset + field.size
# 确保结构体大小正确对齐
aligned_size = get_aligned_size(total_size, max_alignment)
if aligned_size > total_size:
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
writer.write()
def _generate_cpp_bindings(self, binding_infos: List[ShaderInfo], output_path: str) -> None:
"""生成C++绑定函数"""
output_file = Path(output_path)
# 尝试创建输出目录
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as file:
writer = IndentManager(file)
self._write_complete_file(writer, binding_infos)
def _write_complete_file(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
"""写入完整的文件内容"""
self._write_header(writer)
with writer.indent(f'namespace {global_vars.source_file_name}_shaders ''{', '}'):
self._write_shader_bindings_class(writer, binding_infos)
writer.write()
def _write_header(self, writer: IndentManager) -> None:
"""写入文件头部"""
headers = [
'#pragma once',
'#include <cstdint>',
'#include <array>',
]
for header in headers:
writer.write(header)
writer.write()
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
"""写入ShaderBindings类"""
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
self._write_blob(writer, binding_infos)
self._write_public_methods(writer, binding_infos)
self._write_get_pipeline_desc_method(writer, binding_infos)
def _write_blob(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
"""写入public部分的结构体定义"""
# 写入二进制内容
for info in binding_infos:
entry_point = info.entry_point
blob_data = info.blob
with writer.block(f'static constexpr std::array<uint8_t, {len(blob_data)}> {global_vars.source_file_name}_{entry_point}_blob =', '};'):
# 每行16个字节提高可读性
for i in range(0, len(blob_data), 16):
chunk = blob_data[i:i + 16]
hex_str = ', '.join(f'0x{byte:02x}' for byte in chunk)
if i + 16 < len(blob_data):
hex_str += ','
writer.write(hex_str)
writer.write()
writer.write()
def _write_public_methods(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
"""写入公共方法"""
self._write_get_shader_info_method(writer, binding_infos)
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
"""写入getShaderInfo方法"""
with writer.indent(f'static sg_shader_desc get_{global_vars.source_file_name}_shader_desc()'' {', '}'):
writer.write('sg_shader_desc desc = {};')
writer.write(f'desc.label = "{global_vars.source_file_name}";')
writer.write('// 顶点布局')
with writer.indent('{', '}'):
for i, vertex_field in enumerate(global_vars.layout.vertex_fields):
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semantic}";')
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semantic_index};')
for info in binding_infos:
entry_point = info.entry_point
entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob'
stage_name = info.stage.name
writer.write(f'// {stage_name}')
with writer.indent('{','}'):
writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name})')
writer.write(f'desc.{stage_name.lower()}_func.entry = "{entry_point}";')
writer.write('// 资源绑定')
resource_index_map = {}
# 处理资源绑定
for resource in info.resources:
binding_index = resource.binding_index
param_index = resource_index_map.get(resource.type, 0)
resource_index_map[resource.type] = param_index + 1
if resource.type == ResourceType.UNIFORM_BUFFER:
uniform_info = resource.uniform_data
writer.write(f'desc.uniform_blocks[{param_index}].name = "{resource.name}";')
writer.write(f'desc.uniform_blocks[{param_index}].size = {uniform_info.size};')
writer.write(f'desc.uniform_blocks[{param_index}].hlsl_register_b_n = {uniform_info.binding};')
writer.write(f'desc.uniform_blocks[{param_index}].msl_buffer_n = {uniform_info.binding};')
writer.write(f'desc.uniform_blocks[{param_index}].wgsl_group0_binding_n = {uniform_info.binding};')
elif resource.type == ResourceType.SAMPLED_TEXTURE:
writer.write(f'desc.images[{param_index}].stage = SG_SHADERSTAGE_{stage_name.upper()}')
writer.write(f'desc.images[{param_index}].name = "{resource.name}";')
writer.write(f'desc.sampled_textures[{param_index}].name = "{resource.name}";')
elif resource.type == ResourceType.STORAGE_BUFFER:
writer.write(f'desc.storage_buffers[{param_index}].name = "{resource.name}";')
elif resource.type == ResourceType.SAMPLER:
writer.write(f'desc.samplers[{param_index}].name = "{resource.name}";')
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
"""写入getPipelineDesc方法"""
with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc(sg_shader shader, sg_pixel_format pixel_format, int32_t sample_count)'' {', '}'):
writer.write('sg_pipeline_desc desc = {};')
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline";')
writer.write('desc.shader = shader;')
writer.write('desc.index_type = SG_INDEXTYPE_UINT32;')
# 处理顶点输入布局
for i, vertex_field in enumerate(global_vars.layout.vertex_fields):
writer.write(f'// {vertex_field.semantic}')
writer.write(f'desc.layout.attrs[{i}].buffer_index = {vertex_field.location}')
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.offset}')
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.type.scalar_type}')