Files
mirage_slang/shader_layout_generator.py
2025-06-07 03:39:44 +08:00

462 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
import io
import json
import sys
from typing import Tuple
from code_generator import IndentManager
from global_vars import global_vars
from shader_types import *
# 数据解析器(保持不变)
class ShaderLayoutParser:
"""解析JSON数据并提取到类对象"""
def parse(self, json_file: str) -> ShaderLayout:
"""解析JSON文件并返回ShaderLayout对象"""
with open(json_file, 'r') as f:
json_data = json.load(f)
layout = ShaderLayout()
layout.vertex_fields = self._parse_vertex_input(json_data)
layout.uniform_buffers = self._parse_uniform_buffers(json_data)
return layout
def _parse_vertex_input(self, json_data: Dict) -> List[VertexField]:
"""解析顶点输入字段"""
vertex_fields = []
entry_points = json_data.get('entryPoints', [])
for entry in entry_points:
if entry.get('stage') == 'vertex':
parameters = entry.get('parameters', [])
for param in parameters:
if param.get('name') == 'input' and param.get('stage') == 'vertex':
fields = param.get('type', {}).get('fields', [])
for field in fields:
vertex_field = VertexField(
name=field['name'],
type=FieldType.from_dict(field['type']),
location=field['binding']['index'],
semantic=field.get('semanticName', ''),
semantic_index=field.get('semanticIndex', 0)
)
vertex_fields.append(vertex_field)
return vertex_fields
def _parse_uniform_buffers(self, json_data: Dict) -> List[UniformBuffer]:
"""解析Uniform缓冲区"""
uniform_buffers = []
parameters = json_data.get('parameters', [])
for param in parameters:
binding = param.get('binding', {})
if binding.get('kind') == 'constantBuffer':
buffer = UniformBuffer(
name=param['name'],
binding=binding['index']
)
# 解析缓冲区字段
element_type = param.get('type', {}).get('elementType', {})
if element_type.get('kind') == 'struct':
fields = element_type.get('fields', [])
for field in fields:
uniform_field = UniformField(
name=field['name'],
type=FieldType.from_dict(field['type']),
offset=field['binding']['offset'],
size=field['binding']['size']
)
buffer.fields.append(uniform_field)
uniform_buffers.append(buffer)
return uniform_buffers
# **重构后的代码生成器**
class ShaderLayoutCodeGenerator:
"""从ShaderLayout对象生成C代码"""
def __init__(self):
# SDL GPU 格式映射
self.format_mapping = {
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
}
# 类型大小映射
self.type_sizes = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
def generate(self) -> str:
"""生成完整的头文件内容"""
output = io.StringIO()
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
writer.write("#pragma once")
writer.write()
writer.write("#include <SDL3/SDL.h>")
writer.write("#include <SDL3/SDL_gpu.h>")
writer.write("#include <cstdint> // For fixed-width integer types")
writer.write()
writer.write("// Auto-generated vertex structure")
with writer.block(f'namespace {global_vars.source_file_name}_shader'):
self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
return output.getvalue()
def _get_type_info(self, field_type: FieldType) -> Tuple[str, int, int]:
"""获取类型信息C类型名称、元素数量、字节大小
Returns:
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
"""
# C类型映射表
c_type_mapping = {
'int32': 'int32_t',
'uint32': 'uint32_t',
'float32': 'float',
'int8': 'int8_t',
'uint8': 'uint8_t',
'int16': 'int16_t',
'uint16': 'uint16_t',
'float16': 'mirage_float16', # 注意C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
}
# 对齐要求映射表
alignment_mapping = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
if field_type.kind == 'scalar':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
size = self.type_sizes.get(scalar_type, 4)
return c_type, 1, size
elif field_type.kind == 'vector':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
element_count = field_type.element_count or 4
# 验证向量元素数量的合法性
if element_count not in [1, 2, 3, 4]:
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
size = self.type_sizes.get(scalar_type, 4) * element_count
# 对于vec3通常需要按vec4对齐
if element_count == 3:
alignment = alignment_mapping.get(scalar_type, 4) * 4
size = ((size + alignment - 1) // alignment) * alignment
return c_type, element_count, size
elif field_type.kind == 'matrix':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
row_count = field_type.row_count or 4
column_count = field_type.column_count or 4
# 验证矩阵维度的合法性
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
f"Rows and columns must be 2, 3, or 4.")
# 矩阵通常按列主序存储,每列都需要对齐
element_size = self.type_sizes.get(scalar_type, 4)
# 每列的大小(考虑对齐)
if row_count == 3:
# mat3xN 的每列通常按 vec4 对齐
column_size = element_size * 4
else:
column_size = element_size * row_count
total_size = column_size * column_count
total_elements = row_count * column_count
return c_type, total_elements, total_size
else:
# 未知类型,返回默认值并记录警告
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
return 'float', 1, 4
def _get_c_type_declaration(self, field_type: FieldType, field_name: str) -> str:
"""生成C类型声明
Args:
field_type: 字段类型信息
field_name: 字段名称
Returns:
str: 完整的C类型声明
"""
c_type, count, _ = self._get_type_info(field_type)
if field_type.kind == 'scalar':
return f"{c_type} {field_name}"
elif field_type.kind == 'vector':
# 对于向量,可以选择使用数组或专门的向量类型
if count == 1:
return f"{c_type} {field_name}"
else:
# 选项1: 使用数组
return f"{c_type} {field_name}[{count}]"
# 选项2: 使用对齐的结构体(如果需要的话)
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
elif field_type.kind == 'matrix':
rows = field_type.row_count or 4
cols = field_type.column_count or 4
# 矩阵使用二维数组(列主序)
return f"{c_type} {field_name}[{cols}][{rows}]"
else:
return f"{c_type} {field_name}"
def _get_aligned_size(self, size: int, alignment: int) -> int:
"""计算对齐后的大小
Args:
size: 原始大小
alignment: 对齐要求
Returns:
int: 对齐后的大小
"""
return ((size + alignment - 1) // alignment) * alignment
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成顶点结构体"""
with writer.block("struct vertex_t", "};"):
total_offset = 0
for field in vertex_fields:
# 生成字段声明
declaration = self._get_c_type_declaration(field.type, field.name)
# 获取字段信息
c_type, count, size = self._get_type_info(field.type)
# 添加详细注释
semantic = field.semantic
if field.semantic_index > 0:
semantic += str(field.semantic_index)
comment = f"// location: {field.location}, semantic: {semantic}"
comment += f", offset: {total_offset}, size: {size} bytes"
writer.write(comment)
writer.write(f"{declaration};")
total_offset += size
# 添加结构体大小的静态断言
writer.write()
writer.write(f"// Total size: {total_offset} bytes")
writer.write(f"static_assert(sizeof(vertex_t) == {total_offset}, \"Vertex struct size mismatch\");")
writer.write()
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
"""生成Uniform缓冲区结构体"""
writer.write("// Uniform buffer structures")
for buffer in uniform_buffers:
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
# 计算总大小和对齐要求
total_size = 0
max_alignment = 16 # GPU通常要求16字节对齐
with writer.block(f"typedef struct {struct_name}", f"}} {struct_name};"):
for i, field in enumerate(buffer.fields):
# 检查是否需要填充
if field.offset > total_size:
padding_size = field.offset - total_size
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
total_size = field.offset
# 生成字段声明
declaration = self._get_c_type_declaration(field.type, field.name)
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
total_size = field.offset + field.size
# 确保结构体大小正确对齐
aligned_size = self._get_aligned_size(total_size, max_alignment)
if aligned_size > total_size:
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
writer.write()
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成顶点属性数组"""
writer.write("// Vertex attribute descriptions")
writer.write(f"static constexpr uint32_t VERTEX_ATTRIBUTE_COUNT = {len(vertex_fields)};")
writer.write()
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
offset = 0
with writer.indent():
for i, field in enumerate(vertex_fields):
c_type, count, size = self._get_type_info(field.type)
scalar_type = field.type.scalar_type or 'float32'
format_key = (scalar_type, count)
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"):
writer.write(f"// {field.name}")
writer.write(f".location = {field.location},")
writer.write(f".buffer_slot = 0,")
writer.write(f".format = {format_name},")
writer.write(f".offset = {offset}")
offset += size
writer.write("};")
writer.write()
# 生成顶点缓冲区描述
writer.write("// Vertex buffer description")
with writer.block("static constexpr SDL_GPUVertexBufferDescription vertex_buffer_desc =", "};"):
writer.write(".slot = 0,")
writer.write(f".pitch = {offset}, // sizeof(vertex)")
writer.write(".input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,")
writer.write(".instance_step_rate = 0")
writer.write()
# 生成顶点输入状态
writer.write("// Vertex input state")
with writer.block("static constexpr SDL_GPUVertexInputState vertex_input_state =", "};"):
writer.write(".vertex_buffer_descriptions = &vertex_buffer_desc,")
writer.write(".num_vertex_buffers = 1,")
writer.write(".vertex_attributes = vertex_attributes,")
writer.write(".num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT")
writer.write()
def _generate_helper_functions(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成辅助函数"""
writer.write("// Helper functions")
writer.write()
# 创建顶点缓冲区函数
writer.write("static SDL_GPUBuffer* create_vertex_buffer(SDL_GPUDevice* device,")
writer.write(" const vertex_t* vertices,")
with writer.block(" Uint32 vertex_count)", "}"):
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
writer.write(".size = static_cast<Uint32>(sizeof(vertex_t)) * vertex_count")
writer.write()
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
writer.write()
with writer.block("SDL_GPUTransferBufferCreateInfo transfer_info =", "};"):
writer.write(".usage = SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD,")
writer.write(".size = buffer_info.size")
writer.write()
writer.write("// Upload vertex data")
writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transfer_info);")
writer.write()
writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, false);")
writer.write("SDL_memcpy(mapped, vertices, buffer_info.size);")
writer.write("SDL_UnmapGPUTransferBuffer(device, transfer);")
writer.write()
writer.write("// Copy to GPU")
writer.write("SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);")
writer.write("SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);")
writer.write()
writer.write("SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};")
writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = buffer_info.size};")
writer.write()
writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, false);")
writer.write("SDL_EndGPUCopyPass(copy);")
writer.write("SDL_SubmitGPUCommandBuffer(cmd);")
writer.write()
writer.write("SDL_ReleaseGPUTransferBuffer(device, transfer);")
writer.write("return buffer;")
writer.write()
# 主生成器类(保持不变)
class ShaderLayoutGenerator:
"""主生成器类,整合解析和代码生成功能"""
def __init__(self):
self.parser = ShaderLayoutParser()
self.code_generator = ShaderLayoutCodeGenerator()
def generate_header(self, json_file: str) -> str:
"""生成头文件内容"""
# 解析JSON到类对象
global_vars.layout = self.parser.parse(json_file)
# 基于类对象生成代码
return self.code_generator.generate()
# 使用示例
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: shader_layout_gen.py <shader.json>")
sys.exit(1)
generator = ShaderLayoutGenerator()
header_content = generator.generate_header(sys.argv[1])
print(header_content)