diff --git a/shader_layout_generator.py b/shader_layout_generator.py index e489c4d..fae85af 100644 --- a/shader_layout_generator.py +++ b/shader_layout_generator.py @@ -1,62 +1,105 @@ #!/usr/bin/env python3 import json import sys -from typing import Dict, List, Any, Tuple +from typing import Dict, List, Any, Tuple, Optional, TextIO +from dataclasses import dataclass, field +from enum import Enum +from contextlib import contextmanager +import io -class ShaderLayoutGenerator: - def __init__(self): - # SDL GPU 格式映射 - self.format_mapping = { - ('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT', - ('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2', - ('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3', - ('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4', - ('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT', - ('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2', - ('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3', - ('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4', - ('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT', - ('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2', - ('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3', - ('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4', - } +from code_generator import IndentManager - # 类型大小映射 - self.type_sizes = { - 'float32': 4, - 'int32': 4, - 'uint32': 4, - } - def get_type_info(self, type_obj: Dict) -> Tuple[str, int, int]: - """获取类型信息:C类型名称、元素数量、字节大小""" - kind = type_obj.get('kind') +# 数据模型类 +@dataclass +class FieldType: + """字段类型信息""" + kind: str # 'scalar', 'vector', 'matrix' + scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16' + element_count: Optional[int] = None # for vector + row_count: Optional[int] = None # for matrix + column_count: Optional[int] = None # for matrix + + @classmethod + def from_dict(cls, data: Dict) -> 'FieldType': + """从字典创建FieldType对象""" + kind = data.get('kind') if kind == 'vector': - element_type = type_obj['elementType']['scalarType'] - element_count = type_obj['elementCount'] - c_type = 'float' if element_type == 'float32' else element_type - size = self.type_sizes[element_type] * element_count - return c_type, element_count, size + return cls( + kind=kind, + scalar_type=data['elementType']['scalarType'], + element_count=data['elementCount'] + ) elif kind == 'scalar': - scalar_type = type_obj['scalarType'] - c_type = 'float' if scalar_type == 'float32' else scalar_type - size = self.type_sizes[scalar_type] - return c_type, 1, size + return cls( + kind=kind, + scalar_type=data['scalarType'] + ) elif kind == 'matrix': - rows = type_obj['rowCount'] - cols = type_obj['columnCount'] - element_type = type_obj['elementType']['scalarType'] - size = self.type_sizes[element_type] * rows * cols - return 'float', rows * cols, size + return cls( + kind=kind, + scalar_type=data['elementType']['scalarType'], + row_count=data['rowCount'], + column_count=data['columnCount'] + ) - return 'float', 1, 4 + return cls(kind='scalar', scalar_type='float32') - def parse_vertex_input(self, json_data: Dict) -> List[Dict]: + +@dataclass +class VertexField: + """顶点输入字段""" + name: str + type: FieldType + location: int + semantic: str + semantic_index: int + + +@dataclass +class UniformField: + """Uniform缓冲区字段""" + name: str + type: FieldType + offset: int + size: int + + +@dataclass +class UniformBuffer: + """Uniform缓冲区""" + name: str + binding: int + fields: List[UniformField] = field(default_factory=list) + + +@dataclass +class ShaderLayout: + """着色器布局数据""" + vertex_fields: List[VertexField] = field(default_factory=list) + uniform_buffers: List[UniformBuffer] = field(default_factory=list) + + +# 数据解析器(保持不变) +class ShaderLayoutParser: + """解析JSON数据并提取到类对象""" + + def parse(self, json_file: str) -> ShaderLayout: + """解析JSON文件并返回ShaderLayout对象""" + with open(json_file, 'r') as f: + json_data = json.load(f) + + layout = ShaderLayout() + layout.vertex_fields = self._parse_vertex_input(json_data) + layout.uniform_buffers = self._parse_uniform_buffers(json_data) + + return layout + + def _parse_vertex_input(self, json_data: Dict) -> List[VertexField]: """解析顶点输入字段""" vertex_fields = [] - # 查找顶点输入参数 entry_points = json_data.get('entryPoints', []) for entry in entry_points: if entry.get('stage') == 'vertex': @@ -65,18 +108,18 @@ class ShaderLayoutGenerator: if param.get('name') == 'input' and param.get('stage') == 'vertex': fields = param.get('type', {}).get('fields', []) for field in fields: - field_info = { - 'name': field['name'], - 'type': field['type'], - 'location': field['binding']['index'], - 'semantic': field.get('semanticName', ''), - 'semantic_index': field.get('semanticIndex', 0) - } - vertex_fields.append(field_info) + vertex_field = VertexField( + name=field['name'], + type=FieldType.from_dict(field['type']), + location=field['binding']['index'], + semantic=field.get('semanticName', ''), + semantic_index=field.get('semanticIndex', 0) + ) + vertex_fields.append(vertex_field) return vertex_fields - def parse_uniform_buffers(self, json_data: Dict) -> List[Dict]: + def _parse_uniform_buffers(self, json_data: Dict) -> List[UniformBuffer]: """解析Uniform缓冲区""" uniform_buffers = [] @@ -84,182 +127,412 @@ class ShaderLayoutGenerator: for param in parameters: binding = param.get('binding', {}) if binding.get('kind') == 'constantBuffer': - buffer_info = { - 'name': param['name'], - 'binding': binding['index'], - 'fields': [] - } + buffer = UniformBuffer( + name=param['name'], + binding=binding['index'] + ) # 解析缓冲区字段 element_type = param.get('type', {}).get('elementType', {}) if element_type.get('kind') == 'struct': fields = element_type.get('fields', []) for field in fields: - field_info = { - 'name': field['name'], - 'type': field['type'], - 'offset': field['binding']['offset'], - 'size': field['binding']['size'] - } - buffer_info['fields'].append(field_info) + uniform_field = UniformField( + name=field['name'], + type=FieldType.from_dict(field['type']), + offset=field['binding']['offset'], + size=field['binding']['size'] + ) + buffer.fields.append(uniform_field) - uniform_buffers.append(buffer_info) + uniform_buffers.append(buffer) return uniform_buffers - def generate_vertex_struct(self, vertex_fields: List[Dict]) -> str: - """生成顶点结构体""" - code = "// Auto-generated vertex structure\n" - code += "typedef struct Vertex {\n" - for field in vertex_fields: - c_type, count, _ = self.get_type_info(field['type']) - if count == 1: - code += f" {c_type} {field['name']};\n" +# **重构后的代码生成器** +class ShaderLayoutCodeGenerator: + """从ShaderLayout对象生成C代码""" + + def __init__(self): + # SDL GPU 格式映射 + self.format_mapping = { + ('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT', + ('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2', + ('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3', + ('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4', + ('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT', + ('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2', + ('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3', + ('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4', + ('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT', + ('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2', + ('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3', + ('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4', + ('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2', + ('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4', + ('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2', + ('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4', + ('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2', + ('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4', + ('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2', + ('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4', + ('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2', + ('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4', + } + + # 类型大小映射 + self.type_sizes = { + 'int32': 4, + 'uint32': 4, + 'float32': 4, + 'int8': 1, + 'uint8': 1, + 'int16': 2, + 'uint16': 2, + 'float16': 2, + } + + def generate(self, layout: ShaderLayout, source_file: str) -> str: + """生成完整的头文件内容""" + output = io.StringIO() + writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进 + + writer.write("#pragma once") + writer.write() + writer.write("#include ") + writer.write("#include ") + writer.write() + writer.write(f"// Auto-generated from: {source_file}") + writer.write() + + self._generate_vertex_struct(writer, layout.vertex_fields) + self._generate_uniform_structs(writer, layout.uniform_buffers) + self._generate_vertex_attributes(writer, layout.vertex_fields) + self._generate_helper_functions(writer, layout.vertex_fields) + + return output.getvalue() + + def _get_type_info(self, field_type: FieldType) -> Tuple[str, int, int]: + """获取类型信息:C类型名称、元素数量、字节大小 + + Returns: + Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小) + """ + # C类型映射表 + c_type_mapping = { + 'int32': 'int32_t', + 'uint32': 'uint32_t', + 'float32': 'float', + 'int8': 'int8_t', + 'uint8': 'uint8_t', + 'int16': 'int16_t', + 'uint16': 'uint16_t', + 'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现 + } + + # 对齐要求映射表 + alignment_mapping = { + 'int32': 4, + 'uint32': 4, + 'float32': 4, + 'int8': 1, + 'uint8': 1, + 'int16': 2, + 'uint16': 2, + 'float16': 2, + } + + if field_type.kind == 'scalar': + scalar_type = field_type.scalar_type or 'float32' + c_type = c_type_mapping.get(scalar_type, 'float') + size = self.type_sizes.get(scalar_type, 4) + return c_type, 1, size + + elif field_type.kind == 'vector': + scalar_type = field_type.scalar_type or 'float32' + c_type = c_type_mapping.get(scalar_type, 'float') + element_count = field_type.element_count or 4 + + # 验证向量元素数量的合法性 + if element_count not in [1, 2, 3, 4]: + raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.") + + size = self.type_sizes.get(scalar_type, 4) * element_count + + # 对于vec3,通常需要按vec4对齐 + if element_count == 3: + alignment = alignment_mapping.get(scalar_type, 4) * 4 + size = ((size + alignment - 1) // alignment) * alignment + + return c_type, element_count, size + + elif field_type.kind == 'matrix': + scalar_type = field_type.scalar_type or 'float32' + c_type = c_type_mapping.get(scalar_type, 'float') + + row_count = field_type.row_count or 4 + column_count = field_type.column_count or 4 + + # 验证矩阵维度的合法性 + if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]: + raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. " + f"Rows and columns must be 2, 3, or 4.") + + # 矩阵通常按列主序存储,每列都需要对齐 + element_size = self.type_sizes.get(scalar_type, 4) + + # 每列的大小(考虑对齐) + if row_count == 3: + # mat3xN 的每列通常按 vec4 对齐 + column_size = element_size * 4 else: - code += f" {c_type} {field['name']}[{count}];\n" + column_size = element_size * row_count - # 添加注释 - semantic = field['semantic'] - if field['semantic_index'] > 0: - semantic += str(field['semantic_index']) - code += f" // location: {field['location']}, semantic: {semantic}\n" + total_size = column_size * column_count + total_elements = row_count * column_count - code += "} Vertex;\n\n" - return code + return c_type, total_elements, total_size - def generate_vertex_attributes(self, vertex_fields: List[Dict]) -> str: - """生成顶点属性数组""" - code = "// Vertex attribute descriptions\n" - code += f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}\n\n" - code += "static const SDL_GPUVertexAttribute g_vertexAttributes[] = {\n" + else: + # 未知类型,返回默认值并记录警告 + print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.") + return 'float', 1, 4 - offset = 0 - for i, field in enumerate(vertex_fields): - c_type, count, size = self.get_type_info(field['type']) - scalar_type = field['type'].get('elementType', field['type']).get('scalarType', 'float32') + def _get_c_type_declaration(self, field_type: FieldType, field_name: str) -> str: + """生成C类型声明 - format_key = (scalar_type, count) - format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4') + Args: + field_type: 字段类型信息 + field_name: 字段名称 - code += " {\n" - code += f" .location = {field['location']},\n" - code += f" .buffer_slot = 0,\n" - code += f" .format = {format_name},\n" - code += f" .offset = {offset}\n" - code += " }" + Returns: + str: 完整的C类型声明 + """ + c_type, count, _ = self._get_type_info(field_type) - if i < len(vertex_fields) - 1: - code += "," + if field_type.kind == 'scalar': + return f"{c_type} {field_name}" - code += f" // {field['name']}\n" + elif field_type.kind == 'vector': + # 对于向量,可以选择使用数组或专门的向量类型 + if count == 1: + return f"{c_type} {field_name}" + else: + # 选项1: 使用数组 + return f"{c_type} {field_name}[{count}]" - offset += size + # 选项2: 使用对齐的结构体(如果需要的话) + # return f"struct {{ {c_type} data[{count}]; }} {field_name}" - code += "};\n\n" + elif field_type.kind == 'matrix': + rows = field_type.row_count or 4 + cols = field_type.column_count or 4 - # 生成顶点缓冲区描述 - code += "// Vertex buffer description\n" - code += "static const SDL_GPUVertexBufferDescription g_vertexBufferDesc = {\n" - code += " .slot = 0,\n" - code += f" .pitch = {offset}, // sizeof(Vertex)\n" - code += " .input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,\n" - code += " .instance_step_rate = 0\n" - code += "};\n\n" + # 矩阵使用二维数组(列主序) + return f"{c_type} {field_name}[{cols}][{rows}]" - # 生成顶点输入状态 - code += "// Vertex input state\n" - code += "static const SDL_GPUVertexInputState g_vertexInputState = {\n" - code += " .vertex_buffer_descriptions = &g_vertexBufferDesc,\n" - code += " .num_vertex_buffers = 1,\n" - code += " .vertex_attributes = g_vertexAttributes,\n" - code += " .num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT\n" - code += "};\n\n" + else: + return f"{c_type} {field_name}" - return code + def _get_aligned_size(self, size: int, alignment: int) -> int: + """计算对齐后的大小 - def generate_uniform_structs(self, uniform_buffers: List[Dict]) -> str: + Args: + size: 原始大小 + alignment: 对齐要求 + + Returns: + int: 对齐后的大小 + """ + return ((size + alignment - 1) // alignment) * alignment + + def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None: + """生成顶点结构体""" + writer.write("// Auto-generated vertex structure") + + # 添加必要的头文件包含 + writer.write("#include // For fixed-width integer types") + writer.write() + + with writer.block("typedef struct Vertex", "} Vertex;"): + total_offset = 0 + for field in vertex_fields: + # 生成字段声明 + declaration = self._get_c_type_declaration(field.type, field.name) + + # 获取字段信息 + c_type, count, size = self._get_type_info(field.type) + + # 添加详细注释 + semantic = field.semantic + if field.semantic_index > 0: + semantic += str(field.semantic_index) + + comment = f"// location: {field.location}, semantic: {semantic}" + comment += f", offset: {total_offset}, size: {size} bytes" + + writer.write(comment) + writer.write(f"{declaration};") + + total_offset += size + + # 添加结构体大小的静态断言 + writer.write() + writer.write(f"// Total size: {total_offset} bytes") + writer.write(f"static_assert(sizeof(Vertex) == {total_offset}, \"Vertex struct size mismatch\");") + writer.write() + + def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None: """生成Uniform缓冲区结构体""" - code = "// Uniform buffer structures\n" + writer.write("// Uniform buffer structures") for buffer in uniform_buffers: - struct_name = buffer['name'].replace('_buffer', '').title().replace('_', '') + 'Buffer' - code += f"typedef struct {struct_name} {{\n" + struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer' - for field in buffer['fields']: - c_type, count, _ = self.get_type_info(field['type']) - if field['type']['kind'] == 'matrix': - rows = field['type']['rowCount'] - cols = field['type']['columnCount'] - code += f" {c_type} {field['name']}[{rows}][{cols}];\n" - elif count == 1: - code += f" {c_type} {field['name']};\n" - else: - code += f" {c_type} {field['name']}[{count}];\n" + # 计算总大小和对齐要求 + total_size = 0 + max_alignment = 16 # GPU通常要求16字节对齐 - code += f"}} {struct_name};\n" - code += f"// Binding: {buffer['binding']}, Size: {sum(f['size'] for f in buffer['fields'])} bytes\n\n" + with writer.block(f"typedef struct {struct_name}", f"}} {struct_name};"): + for i, field in enumerate(buffer.fields): + # 检查是否需要填充 + if field.offset > total_size: + padding_size = field.offset - total_size + writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding") + total_size = field.offset - return code + # 生成字段声明 + declaration = self._get_c_type_declaration(field.type, field.name) + writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}") - def generate_helper_functions(self, vertex_fields: List[Dict]) -> str: + total_size = field.offset + field.size + + # 确保结构体大小正确对齐 + aligned_size = self._get_aligned_size(total_size, max_alignment) + if aligned_size > total_size: + writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment") + + writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})") + writer.write() + + def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None: + """生成顶点属性数组""" + writer.write("// Vertex attribute descriptions") + writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}") + writer.write() + + writer.write("static const SDL_GPUVertexAttribute g_vertexAttributes[] = {") + + offset = 0 + with writer.indent(): + for i, field in enumerate(vertex_fields): + c_type, count, size = self._get_type_info(field.type) + scalar_type = field.type.scalar_type or 'float32' + + format_key = (scalar_type, count) + format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4') + + with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"): + writer.write(f"// {field.name}") + writer.write(f".location = {field.location},") + writer.write(f".buffer_slot = 0,") + writer.write(f".format = {format_name},") + writer.write(f".offset = {offset}") + + offset += size + + writer.write("};") + writer.write() + + # 生成顶点缓冲区描述 + writer.write("// Vertex buffer description") + with writer.block("static const SDL_GPUVertexBufferDescription g_vertexBufferDesc =", "};"): + writer.write(".slot = 0,") + writer.write(f".pitch = {offset}, // sizeof(Vertex)") + writer.write(".input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,") + writer.write(".instance_step_rate = 0") + writer.write() + + # 生成顶点输入状态 + writer.write("// Vertex input state") + with writer.block("static const SDL_GPUVertexInputState g_vertexInputState =", "};"): + writer.write(".vertex_buffer_descriptions = &g_vertexBufferDesc,") + writer.write(".num_vertex_buffers = 1,") + writer.write(".vertex_attributes = g_vertexAttributes,") + writer.write(".num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT") + writer.write() + + def _generate_helper_functions(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None: """生成辅助函数""" - code = "// Helper functions\n\n" + writer.write("// Helper functions") + writer.write() # 创建顶点缓冲区函数 - code += "static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device, \n" - code += " const Vertex* vertices, \n" - code += " Uint32 vertexCount) {\n" - code += " SDL_GPUBufferCreateInfo bufferInfo = {\n" - code += " .usage = SDL_GPU_BUFFERUSAGE_VERTEX,\n" - code += " .size = static_cast(sizeof(Vertex)) * vertexCount\n" - code += " };\n" - code += " \n" - code += " SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);\n" - code += " \n" - code += " // Upload vertex data\n" - code += " SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device,\n" - code += " SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD, bufferInfo.size);\n" - code += " \n" - code += " void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);\n" - code += " SDL_memcpy(mapped, vertices, bufferInfo.size);\n" - code += " SDL_UnmapGPUTransferBuffer(device, transfer);\n" - code += " \n" - code += " // Copy to GPU\n" - code += " SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);\n" - code += " SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);\n" - code += " \n" - code += " SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};\n" - code += " SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};\n" - code += " \n" - code += " SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);\n" - code += " SDL_EndGPUCopyPass(copy);\n" - code += " SDL_SubmitGPUCommandBuffer(cmd);\n" - code += " \n" - code += " SDL_ReleaseGPUTransferBuffer(device, transfer);\n" - code += " return buffer;\n" - code += "}\n\n" + writer.write("static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device,") + writer.write(" const Vertex* vertices,") - return code + with writer.block(" Uint32 vertexCount)", "}"): + with writer.block("SDL_GPUBufferCreateInfo bufferInfo =", "};"): + writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,") + writer.write(".size = static_cast(sizeof(Vertex)) * vertexCount") + + writer.write() + writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);") + writer.write() + + with writer.block("SDL_GPUTransferBufferCreateInfo transferInfo =", "};"): + writer.write(".usage = SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD,") + writer.write(".size = bufferInfo.size") + + writer.write() + writer.write("// Upload vertex data") + writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transferInfo);") + writer.write() + writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);") + writer.write("SDL_memcpy(mapped, vertices, bufferInfo.size);") + writer.write("SDL_UnmapGPUTransferBuffer(device, transfer);") + writer.write() + writer.write("// Copy to GPU") + writer.write("SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);") + writer.write("SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);") + writer.write() + writer.write("SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};") + writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};") + writer.write() + writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);") + writer.write("SDL_EndGPUCopyPass(copy);") + writer.write("SDL_SubmitGPUCommandBuffer(cmd);") + writer.write() + writer.write("SDL_ReleaseGPUTransferBuffer(device, transfer);") + writer.write("return buffer;") + + writer.write() + + +# 主生成器类(保持不变) +class ShaderLayoutGenerator: + """主生成器类,整合解析和代码生成功能""" + + def __init__(self): + self.parser = ShaderLayoutParser() + self.code_generator = ShaderLayoutCodeGenerator() def generate_header(self, json_file: str) -> str: """生成头文件内容""" - with open(json_file, 'r') as f: - json_data = json.load(f) + # 解析JSON到类对象 + layout = self.parser.parse(json_file) - # 解析数据 - vertex_fields = self.parse_vertex_input(json_data) - uniform_buffers = self.parse_uniform_buffers(json_data) + # 基于类对象生成代码 + return self.code_generator.generate(layout, json_file) - # 生成代码 - code = "#pragma once\n\n" - code += "#include \n" - code += "#include \n\n" - code += "// Auto-generated from: " + json_file + "\n\n" - code += self.generate_vertex_struct(vertex_fields) - code += self.generate_uniform_structs(uniform_buffers) - code += self.generate_vertex_attributes(vertex_fields) - code += self.generate_helper_functions(vertex_fields) +# 使用示例 +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: shader_layout_gen.py ") + sys.exit(1) - return code + generator = ShaderLayoutGenerator() + header_content = generator.generate_header(sys.argv[1]) + print(header_content) diff --git a/shader_parser.py b/shader_parser.py index a22e217..7a0bf7b 100644 --- a/shader_parser.py +++ b/shader_parser.py @@ -12,7 +12,7 @@ import shutil import re from typing import List, Dict, Optional from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo - +from shader_layout_generator import ShaderLayoutGenerator class ShaderParser: def __init__(self, slangc_path: str, include_paths: List[str] = None): @@ -220,6 +220,11 @@ class ShaderParser: if reflection_json.strip(): try: reflection = json.loads(reflection_json) + generator = ShaderLayoutGenerator() + header_content = generator.generate_header(reflection_path) + layout_path_filename = os.path.join(os.path.dirname(output_path), f"{entry_name}_layout.h") + with open(layout_path_filename, 'w', encoding='utf-8') as out_file: + out_file.write(header_content) return self._create_shader_info_from_reflection( reflection, entry_name, stage, source )