462 lines
19 KiB
Python
462 lines
19 KiB
Python
#!/usr/bin/env python3
|
||
import io
|
||
import json
|
||
import sys
|
||
from typing import Tuple
|
||
|
||
from code_generator import IndentManager
|
||
from global_vars import global_vars
|
||
from shader_types import *
|
||
|
||
|
||
# 数据解析器(保持不变)
|
||
class ShaderLayoutParser:
|
||
"""解析JSON数据并提取到类对象"""
|
||
|
||
def parse(self, json_file: str) -> ShaderLayout:
|
||
"""解析JSON文件并返回ShaderLayout对象"""
|
||
with open(json_file, 'r') as f:
|
||
json_data = json.load(f)
|
||
|
||
layout = ShaderLayout()
|
||
layout.vertex_fields = self._parse_vertex_input(json_data)
|
||
layout.uniform_buffers = self._parse_uniform_buffers(json_data)
|
||
|
||
return layout
|
||
|
||
def _parse_vertex_input(self, json_data: Dict) -> List[VertexField]:
|
||
"""解析顶点输入字段"""
|
||
vertex_fields = []
|
||
|
||
entry_points = json_data.get('entryPoints', [])
|
||
for entry in entry_points:
|
||
if entry.get('stage') == 'vertex':
|
||
parameters = entry.get('parameters', [])
|
||
for param in parameters:
|
||
if param.get('name') == 'input' and param.get('stage') == 'vertex':
|
||
fields = param.get('type', {}).get('fields', [])
|
||
for field in fields:
|
||
vertex_field = VertexField(
|
||
name=field['name'],
|
||
type=FieldType.from_dict(field['type']),
|
||
location=field['binding']['index'],
|
||
semantic=field.get('semanticName', ''),
|
||
semantic_index=field.get('semanticIndex', 0)
|
||
)
|
||
vertex_fields.append(vertex_field)
|
||
|
||
return vertex_fields
|
||
|
||
def _parse_uniform_buffers(self, json_data: Dict) -> List[UniformBuffer]:
|
||
"""解析Uniform缓冲区"""
|
||
uniform_buffers = []
|
||
|
||
parameters = json_data.get('parameters', [])
|
||
for param in parameters:
|
||
binding = param.get('binding', {})
|
||
if binding.get('kind') == 'constantBuffer':
|
||
buffer = UniformBuffer(
|
||
name=param['name'],
|
||
binding=binding['index']
|
||
)
|
||
|
||
# 解析缓冲区字段
|
||
element_type = param.get('type', {}).get('elementType', {})
|
||
if element_type.get('kind') == 'struct':
|
||
fields = element_type.get('fields', [])
|
||
for field in fields:
|
||
uniform_field = UniformField(
|
||
name=field['name'],
|
||
type=FieldType.from_dict(field['type']),
|
||
offset=field['binding']['offset'],
|
||
size=field['binding']['size']
|
||
)
|
||
buffer.fields.append(uniform_field)
|
||
|
||
uniform_buffers.append(buffer)
|
||
|
||
return uniform_buffers
|
||
|
||
|
||
# **重构后的代码生成器**
|
||
class ShaderLayoutCodeGenerator:
|
||
"""从ShaderLayout对象生成C代码"""
|
||
|
||
def __init__(self):
|
||
# SDL GPU 格式映射
|
||
self.format_mapping = {
|
||
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
|
||
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
|
||
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
|
||
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
|
||
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
|
||
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
|
||
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
|
||
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
|
||
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
|
||
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
|
||
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
|
||
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
|
||
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
|
||
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
|
||
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
|
||
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
|
||
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
|
||
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
|
||
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
|
||
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
|
||
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
|
||
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
|
||
}
|
||
|
||
# 类型大小映射
|
||
self.type_sizes = {
|
||
'int32': 4,
|
||
'uint32': 4,
|
||
'float32': 4,
|
||
'int8': 1,
|
||
'uint8': 1,
|
||
'int16': 2,
|
||
'uint16': 2,
|
||
'float16': 2,
|
||
}
|
||
|
||
def generate(self) -> str:
|
||
"""生成完整的头文件内容"""
|
||
output = io.StringIO()
|
||
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
|
||
|
||
writer.write("#pragma once")
|
||
writer.write()
|
||
writer.write("#include <SDL3/SDL.h>")
|
||
writer.write("#include <SDL3/SDL_gpu.h>")
|
||
writer.write("#include <cstdint> // For fixed-width integer types")
|
||
writer.write()
|
||
|
||
writer.write("// Auto-generated vertex structure")
|
||
with writer.block(f'namespace {global_vars.source_file_name}_shader'):
|
||
self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
|
||
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
|
||
self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
|
||
self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
|
||
|
||
return output.getvalue()
|
||
|
||
def _get_type_info(self, field_type: FieldType) -> Tuple[str, int, int]:
|
||
"""获取类型信息:C类型名称、元素数量、字节大小
|
||
|
||
Returns:
|
||
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
||
"""
|
||
# C类型映射表
|
||
c_type_mapping = {
|
||
'int32': 'int32_t',
|
||
'uint32': 'uint32_t',
|
||
'float32': 'float',
|
||
'int8': 'int8_t',
|
||
'uint8': 'uint8_t',
|
||
'int16': 'int16_t',
|
||
'uint16': 'uint16_t',
|
||
'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||
}
|
||
|
||
# 对齐要求映射表
|
||
alignment_mapping = {
|
||
'int32': 4,
|
||
'uint32': 4,
|
||
'float32': 4,
|
||
'int8': 1,
|
||
'uint8': 1,
|
||
'int16': 2,
|
||
'uint16': 2,
|
||
'float16': 2,
|
||
}
|
||
|
||
if field_type.kind == 'scalar':
|
||
scalar_type = field_type.scalar_type or 'float32'
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
size = self.type_sizes.get(scalar_type, 4)
|
||
return c_type, 1, size
|
||
|
||
elif field_type.kind == 'vector':
|
||
scalar_type = field_type.scalar_type or 'float32'
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
element_count = field_type.element_count or 4
|
||
|
||
# 验证向量元素数量的合法性
|
||
if element_count not in [1, 2, 3, 4]:
|
||
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
||
|
||
size = self.type_sizes.get(scalar_type, 4) * element_count
|
||
|
||
# 对于vec3,通常需要按vec4对齐
|
||
if element_count == 3:
|
||
alignment = alignment_mapping.get(scalar_type, 4) * 4
|
||
size = ((size + alignment - 1) // alignment) * alignment
|
||
|
||
return c_type, element_count, size
|
||
|
||
elif field_type.kind == 'matrix':
|
||
scalar_type = field_type.scalar_type or 'float32'
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
|
||
row_count = field_type.row_count or 4
|
||
column_count = field_type.column_count or 4
|
||
|
||
# 验证矩阵维度的合法性
|
||
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
||
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
|
||
f"Rows and columns must be 2, 3, or 4.")
|
||
|
||
# 矩阵通常按列主序存储,每列都需要对齐
|
||
element_size = self.type_sizes.get(scalar_type, 4)
|
||
|
||
# 每列的大小(考虑对齐)
|
||
if row_count == 3:
|
||
# mat3xN 的每列通常按 vec4 对齐
|
||
column_size = element_size * 4
|
||
else:
|
||
column_size = element_size * row_count
|
||
|
||
total_size = column_size * column_count
|
||
total_elements = row_count * column_count
|
||
|
||
return c_type, total_elements, total_size
|
||
|
||
else:
|
||
# 未知类型,返回默认值并记录警告
|
||
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
||
return 'float', 1, 4
|
||
|
||
def _get_c_type_declaration(self, field_type: FieldType, field_name: str) -> str:
|
||
"""生成C类型声明
|
||
|
||
Args:
|
||
field_type: 字段类型信息
|
||
field_name: 字段名称
|
||
|
||
Returns:
|
||
str: 完整的C类型声明
|
||
"""
|
||
c_type, count, _ = self._get_type_info(field_type)
|
||
|
||
if field_type.kind == 'scalar':
|
||
return f"{c_type} {field_name}"
|
||
|
||
elif field_type.kind == 'vector':
|
||
# 对于向量,可以选择使用数组或专门的向量类型
|
||
if count == 1:
|
||
return f"{c_type} {field_name}"
|
||
else:
|
||
# 选项1: 使用数组
|
||
return f"{c_type} {field_name}[{count}]"
|
||
|
||
# 选项2: 使用对齐的结构体(如果需要的话)
|
||
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
||
|
||
elif field_type.kind == 'matrix':
|
||
rows = field_type.row_count or 4
|
||
cols = field_type.column_count or 4
|
||
|
||
# 矩阵使用二维数组(列主序)
|
||
return f"{c_type} {field_name}[{cols}][{rows}]"
|
||
|
||
else:
|
||
return f"{c_type} {field_name}"
|
||
|
||
def _get_aligned_size(self, size: int, alignment: int) -> int:
|
||
"""计算对齐后的大小
|
||
|
||
Args:
|
||
size: 原始大小
|
||
alignment: 对齐要求
|
||
|
||
Returns:
|
||
int: 对齐后的大小
|
||
"""
|
||
return ((size + alignment - 1) // alignment) * alignment
|
||
|
||
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||
"""生成顶点结构体"""
|
||
with writer.block("struct vertex_t", "};"):
|
||
total_offset = 0
|
||
for field in vertex_fields:
|
||
# 生成字段声明
|
||
declaration = self._get_c_type_declaration(field.type, field.name)
|
||
|
||
# 获取字段信息
|
||
c_type, count, size = self._get_type_info(field.type)
|
||
|
||
# 添加详细注释
|
||
semantic = field.semantic
|
||
if field.semantic_index > 0:
|
||
semantic += str(field.semantic_index)
|
||
|
||
comment = f"// location: {field.location}, semantic: {semantic}"
|
||
comment += f", offset: {total_offset}, size: {size} bytes"
|
||
|
||
writer.write(comment)
|
||
writer.write(f"{declaration};")
|
||
|
||
total_offset += size
|
||
|
||
# 添加结构体大小的静态断言
|
||
writer.write()
|
||
writer.write(f"// Total size: {total_offset} bytes")
|
||
writer.write(f"static_assert(sizeof(vertex_t) == {total_offset}, \"Vertex struct size mismatch\");")
|
||
writer.write()
|
||
|
||
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
|
||
"""生成Uniform缓冲区结构体"""
|
||
writer.write("// Uniform buffer structures")
|
||
|
||
for buffer in uniform_buffers:
|
||
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
|
||
|
||
# 计算总大小和对齐要求
|
||
total_size = 0
|
||
max_alignment = 16 # GPU通常要求16字节对齐
|
||
|
||
with writer.block(f"typedef struct {struct_name}", f"}} {struct_name};"):
|
||
for i, field in enumerate(buffer.fields):
|
||
# 检查是否需要填充
|
||
if field.offset > total_size:
|
||
padding_size = field.offset - total_size
|
||
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
|
||
total_size = field.offset
|
||
|
||
# 生成字段声明
|
||
declaration = self._get_c_type_declaration(field.type, field.name)
|
||
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
|
||
|
||
total_size = field.offset + field.size
|
||
|
||
# 确保结构体大小正确对齐
|
||
aligned_size = self._get_aligned_size(total_size, max_alignment)
|
||
if aligned_size > total_size:
|
||
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
|
||
|
||
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
|
||
writer.write()
|
||
|
||
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||
"""生成顶点属性数组"""
|
||
writer.write("// Vertex attribute descriptions")
|
||
writer.write(f"static constexpr uint32_t VERTEX_ATTRIBUTE_COUNT = {len(vertex_fields)};")
|
||
writer.write()
|
||
|
||
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
|
||
|
||
offset = 0
|
||
with writer.indent():
|
||
for i, field in enumerate(vertex_fields):
|
||
c_type, count, size = self._get_type_info(field.type)
|
||
scalar_type = field.type.scalar_type or 'float32'
|
||
|
||
format_key = (scalar_type, count)
|
||
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
|
||
|
||
with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"):
|
||
writer.write(f"// {field.name}")
|
||
writer.write(f".location = {field.location},")
|
||
writer.write(f".buffer_slot = 0,")
|
||
writer.write(f".format = {format_name},")
|
||
writer.write(f".offset = {offset}")
|
||
|
||
offset += size
|
||
|
||
writer.write("};")
|
||
writer.write()
|
||
|
||
# 生成顶点缓冲区描述
|
||
writer.write("// Vertex buffer description")
|
||
with writer.block("static constexpr SDL_GPUVertexBufferDescription vertex_buffer_desc =", "};"):
|
||
writer.write(".slot = 0,")
|
||
writer.write(f".pitch = {offset}, // sizeof(vertex)")
|
||
writer.write(".input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,")
|
||
writer.write(".instance_step_rate = 0")
|
||
writer.write()
|
||
|
||
# 生成顶点输入状态
|
||
writer.write("// Vertex input state")
|
||
with writer.block("static constexpr SDL_GPUVertexInputState vertex_input_state =", "};"):
|
||
writer.write(".vertex_buffer_descriptions = &vertex_buffer_desc,")
|
||
writer.write(".num_vertex_buffers = 1,")
|
||
writer.write(".vertex_attributes = vertex_attributes,")
|
||
writer.write(".num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT")
|
||
writer.write()
|
||
|
||
def _generate_helper_functions(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||
"""生成辅助函数"""
|
||
writer.write("// Helper functions")
|
||
writer.write()
|
||
|
||
# 创建顶点缓冲区函数
|
||
writer.write("static SDL_GPUBuffer* create_vertex_buffer(SDL_GPUDevice* device,")
|
||
writer.write(" const vertex_t* vertices,")
|
||
|
||
with writer.block(" Uint32 vertex_count)", "}"):
|
||
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
|
||
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
|
||
writer.write(".size = static_cast<Uint32>(sizeof(vertex_t)) * vertex_count")
|
||
|
||
writer.write()
|
||
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
|
||
writer.write()
|
||
|
||
with writer.block("SDL_GPUTransferBufferCreateInfo transfer_info =", "};"):
|
||
writer.write(".usage = SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD,")
|
||
writer.write(".size = buffer_info.size")
|
||
|
||
writer.write()
|
||
writer.write("// Upload vertex data")
|
||
writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transfer_info);")
|
||
writer.write()
|
||
writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, false);")
|
||
writer.write("SDL_memcpy(mapped, vertices, buffer_info.size);")
|
||
writer.write("SDL_UnmapGPUTransferBuffer(device, transfer);")
|
||
writer.write()
|
||
writer.write("// Copy to GPU")
|
||
writer.write("SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);")
|
||
writer.write("SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);")
|
||
writer.write()
|
||
writer.write("SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};")
|
||
writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = buffer_info.size};")
|
||
writer.write()
|
||
writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, false);")
|
||
writer.write("SDL_EndGPUCopyPass(copy);")
|
||
writer.write("SDL_SubmitGPUCommandBuffer(cmd);")
|
||
writer.write()
|
||
writer.write("SDL_ReleaseGPUTransferBuffer(device, transfer);")
|
||
writer.write("return buffer;")
|
||
|
||
writer.write()
|
||
|
||
|
||
# 主生成器类(保持不变)
|
||
class ShaderLayoutGenerator:
|
||
"""主生成器类,整合解析和代码生成功能"""
|
||
|
||
def __init__(self):
|
||
self.parser = ShaderLayoutParser()
|
||
self.code_generator = ShaderLayoutCodeGenerator()
|
||
|
||
def generate_header(self, json_file: str) -> str:
|
||
"""生成头文件内容"""
|
||
# 解析JSON到类对象
|
||
global_vars.layout = self.parser.parse(json_file)
|
||
|
||
# 基于类对象生成代码
|
||
return self.code_generator.generate()
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
if len(sys.argv) < 2:
|
||
print("Usage: shader_layout_gen.py <shader.json>")
|
||
sys.exit(1)
|
||
|
||
generator = ShaderLayoutGenerator()
|
||
header_content = generator.generate_header(sys.argv[1])
|
||
print(header_content)
|